Allow setting a default_concurrency_limit other than 1 (#6439)

* changes

* format

* add changeset

* add changeset

* fix

* format

* fix

* add test

* change

* update test

* moved to queue()

* typo

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-11-16 12:42:49 -08:00 committed by GitHub
parent 179f5bcde1
commit a1e3c61f41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 149 additions and 65 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Allow setting a `default_concurrency_limit` other than 1

View File

@ -232,6 +232,8 @@ Previously, in Gradio 3.x, there was a single global `concurrency_count` paramet
In Gradio 4.0, the `concurrency_count` parameter has been removed. You can still control the number of total threads by using the `max_threads` parameter. The default value of this parameter is `40`, but you don't have worry (as much) about OOM errors, because even though there are 40 threads, we use a single-worker-single-event model, which means each worker thread only executes a specific function. So effectively, each function has its own "concurrency count" of 1. If you'd like to change this behavior, you can do so by setting a parameter `concurrency_limit`, which is now a parameter of *each event*, not a global parameter. By default this is `1` for each event, but you can set it to a higher value, or to `None` if you'd like to allow an arbitrary number of executions of this event simultaneously. Events can also be grouped together using the `concurrency_id` parameter so that they share the same limit, and by default, events that call the same function share the same `concurrency_id`.
Lastly, it should be noted that the default value of the `concurrency_limit` of all events in a Blocks (which is normally 1) can be changed using the `default_concurrency_limit` parameter in `Blocks.queue()`. You can set this to a higher integer or to `None`. This in turn sets the `concurrency_limit` of all events that don't have an explicit `conurrency_limit` specified.
To summarize migration:
* For events that execute quickly or don't use much CPU or GPU resources, you should set `concurrency_limit=None` in Gradio 4.0. (Previously you would set `queue=False`.)

View File

@ -366,7 +366,7 @@ class BlockFunction:
inputs_as_dict: bool,
batch: bool = False,
max_batch_size: int = 4,
concurrency_limit: int | None = 1,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
tracks_progress: bool = False,
):
@ -592,6 +592,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
self.output_components = None
self.__name__ = None
self.api_mode = None
self.progress_tracking = None
self.ssl_verify = True
@ -822,7 +823,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
trigger_after: int | None = None,
trigger_only_on_success: bool = False,
trigger_mode: Literal["once", "multiple", "always_last"] | None = "once",
concurrency_limit: int | None = 1,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
) -> tuple[dict[str, Any], int]:
"""
@ -848,7 +849,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
trigger_after: if set, this event will be triggered after 'trigger_after' function index
trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set)
trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete.
concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Extra events triggered by this listener will be queued. On Spaces, this is set to 1 by default.
concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `queue()`, which itself is 1 by default).
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
Returns: dependency information, dependency index
"""
@ -1649,6 +1650,8 @@ Received outputs:
api_open: bool | None = None,
max_size: int | None = None,
concurrency_count: int | None = None,
*,
default_concurrency_limit: int | None | Literal["not_set"] = "not_set",
):
"""
By enabling the queue you can control when users know their position in the queue, and set a limit on maximum number of events allowed.
@ -1657,6 +1660,7 @@ Received outputs:
api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue.
max_size: The maximum number of events the queue will store at any given moment. If the queue is full, new events will not be added and a user will receive a message saying that the queue is full. If None, the queue size will be unlimited.
concurrency_count: Deprecated and has no effect. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via `max_threads` in launch().
default_concurrency_limit: The default value of `concurrency_limit` to use for event listeners that don't specify a value. Can be set by environment variable GRADIO_DEFAULT_CONCURRENCY_LIMIT. Defaults to 1 if not set otherwise.
Example: (Blocks)
with gr.Blocks() as demo:
button = gr.Button(label="Generate Image")
@ -1682,6 +1686,7 @@ Received outputs:
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
max_size=max_size,
block_fns=self.fns,
default_concurrency_limit=default_concurrency_limit,
)
self.config = self.get_config_file()
self.app = routes.App.create_app(self)

View File

@ -207,7 +207,7 @@ class EventListener(str):
every: float | None = None,
trigger_mode: Literal["once", "multiple", "always_last"] | None = None,
js: str | None = None,
concurrency_limit: int | None = 1,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
) -> Dependency:
"""
@ -227,7 +227,7 @@ class EventListener(str):
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete.
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued.
concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
"""
@ -351,7 +351,7 @@ def on(
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
every: float | None = None,
js: str | None = None,
concurrency_limit: int | None = 1,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
) -> Dependency:
"""
@ -371,7 +371,7 @@ def on(
cancels: A list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish.
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs', return should be a list of values for output components.
concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued.
concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
"""
from gradio.components.base import Component

View File

@ -114,7 +114,7 @@ class Interface(Blocks):
api_name: str | Literal[False] | None = "predict",
_api_mode: bool = False,
allow_duplication: bool = False,
concurrency_limit: int | None = 1,
concurrency_limit: int | None | Literal["default"] = "default",
**kwargs,
):
"""
@ -141,7 +141,7 @@ class Interface(Blocks):
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None, the name of the prediction function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event.
allow_duplication: If True, then will show a 'Duplicate Spaces' button on Hugging Face Spaces.
concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued.
concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which itself is 1 by default).
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -312,7 +312,7 @@ class Interface(Blocks):
self.batch = batch
self.max_batch_size = max_batch_size
self.allow_duplication = allow_duplication
self.concurrency_limit = concurrency_limit
self.concurrency_limit: int | None | Literal["default"] = concurrency_limit
self.share = None
self.share_url = None

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import copy
import json
import os
import time
import traceback
import uuid
@ -78,6 +79,7 @@ class Queue:
update_intervals: float,
max_size: int | None,
block_fns: list[BlockFunction],
default_concurrency_limit: int | None | Literal["not_set"] = "not_set",
):
self.event_queue: list[Event] = []
self.awaiting_data_events: dict[str, Event] = {}
@ -99,20 +101,27 @@ class Queue:
self.block_fns = block_fns
self.continuous_tasks: list[Event] = []
self._asyncio_tasks: list[asyncio.Task] = []
self.default_concurrency_limit = self._resolve_concurrency_limit(
default_concurrency_limit
)
self.concurrency_limit_per_concurrency_id = {}
def start(self):
self.active_jobs = [None] * self.max_thread_count
for block_fn in self.block_fns:
if block_fn.concurrency_limit is not None:
concurrency_limit = (
self.default_concurrency_limit
if block_fn.concurrency_limit == "default"
else block_fn.concurrency_limit
)
if concurrency_limit is not None:
self.concurrency_limit_per_concurrency_id[
block_fn.concurrency_id
] = min(
self.concurrency_limit_per_concurrency_id.get(
block_fn.concurrency_id, block_fn.concurrency_limit
block_fn.concurrency_id, concurrency_limit
),
block_fn.concurrency_limit,
concurrency_limit,
)
run_coro_in_background(self.start_processing)
@ -123,6 +132,26 @@ class Queue:
def close(self):
self.stopped = True
def _resolve_concurrency_limit(self, default_concurrency_limit):
"""
Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination
of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT`
environment variable. The parameter in `Blocks.queue()` takes precedence over the environment variable.
Parameters:
default_concurrency_limit: The default concurrency limit, as specified by a user in `Blocks.queu()`.
"""
if default_concurrency_limit != "not_set":
return default_concurrency_limit
if default_concurrency_limit_env := os.environ.get(
"GRADIO_DEFAULT_CONCURRENCY_LIMIT"
):
if default_concurrency_limit_env.lower() == "none":
return None
else:
return int(default_concurrency_limit_env)
else:
return 1
def attach_data(self, body: PredictBody):
event_id = body.event_id
if event_id in self.awaiting_data_events:

