From a1e3c61f41b16166656b46254a201b37abcf20a8 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 16 Nov 2023 12:42:49 -0800 Subject: [PATCH] Allow setting a `default_concurrency_limit` other than 1 (#6439) * changes * format * add changeset * add changeset * fix * format * fix * add test * change * update test * moved to queue() * typo --------- Co-authored-by: gradio-pr-bot --- .changeset/dirty-guests-suffer.md | 5 ++ CHANGELOG.md | 2 + gradio/blocks.py | 11 ++- gradio/events.py | 8 +- gradio/interface.py | 6 +- gradio/queueing.py | 37 +++++++- test/test_queueing.py | 145 +++++++++++++++++++----------- 7 files changed, 149 insertions(+), 65 deletions(-) create mode 100644 .changeset/dirty-guests-suffer.md diff --git a/.changeset/dirty-guests-suffer.md b/.changeset/dirty-guests-suffer.md new file mode 100644 index 0000000000..cb1762e6d1 --- /dev/null +++ b/.changeset/dirty-guests-suffer.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Allow setting a `default_concurrency_limit` other than 1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 31c7f12fbd..25d82148af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -232,6 +232,8 @@ Previously, in Gradio 3.x, there was a single global `concurrency_count` paramet In Gradio 4.0, the `concurrency_count` parameter has been removed. You can still control the number of total threads by using the `max_threads` parameter. The default value of this parameter is `40`, but you don't have worry (as much) about OOM errors, because even though there are 40 threads, we use a single-worker-single-event model, which means each worker thread only executes a specific function. So effectively, each function has its own "concurrency count" of 1. If you'd like to change this behavior, you can do so by setting a parameter `concurrency_limit`, which is now a parameter of *each event*, not a global parameter. By default this is `1` for each event, but you can set it to a higher value, or to `None` if you'd like to allow an arbitrary number of executions of this event simultaneously. Events can also be grouped together using the `concurrency_id` parameter so that they share the same limit, and by default, events that call the same function share the same `concurrency_id`. +Lastly, it should be noted that the default value of the `concurrency_limit` of all events in a Blocks (which is normally 1) can be changed using the `default_concurrency_limit` parameter in `Blocks.queue()`. You can set this to a higher integer or to `None`. This in turn sets the `concurrency_limit` of all events that don't have an explicit `conurrency_limit` specified. + To summarize migration: * For events that execute quickly or don't use much CPU or GPU resources, you should set `concurrency_limit=None` in Gradio 4.0. (Previously you would set `queue=False`.) diff --git a/gradio/blocks.py b/gradio/blocks.py index 75366cba50..2e26fe9e7a 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -366,7 +366,7 @@ class BlockFunction: inputs_as_dict: bool, batch: bool = False, max_batch_size: int = 4, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, tracks_progress: bool = False, ): @@ -592,6 +592,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): self.output_components = None self.__name__ = None self.api_mode = None + self.progress_tracking = None self.ssl_verify = True @@ -822,7 +823,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): trigger_after: int | None = None, trigger_only_on_success: bool = False, trigger_mode: Literal["once", "multiple", "always_last"] | None = "once", - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, ) -> tuple[dict[str, Any], int]: """ @@ -848,7 +849,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): trigger_after: if set, this event will be triggered after 'trigger_after' function index trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set) trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete. - concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Extra events triggered by this listener will be queued. On Spaces, this is set to 1 by default. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `queue()`, which itself is 1 by default). concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. Returns: dependency information, dependency index """ @@ -1649,6 +1650,8 @@ Received outputs: api_open: bool | None = None, max_size: int | None = None, concurrency_count: int | None = None, + *, + default_concurrency_limit: int | None | Literal["not_set"] = "not_set", ): """ By enabling the queue you can control when users know their position in the queue, and set a limit on maximum number of events allowed. @@ -1657,6 +1660,7 @@ Received outputs: api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue. max_size: The maximum number of events the queue will store at any given moment. If the queue is full, new events will not be added and a user will receive a message saying that the queue is full. If None, the queue size will be unlimited. concurrency_count: Deprecated and has no effect. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via `max_threads` in launch(). + default_concurrency_limit: The default value of `concurrency_limit` to use for event listeners that don't specify a value. Can be set by environment variable GRADIO_DEFAULT_CONCURRENCY_LIMIT. Defaults to 1 if not set otherwise. Example: (Blocks) with gr.Blocks() as demo: button = gr.Button(label="Generate Image") @@ -1682,6 +1686,7 @@ Received outputs: update_intervals=status_update_rate if status_update_rate != "auto" else 1, max_size=max_size, block_fns=self.fns, + default_concurrency_limit=default_concurrency_limit, ) self.config = self.get_config_file() self.app = routes.App.create_app(self) diff --git a/gradio/events.py b/gradio/events.py index 5b215b0916..12f2bde712 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -207,7 +207,7 @@ class EventListener(str): every: float | None = None, trigger_mode: Literal["once", "multiple", "always_last"] | None = None, js: str | None = None, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, ) -> Dependency: """ @@ -227,7 +227,7 @@ class EventListener(str): every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. - concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. """ @@ -351,7 +351,7 @@ def on( cancels: dict[str, Any] | list[dict[str, Any]] | None = None, every: float | None = None, js: str | None = None, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, ) -> Dependency: """ @@ -371,7 +371,7 @@ def on( cancels: A list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish. every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs', return should be a list of values for output components. - concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. """ from gradio.components.base import Component diff --git a/gradio/interface.py b/gradio/interface.py index 7539b0eff8..ce46cfd331 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -114,7 +114,7 @@ class Interface(Blocks): api_name: str | Literal[False] | None = "predict", _api_mode: bool = False, allow_duplication: bool = False, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", **kwargs, ): """ @@ -141,7 +141,7 @@ class Interface(Blocks): max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True) api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None, the name of the prediction function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event. allow_duplication: If True, then will show a 'Duplicate Spaces' button on Hugging Face Spaces. - concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which itself is 1 by default). """ super().__init__( analytics_enabled=analytics_enabled, @@ -312,7 +312,7 @@ class Interface(Blocks): self.batch = batch self.max_batch_size = max_batch_size self.allow_duplication = allow_duplication - self.concurrency_limit = concurrency_limit + self.concurrency_limit: int | None | Literal["default"] = concurrency_limit self.share = None self.share_url = None diff --git a/gradio/queueing.py b/gradio/queueing.py index ae347512da..1509331684 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import copy import json +import os import time import traceback import uuid @@ -78,6 +79,7 @@ class Queue: update_intervals: float, max_size: int | None, block_fns: list[BlockFunction], + default_concurrency_limit: int | None | Literal["not_set"] = "not_set", ): self.event_queue: list[Event] = [] self.awaiting_data_events: dict[str, Event] = {} @@ -99,20 +101,27 @@ class Queue: self.block_fns = block_fns self.continuous_tasks: list[Event] = [] self._asyncio_tasks: list[asyncio.Task] = [] - + self.default_concurrency_limit = self._resolve_concurrency_limit( + default_concurrency_limit + ) self.concurrency_limit_per_concurrency_id = {} def start(self): self.active_jobs = [None] * self.max_thread_count for block_fn in self.block_fns: - if block_fn.concurrency_limit is not None: + concurrency_limit = ( + self.default_concurrency_limit + if block_fn.concurrency_limit == "default" + else block_fn.concurrency_limit + ) + if concurrency_limit is not None: self.concurrency_limit_per_concurrency_id[ block_fn.concurrency_id ] = min( self.concurrency_limit_per_concurrency_id.get( - block_fn.concurrency_id, block_fn.concurrency_limit + block_fn.concurrency_id, concurrency_limit ), - block_fn.concurrency_limit, + concurrency_limit, ) run_coro_in_background(self.start_processing) @@ -123,6 +132,26 @@ class Queue: def close(self): self.stopped = True + def _resolve_concurrency_limit(self, default_concurrency_limit): + """ + Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination + of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT` + environment variable. The parameter in `Blocks.queue()` takes precedence over the environment variable. + Parameters: + default_concurrency_limit: The default concurrency limit, as specified by a user in `Blocks.queu()`. + """ + if default_concurrency_limit != "not_set": + return default_concurrency_limit + if default_concurrency_limit_env := os.environ.get( + "GRADIO_DEFAULT_CONCURRENCY_LIMIT" + ): + if default_concurrency_limit_env.lower() == "none": + return None + else: + return int(default_concurrency_limit_env) + else: + return 1 + def attach_data(self, body: PredictBody): event_id = body.event_id if event_id in self.awaiting_data_events: diff --git a/test/test_queueing.py b/test/test_queueing.py index 83c2491c69..e65bb060fa 100644 --- a/test/test_queueing.py +++ b/test/test_queueing.py @@ -1,13 +1,14 @@ import time import gradio_client as grc +import pytest from fastapi.testclient import TestClient import gradio as gr class TestQueueing: - def test_single_request(self): + def test_single_request(self, connect): with gr.Blocks() as demo: name = gr.Textbox() output = gr.Textbox() @@ -19,12 +20,11 @@ class TestQueueing: demo.launch(prevent_thread_lock=True) - client = grc.Client(f"http://localhost:{demo.server_port}") - job = client.submit("x", fn_index=0) + with connect(demo) as client: + job = client.submit("x", fn_index=0) + assert job.result() == "Hello, x!" - assert job.result() == "Hello, x!" - - def test_all_status_messages(self): + def test_all_status_messages(self, connect): with gr.Blocks() as demo: name = gr.Textbox() output = gr.Textbox() @@ -35,10 +35,10 @@ class TestQueueing: name.submit(greet, name, output, concurrency_limit=2) - app, _, _ = demo.launch(prevent_thread_lock=True) + app, local_url, _ = demo.launch(prevent_thread_lock=True) test_client = TestClient(app) + client = grc.Client(local_url) - client = grc.Client(f"http://localhost:{demo.server_port}") client.submit("a", fn_index=0) job2 = client.submit("b", fn_index=0) client.submit("c", fn_index=0) @@ -70,7 +70,44 @@ class TestQueueing: assert job2.result() == "Hello, b!" assert job4.result() == "Hello, d!" - def test_concurrency_limits(self): + @pytest.mark.parametrize( + "default_concurrency_limit, statuses", + [ + ("not_set", ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]), + (None, ["PROCESSING", "PROCESSING", "PROCESSING"]), + (1, ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]), + (2, ["IN_QUEUE", "PROCESSING", "PROCESSING"]), + ], + ) + def test_default_concurrency_limits(self, default_concurrency_limit, statuses): + with gr.Blocks() as demo: + a = gr.Number() + b = gr.Number() + output = gr.Number() + + add_btn = gr.Button("Add") + + @add_btn.click(inputs=[a, b], outputs=output) + def add(x, y): + time.sleep(2) + return x + y + + demo.queue(default_concurrency_limit=default_concurrency_limit) + _, local_url, _ = demo.launch( + prevent_thread_lock=True, + ) + client = grc.Client(local_url) + + add_job_1 = client.submit(1, 1, fn_index=0) + add_job_2 = client.submit(1, 1, fn_index=0) + add_job_3 = client.submit(1, 1, fn_index=0) + + time.sleep(1) + + add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()] + assert sorted([s.code.value for s in add_job_statuses]) == statuses + + def test_concurrency_limits(self, connect): with gr.Blocks() as demo: a = gr.Number() b = gr.Number() @@ -80,14 +117,14 @@ class TestQueueing: @add_btn.click(inputs=[a, b], outputs=output, concurrency_limit=2) def add(x, y): - time.sleep(4) + time.sleep(2) return x + y sub_btn = gr.Button("Subtract") @sub_btn.click(inputs=[a, b], outputs=output, concurrency_limit=None) def sub(x, y): - time.sleep(4) + time.sleep(2) return x - y mul_btn = gr.Button("Multiply") @@ -99,7 +136,7 @@ class TestQueueing: concurrency_id="muldiv", ) def mul(x, y): - time.sleep(4) + time.sleep(2) return x * y div_btn = gr.Button("Divide") @@ -111,49 +148,55 @@ class TestQueueing: concurrency_id="muldiv", ) def div(x, y): - time.sleep(4) + time.sleep(2) return x / y - app, _, _ = demo.launch(prevent_thread_lock=True) + with connect(demo) as client: + add_job_1 = client.submit(1, 1, fn_index=0) + add_job_2 = client.submit(1, 1, fn_index=0) + add_job_3 = client.submit(1, 1, fn_index=0) + sub_job_1 = client.submit(1, 1, fn_index=1) + sub_job_2 = client.submit(1, 1, fn_index=1) + sub_job_3 = client.submit(1, 1, fn_index=1) + sub_job_3 = client.submit(1, 1, fn_index=1) + mul_job_1 = client.submit(1, 1, fn_index=2) + div_job_1 = client.submit(1, 1, fn_index=3) + mul_job_2 = client.submit(1, 1, fn_index=2) - client = grc.Client(f"http://localhost:{demo.server_port}") - add_job_1 = client.submit(1, 1, fn_index=0) - add_job_2 = client.submit(1, 1, fn_index=0) - add_job_3 = client.submit(1, 1, fn_index=0) - sub_job_1 = client.submit(1, 1, fn_index=1) - sub_job_2 = client.submit(1, 1, fn_index=1) - sub_job_3 = client.submit(1, 1, fn_index=1) - sub_job_3 = client.submit(1, 1, fn_index=1) - mul_job_1 = client.submit(1, 1, fn_index=2) - div_job_1 = client.submit(1, 1, fn_index=3) - mul_job_2 = client.submit(1, 1, fn_index=2) + time.sleep(1) - time.sleep(2) + add_job_statuses = [ + add_job_1.status(), + add_job_2.status(), + add_job_3.status(), + ] + assert sorted([s.code.value for s in add_job_statuses]) == [ + "IN_QUEUE", + "PROCESSING", + "PROCESSING", + ] - add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()] - assert sorted([s.code.value for s in add_job_statuses]) == [ - "IN_QUEUE", - "PROCESSING", - "PROCESSING", - ] + sub_job_statuses = [ + sub_job_1.status(), + sub_job_2.status(), + sub_job_3.status(), + ] + assert [s.code.value for s in sub_job_statuses] == [ + "PROCESSING", + "PROCESSING", + "PROCESSING", + ] - sub_job_statuses = [sub_job_1.status(), sub_job_2.status(), sub_job_3.status()] - assert [s.code.value for s in sub_job_statuses] == [ - "PROCESSING", - "PROCESSING", - "PROCESSING", - ] - - muldiv_job_statuses = [ - mul_job_1.status(), - div_job_1.status(), - mul_job_2.status(), - ] - assert sorted([s.code.value for s in muldiv_job_statuses]) == [ - "IN_QUEUE", - "PROCESSING", - "PROCESSING", - ] + muldiv_job_statuses = [ + mul_job_1.status(), + div_job_1.status(), + mul_job_2.status(), + ] + assert sorted([s.code.value for s in muldiv_job_statuses]) == [ + "IN_QUEUE", + "PROCESSING", + "PROCESSING", + ] def test_every_does_not_block_queue(self): with gr.Blocks() as demo: @@ -162,10 +205,10 @@ class TestQueueing: num.submit(lambda n: 2 * n, num, num, every=0.5) num2.submit(lambda n: 3 * n, num, num) - app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True) + app, local_url, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True) test_client = TestClient(app) - client = grc.Client(f"http://localhost:{demo.server_port}") + client = grc.Client(local_url) job = client.submit(1, fn_index=1) for _ in range(5):