Add concurrency_limit to ChatInterface, add IDE support for concurrency_limit (#6653)

* concurrency limit chat interface

* add changeset

* Update gradio/chat_interface.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-12-04 17:19:53 -05:00 committed by GitHub
parent 19c9d26522
commit d92c819419
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 3 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Add concurrency_limit to ChatInterface, add IDE support for concurrency_limit

View File

@ -6,7 +6,7 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
from __future__ import annotations
import inspect
from typing import AsyncGenerator, Callable
from typing import AsyncGenerator, Callable, Literal, Union, cast
import anyio
from gradio_client import utils as client_utils
@ -75,6 +75,7 @@ class ChatInterface(Blocks):
undo_btn: str | None | Button = "↩️ Undo",
clear_btn: str | None | Button = "🗑️ Clear",
autofocus: bool = True,
concurrency_limit: int | None | Literal["default"] = "default",
):
"""
Parameters:
@ -97,6 +98,7 @@ class ChatInterface(Blocks):
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
autofocus: If True, autofocuses to the textbox when the page loads.
concurrency_limit: If set, this this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -105,6 +107,7 @@ class ChatInterface(Blocks):
title=title or "Gradio",
theme=theme,
)
self.concurrency_limit = concurrency_limit
self.fn = fn
self.is_async = inspect.iscoroutinefunction(
self.fn
@ -304,6 +307,9 @@ class ChatInterface(Blocks):
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)
)
self._setup_stop_events(submit_triggers, submit_event)
@ -329,6 +335,9 @@ class ChatInterface(Blocks):
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)
)
self._setup_stop_events([self.retry_btn.click], retry_event)
@ -412,6 +421,9 @@ class ChatInterface(Blocks):
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state],
api_name="chat",
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)
def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:

View File

@ -21,7 +21,6 @@ INTERFACE_TEMPLATE = '''
inputs: Component | Sequence[Component] | set[Component] | None = None,
outputs: Component | Sequence[Component] | None = None,
api_name: str | None | Literal[False] = None,
status_tracker: None = None,
scroll_to_output: bool = False,
show_progress: Literal["full", "minimal", "hidden"] = "full",
queue: bool | None = None,
@ -32,7 +31,9 @@ INTERFACE_TEMPLATE = '''
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
every: float | None = None,
trigger_mode: Literal["once", "multiple", "always_last"] | None = None,
js: str | None = None,) -> Dependency:
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None) -> Dependency:
"""
Parameters:
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
@ -50,6 +51,8 @@ INTERFACE_TEMPLATE = '''
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete.
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
"""
...
{% endfor %}

View File

@ -49,6 +49,12 @@ class TestInit:
assert chatbot.submit_btn is None
assert chatbot.retry_btn is None
def test_concurrency_limit(self):
chat = gr.ChatInterface(double, concurrency_limit=10)
assert chat.concurrency_limit == 10
fns = [fn for fn in chat.fns if fn.name in {"_submit_fn", "_api_submit_fn"}]
assert all(fn.concurrency_limit == 10 for fn in fns)
def test_events_attached(self):
chatbot = gr.ChatInterface(double)
dependencies = chatbot.dependencies

View File

@ -1,3 +1,7 @@
import ast
import inspect
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
@ -159,3 +163,23 @@ class TestEventErrors:
with pytest.raises(AttributeError):
textbox.change(lambda x: x + x, textbox, textbox)
def test_event_pyi_file_matches_source_code():
"""Test that the template used to create pyi files (search INTERFACE_TEMPLATE in component_meta) matches the source code of EventListener._setup."""
code = (
Path(__file__).parent / ".." / "gradio" / "components" / "button.pyi"
).read_text()
mod = ast.parse(code)
segment = None
for node in ast.walk(mod):
if isinstance(node, ast.FunctionDef) and node.name == "click":
segment = ast.get_source_segment(code, node)
# This would fail if Button no longer has a click method
assert segment
sig = inspect.signature(gr.Button.click)
for param in sig.parameters.values():
if param.name == "block":
continue
assert param.name in segment