gradio/test/test_queueing.py
aliabid94 79c8156ebb
Queue concurrency count (#6172)
* changes

* add changeset

* changes

* changes

* changes

* Update client/python/test/conftest.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
2023-10-31 02:41:19 -07:00

177 lines
5.3 KiB
Python

import time
import gradio_client as grc
from fastapi.testclient import TestClient
import gradio as gr
class TestQueueing:
def test_single_request(self):
with gr.Blocks() as demo:
name = gr.Textbox()
output = gr.Textbox()
def greet(x):
return f"Hello, {x}!"
name.submit(greet, name, output)
demo.launch(prevent_thread_lock=True)
client = grc.Client(f"http://localhost:{demo.server_port}")
job = client.submit("x", fn_index=0)
assert job.result() == "Hello, x!"
def test_all_status_messages(self):
with gr.Blocks() as demo:
name = gr.Textbox()
output = gr.Textbox()
def greet(x):
time.sleep(2)
return f"Hello, {x}!"
name.submit(greet, name, output, concurrency_limit=2)
app, _, _ = demo.launch(prevent_thread_lock=True)
test_client = TestClient(app)
client = grc.Client(f"http://localhost:{demo.server_port}")
client.submit("a", fn_index=0)
job2 = client.submit("b", fn_index=0)
client.submit("c", fn_index=0)
job4 = client.submit("d", fn_index=0)
sizes = []
while job4.status().code.value != "FINISHED":
queue_status = test_client.get("/queue/status").json()
queue_size = queue_status["queue_size"]
if len(sizes) == 0 or queue_size != sizes[-1]:
sizes.append(queue_size)
time.sleep(0.01)
time.sleep(0.1)
queue_status = test_client.get("/queue/status").json()
queue_size = queue_status["queue_size"]
if queue_size != sizes[-1]:
sizes.append(queue_size)
assert max(sizes) in [
2,
3,
4,
] # Can be 2 - 4, depending on if the workers have picked up jobs before the queue status is checked
assert min(sizes) == 0
assert sizes[-1] == 0
assert job2.result() == "Hello, b!"
assert job4.result() == "Hello, d!"
def test_concurrency_limits(self):
with gr.Blocks() as demo:
a = gr.Number()
b = gr.Number()
output = gr.Number()
add_btn = gr.Button("Add")
@add_btn.click(inputs=[a, b], outputs=output, concurrency_limit=2)
def add(x, y):
time.sleep(4)
return x + y
sub_btn = gr.Button("Subtract")
@sub_btn.click(inputs=[a, b], outputs=output, concurrency_limit=None)
def sub(x, y):
time.sleep(4)
return x - y
mul_btn = gr.Button("Multiply")
@mul_btn.click(
inputs=[a, b],
outputs=output,
concurrency_limit=2,
concurrency_id="muldiv",
)
def mul(x, y):
time.sleep(4)
return x * y
div_btn = gr.Button("Divide")
@div_btn.click(
inputs=[a, b],
outputs=output,
concurrency_limit=2,
concurrency_id="muldiv",
)
def div(x, y):
time.sleep(4)
return x / y
app, _, _ = demo.launch(prevent_thread_lock=True)
client = grc.Client(f"http://localhost:{demo.server_port}")
add_job_1 = client.submit(1, 1, fn_index=0)
add_job_2 = client.submit(1, 1, fn_index=0)
add_job_3 = client.submit(1, 1, fn_index=0)
sub_job_1 = client.submit(1, 1, fn_index=1)
sub_job_2 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
mul_job_1 = client.submit(1, 1, fn_index=2)
div_job_1 = client.submit(1, 1, fn_index=3)
mul_job_2 = client.submit(1, 1, fn_index=2)
time.sleep(2)
add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()]
assert sorted([s.code.value for s in add_job_statuses]) == [
"IN_QUEUE",
"PROCESSING",
"PROCESSING",
]
sub_job_statuses = [sub_job_1.status(), sub_job_2.status(), sub_job_3.status()]
assert [s.code.value for s in sub_job_statuses] == [
"PROCESSING",
"PROCESSING",
"PROCESSING",
]
muldiv_job_statuses = [
mul_job_1.status(),
div_job_1.status(),
mul_job_2.status(),
]
assert sorted([s.code.value for s in muldiv_job_statuses]) == [
"IN_QUEUE",
"PROCESSING",
"PROCESSING",
]
def test_every_does_not_block_queue(self):
with gr.Blocks() as demo:
num = gr.Number(value=0)
num2 = gr.Number(value=0)
num.submit(lambda n: 2 * n, num, num, every=0.5)
num2.submit(lambda n: 3 * n, num, num)
app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
test_client = TestClient(app)
client = grc.Client(f"http://localhost:{demo.server_port}")
job = client.submit(1, fn_index=1)
for _ in range(5):
status = test_client.get("/queue/status").json()
assert status["queue_size"] == 0
time.sleep(0.5)
assert job.result() == 3