View File

@ -1,13 +1,14 @@
import time
import gradio_client as grc
import pytest
from fastapi.testclient import TestClient
import gradio as gr
class TestQueueing:
def test_single_request(self):
def test_single_request(self, connect):
with gr.Blocks() as demo:
name = gr.Textbox()
output = gr.Textbox()
@ -19,12 +20,11 @@ class TestQueueing:
demo.launch(prevent_thread_lock=True)
client = grc.Client(f"http://localhost:{demo.server_port}")
job = client.submit("x", fn_index=0)
with connect(demo) as client:
job = client.submit("x", fn_index=0)
assert job.result() == "Hello, x!"
assert job.result() == "Hello, x!"
def test_all_status_messages(self):
def test_all_status_messages(self, connect):
with gr.Blocks() as demo:
name = gr.Textbox()
output = gr.Textbox()
@ -35,10 +35,10 @@ class TestQueueing:
name.submit(greet, name, output, concurrency_limit=2)
app, _, _ = demo.launch(prevent_thread_lock=True)
app, local_url, _ = demo.launch(prevent_thread_lock=True)
test_client = TestClient(app)
client = grc.Client(local_url)
client = grc.Client(f"http://localhost:{demo.server_port}")
client.submit("a", fn_index=0)
job2 = client.submit("b", fn_index=0)
client.submit("c", fn_index=0)
@ -70,7 +70,44 @@ class TestQueueing:
assert job2.result() == "Hello, b!"
assert job4.result() == "Hello, d!"
def test_concurrency_limits(self):
@pytest.mark.parametrize(
"default_concurrency_limit, statuses",
[
("not_set", ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]),
(None, ["PROCESSING", "PROCESSING", "PROCESSING"]),
(1, ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]),
(2, ["IN_QUEUE", "PROCESSING", "PROCESSING"]),
],
)
def test_default_concurrency_limits(self, default_concurrency_limit, statuses):
with gr.Blocks() as demo:
a = gr.Number()
b = gr.Number()
output = gr.Number()
add_btn = gr.Button("Add")
@add_btn.click(inputs=[a, b], outputs=output)
def add(x, y):
time.sleep(2)
return x + y
demo.queue(default_concurrency_limit=default_concurrency_limit)
_, local_url, _ = demo.launch(
prevent_thread_lock=True,
)
client = grc.Client(local_url)
add_job_1 = client.submit(1, 1, fn_index=0)
add_job_2 = client.submit(1, 1, fn_index=0)
add_job_3 = client.submit(1, 1, fn_index=0)
time.sleep(1)
add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()]
assert sorted([s.code.value for s in add_job_statuses]) == statuses
def test_concurrency_limits(self, connect):
with gr.Blocks() as demo:
a = gr.Number()
b = gr.Number()
@ -80,14 +117,14 @@ class TestQueueing:
@add_btn.click(inputs=[a, b], outputs=output, concurrency_limit=2)
def add(x, y):
time.sleep(4)
time.sleep(2)
return x + y
sub_btn = gr.Button("Subtract")
@sub_btn.click(inputs=[a, b], outputs=output, concurrency_limit=None)
def sub(x, y):
time.sleep(4)
time.sleep(2)
return x - y
mul_btn = gr.Button("Multiply")
@ -99,7 +136,7 @@ class TestQueueing:
concurrency_id="muldiv",
)
def mul(x, y):
time.sleep(4)
time.sleep(2)
return x * y
div_btn = gr.Button("Divide")
@ -111,49 +148,55 @@ class TestQueueing:
concurrency_id="muldiv",
)
def div(x, y):
time.sleep(4)
time.sleep(2)
return x / y
app, _, _ = demo.launch(prevent_thread_lock=True)
with connect(demo) as client:
add_job_1 = client.submit(1, 1, fn_index=0)
add_job_2 = client.submit(1, 1, fn_index=0)
add_job_3 = client.submit(1, 1, fn_index=0)
sub_job_1 = client.submit(1, 1, fn_index=1)
sub_job_2 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
mul_job_1 = client.submit(1, 1, fn_index=2)
div_job_1 = client.submit(1, 1, fn_index=3)
mul_job_2 = client.submit(1, 1, fn_index=2)
client = grc.Client(f"http://localhost:{demo.server_port}")
add_job_1 = client.submit(1, 1, fn_index=0)
add_job_2 = client.submit(1, 1, fn_index=0)
add_job_3 = client.submit(1, 1, fn_index=0)
sub_job_1 = client.submit(1, 1, fn_index=1)
sub_job_2 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
mul_job_1 = client.submit(1, 1, fn_index=2)
div_job_1 = client.submit(1, 1, fn_index=3)
mul_job_2 = client.submit(1, 1, fn_index=2)
time.sleep(1)
time.sleep(2)
add_job_statuses = [
add_job_1.status(),
add_job_2.status(),
add_job_3.status(),
]
assert sorted([s.code.value for s in add_job_statuses]) == [
"IN_QUEUE",
"PROCESSING",
"PROCESSING",
]
add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()]
assert sorted([s.code.value for s in add_job_statuses]) == [
"IN_QUEUE",
"PROCESSING",
"PROCESSING",
]
sub_job_statuses = [
sub_job_1.status(),
sub_job_2.status(),
sub_job_3.status(),
]
assert [s.code.value for s in sub_job_statuses] == [
"PROCESSING",
"PROCESSING",
"PROCESSING",
]
sub_job_statuses = [sub_job_1.status(), sub_job_2.status(), sub_job_3.status()]
assert [s.code.value for s in sub_job_statuses] == [
"PROCESSING",
"PROCESSING",
"PROCESSING",
]
muldiv_job_statuses = [
mul_job_1.status(),
div_job_1.status(),
mul_job_2.status(),
]
assert sorted([s.code.value for s in muldiv_job_statuses]) == [
"IN_QUEUE",
"PROCESSING",
"PROCESSING",
]
muldiv_job_statuses = [
mul_job_1.status(),
div_job_1.status(),
mul_job_2.status(),
]
assert sorted([s.code.value for s in muldiv_job_statuses]) == [
"IN_QUEUE",
"PROCESSING",
"PROCESSING",
]
def test_every_does_not_block_queue(self):
with gr.Blocks() as demo:
@ -162,10 +205,10 @@ class TestQueueing:
num.submit(lambda n: 2 * n, num, num, every=0.5)
num2.submit(lambda n: 3 * n, num, num)
app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
app, local_url, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True)
test_client = TestClient(app)
client = grc.Client(f"http://localhost:{demo.server_port}")
client = grc.Client(local_url)
job = client.submit(1, fn_index=1)
for _ in range(5):