mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Converts "Submit" button to "Stop" button in ChatInterface while streaming (#4971)
* initial work * changes * changes * undo routes * cancel fix * cancel fix * remove unnecessary test * formatting * changelog * remove extraneous test * simplify * remove no streaming * changelog * no progress on textbox * show progress revert * clog * lint * fixes based on review * updated textbox props * fix * format * test fix
This commit is contained in:
parent
b68aeea412
commit
f8e5bfa2d4
@ -4,6 +4,11 @@ No changes to highlight.
|
||||
|
||||
## New Features:
|
||||
|
||||
- The `gr.ChatInterface` UI now converts the "Submit" button to a "Stop" button in ChatInterface while streaming, which can be used to pause generation. By [@abidlabs](https://github.com/abidlabs) in [PR 4971](https://github.com/gradio-app/gradio/pull/4971).
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
- Fixes `cancels` for generators so that if a generator is canceled before it is complete, subsequent runs of the event do not continue from the previous iteration, but rather start from the beginning. By [@abidlabs](https://github.com/abidlabs) in [PR 4969](https://github.com/gradio-app/gradio/pull/4969).
|
||||
- Add `show_download_button` param to allow the download button in static Image components to be hidden by [@hannahblair](https://github.com/hannahblair) in [PR 4959](https://github.com/gradio-app/gradio/pull/4959)
|
||||
- Added autofocus argument to Textbox by [@aliabid94](https://github.com/aliabid94) in [PR 4978](https://github.com/gradio-app/gradio/pull/4978)
|
||||
- Use `gr.State` in `gr.ChatInterface` to reduce latency by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4976](https://github.com/gradio-app/gradio/pull/4976)
|
||||
|
@ -19,6 +19,7 @@ from gradio.components import (
|
||||
State,
|
||||
Textbox,
|
||||
)
|
||||
from gradio.events import Dependency, EventListenerMethod
|
||||
from gradio.helpers import create_examples as Examples # noqa: N812
|
||||
from gradio.layouts import Column, Group, Row
|
||||
from gradio.themes import ThemeClass as Theme
|
||||
@ -60,6 +61,7 @@ class ChatInterface(Blocks):
|
||||
css: str | None = None,
|
||||
analytics_enabled: bool | None = None,
|
||||
submit_btn: str | None | Button = "Submit",
|
||||
stop_btn: str | None | Button = "Stop",
|
||||
retry_btn: str | None | Button = "🔄 Retry",
|
||||
undo_btn: str | None | Button = "↩️ Undo",
|
||||
clear_btn: str | None | Button = "🗑️ Clear",
|
||||
@ -77,6 +79,7 @@ class ChatInterface(Blocks):
|
||||
css: custom css or path to custom css file to use with interface.
|
||||
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
||||
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
|
||||
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
|
||||
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
||||
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
||||
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
||||
@ -95,6 +98,7 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
|
||||
self.fn = fn
|
||||
self.is_generator = inspect.isgeneratorfunction(self.fn)
|
||||
self.examples = examples
|
||||
if self.space_id and cache_examples is None:
|
||||
self.cache_examples = True
|
||||
@ -119,13 +123,15 @@ class ChatInterface(Blocks):
|
||||
with Group():
|
||||
with Row():
|
||||
if textbox:
|
||||
textbox.container = False
|
||||
textbox.show_label = False
|
||||
self.textbox = textbox.render()
|
||||
else:
|
||||
self.textbox = Textbox(
|
||||
container=False,
|
||||
show_label=False,
|
||||
placeholder="Type a message...",
|
||||
scale=10,
|
||||
scale=7,
|
||||
autofocus=True,
|
||||
)
|
||||
if submit_btn:
|
||||
@ -139,17 +145,31 @@ class ChatInterface(Blocks):
|
||||
raise ValueError(
|
||||
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
||||
)
|
||||
self.buttons.append(submit_btn)
|
||||
if stop_btn:
|
||||
if isinstance(stop_btn, Button):
|
||||
stop_btn.visible = False
|
||||
stop_btn.render()
|
||||
elif isinstance(stop_btn, str):
|
||||
stop_btn = Button(
|
||||
stop_btn,
|
||||
variant="stop",
|
||||
visible=False,
|
||||
scale=1,
|
||||
min_width=150,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
||||
)
|
||||
self.buttons.extend([submit_btn, stop_btn])
|
||||
|
||||
with Row():
|
||||
self.stop_btn = Button("Stop", variant="stop", visible=False)
|
||||
|
||||
for btn in [retry_btn, undo_btn, clear_btn]:
|
||||
if btn:
|
||||
if isinstance(btn, Button):
|
||||
btn.render()
|
||||
elif isinstance(btn, str):
|
||||
btn = Button(btn, variant="secondary", size="sm")
|
||||
btn = Button(btn, variant="secondary")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
||||
@ -162,13 +182,14 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
(
|
||||
self.submit_btn,
|
||||
self.stop_btn,
|
||||
self.retry_btn,
|
||||
self.undo_btn,
|
||||
self.clear_btn,
|
||||
) = self.buttons
|
||||
|
||||
if examples:
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
if self.is_generator:
|
||||
examples_fn = self._examples_stream_fn
|
||||
else:
|
||||
examples_fn = self._examples_fn
|
||||
@ -187,76 +208,87 @@ class ChatInterface(Blocks):
|
||||
self._setup_events()
|
||||
self._setup_api()
|
||||
|
||||
def _setup_events(self):
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
submit_fn = self._stream_fn
|
||||
else:
|
||||
submit_fn = self._submit_fn
|
||||
|
||||
self.textbox.submit(
|
||||
self._clear_and_save_textbox,
|
||||
[self.textbox],
|
||||
[self.textbox, self.saved_input],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
if self.submit_btn:
|
||||
self.submit_btn.click(
|
||||
def _setup_events(self) -> None:
|
||||
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
||||
submit_event = (
|
||||
self.textbox.submit(
|
||||
self._clear_and_save_textbox,
|
||||
[self.textbox],
|
||||
[self.textbox, self.saved_input],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(self.textbox.submit, submit_event)
|
||||
|
||||
if self.submit_btn:
|
||||
click_event = (
|
||||
self.submit_btn.click(
|
||||
self._clear_and_save_textbox,
|
||||
[self.textbox],
|
||||
[self.textbox, self.saved_input],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(self.submit_btn.click, click_event)
|
||||
|
||||
if self.retry_btn:
|
||||
self.retry_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
retry_event = (
|
||||
self.retry_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(self.retry_btn.click, retry_event)
|
||||
|
||||
if self.undo_btn:
|
||||
self.undo_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
@ -276,11 +308,50 @@ class ChatInterface(Blocks):
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
def _setup_api(self):
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
api_fn = self._api_stream_fn
|
||||
else:
|
||||
api_fn = self._api_submit_fn
|
||||
def _setup_stop_events(
|
||||
self, event_trigger: EventListenerMethod, event_to_cancel: Dependency
|
||||
) -> None:
|
||||
if self.stop_btn and self.is_generator:
|
||||
if self.submit_btn:
|
||||
event_trigger(
|
||||
lambda: (Button.update(visible=False), Button.update(visible=True)),
|
||||
None,
|
||||
[self.submit_btn, self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
event_to_cancel.then(
|
||||
lambda: (Button.update(visible=True), Button.update(visible=False)),
|
||||
None,
|
||||
[self.submit_btn, self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
else:
|
||||
event_trigger(
|
||||
lambda: Button.update(visible=True),
|
||||
None,
|
||||
[self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
event_to_cancel.then(
|
||||
lambda: Button.update(visible=False),
|
||||
None,
|
||||
[self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
self.stop_btn.click(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
cancels=event_to_cancel,
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
def _setup_api(self) -> None:
|
||||
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
||||
|
||||
self.fake_api_btn.click(
|
||||
api_fn,
|
||||
|
@ -110,6 +110,7 @@ class App(FastAPI):
|
||||
self.blocks: gradio.Blocks | None = None
|
||||
self.state_holder = {}
|
||||
self.iterators = defaultdict(dict)
|
||||
self.iterators_to_reset = defaultdict(set)
|
||||
self.lock = asyncio.Lock()
|
||||
self.queue_token = secrets.token_urlsafe(32)
|
||||
self.startup_events_triggered = False
|
||||
@ -393,7 +394,7 @@ class App(FastAPI):
|
||||
return {"success": False}
|
||||
async with app.lock:
|
||||
app.iterators[body.session_hash][body.fn_index] = None
|
||||
app.iterators[body.session_hash]["should_reset"].add(body.fn_index)
|
||||
app.iterators_to_reset[body.session_hash].add(body.fn_index)
|
||||
return {"success": True}
|
||||
|
||||
async def run_predict(
|
||||
@ -401,6 +402,7 @@ class App(FastAPI):
|
||||
request: Request | List[Request],
|
||||
fn_index_inferred: int,
|
||||
):
|
||||
fn_index = body.fn_index
|
||||
if hasattr(body, "session_hash"):
|
||||
if body.session_hash not in app.state_holder:
|
||||
app.state_holder[body.session_hash] = {
|
||||
@ -409,21 +411,22 @@ class App(FastAPI):
|
||||
if getattr(block, "stateful", False)
|
||||
}
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
iterators = app.iterators[body.session_hash]
|
||||
# The should_reset set keeps track of the fn_indices
|
||||
# that have been cancelled. When a job is cancelled,
|
||||
# the /reset route will mark the jobs as having been reset.
|
||||
# That way if the cancel job finishes BEFORE the job being cancelled
|
||||
# the job being cancelled will not overwrite the state of the iterator.
|
||||
# In all cases, should_reset will be the empty set the next time
|
||||
# the fn_index is run.
|
||||
app.iterators[body.session_hash]["should_reset"] = set()
|
||||
if fn_index in app.iterators_to_reset[body.session_hash]:
|
||||
iterators = {}
|
||||
app.iterators_to_reset[body.session_hash].remove(fn_index)
|
||||
else:
|
||||
iterators = app.iterators[body.session_hash]
|
||||
else:
|
||||
session_state = {}
|
||||
iterators = {}
|
||||
|
||||
event_id = getattr(body, "event_id", None)
|
||||
raw_input = body.data
|
||||
fn_index = body.fn_index
|
||||
|
||||
dependency = app.get_blocks().dependencies[fn_index_inferred]
|
||||
target = dependency["targets"][0] if len(dependency["targets"]) else None
|
||||
@ -447,10 +450,7 @@ class App(FastAPI):
|
||||
)
|
||||
iterator = output.pop("iterator", None)
|
||||
if hasattr(body, "session_hash"):
|
||||
if fn_index in app.iterators[body.session_hash]["should_reset"]:
|
||||
app.iterators[body.session_hash][fn_index] = None
|
||||
else:
|
||||
app.iterators[body.session_hash][fn_index] = iterator
|
||||
app.iterators[body.session_hash][fn_index] = iterator
|
||||
if isinstance(output, Error):
|
||||
raise output
|
||||
except BaseException as error:
|
||||
|
@ -12,7 +12,6 @@ import os
|
||||
import pkgutil
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
@ -720,9 +719,6 @@ def get_function_with_locals(fn: Callable, blocks: Blocks, event_id: str | None)
|
||||
|
||||
|
||||
async def cancel_tasks(task_ids: set[str]):
|
||||
if sys.version_info < (3, 8):
|
||||
return None
|
||||
|
||||
matching_tasks = [
|
||||
task for task in asyncio.all_tasks() if task.get_name() in task_ids
|
||||
]
|
||||
@ -732,9 +728,7 @@ async def cancel_tasks(task_ids: set[str]):
|
||||
|
||||
|
||||
def set_task_name(task, session_hash: str, fn_index: int, batch: bool):
|
||||
if sys.version_info >= (3, 8) and not (
|
||||
batch
|
||||
): # You shouldn't be able to cancel a task if it's part of a batch
|
||||
if not batch:
|
||||
task.set_name(f"{session_hash}_{fn_index}")
|
||||
|
||||
|
||||
|
@ -1241,10 +1241,6 @@ class TestRender:
|
||||
|
||||
|
||||
class TestCancel:
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Tasks dont have names in 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_function(self, capsys):
|
||||
async def long_job():
|
||||
@ -1266,10 +1262,6 @@ class TestCancel:
|
||||
captured = capsys.readouterr()
|
||||
assert "HELLO FROM LONG JOB" not in captured.out
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Tasks dont have names in 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_function_with_multiple_blocks(self, capsys):
|
||||
async def long_job():
|
||||
@ -1326,24 +1318,6 @@ class TestCancel:
|
||||
cancel.click(None, None, None, cancels=[click])
|
||||
demo.queue().launch(prevent_thread_lock=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_button_for_interfaces(self, connect):
|
||||
def generate(x):
|
||||
for i in range(4):
|
||||
yield i
|
||||
time.sleep(0.2)
|
||||
|
||||
io = gr.Interface(generate, gr.Textbox(), gr.Textbox()).queue()
|
||||
stop_btn_id = next(
|
||||
i for i, k in io.blocks.items() if getattr(k, "value", None) == "Stop"
|
||||
)
|
||||
assert not io.blocks[stop_btn_id].visible
|
||||
|
||||
with connect(io) as client:
|
||||
job = client.submit("freddy", fn_index=1)
|
||||
wait([job])
|
||||
assert job.outputs()[-1] == "3"
|
||||
|
||||
|
||||
class TestEvery:
|
||||
def test_raise_exception_if_parameters_invalid(self):
|
||||
|
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -168,10 +167,6 @@ class TestQueueEstimation:
|
||||
|
||||
|
||||
class TestQueueProcessEvents:
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
|
||||
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
|
||||
@ -314,10 +309,6 @@ class TestQueueProcessEvents:
|
||||
mock_event.disconnect.assert_called_once()
|
||||
assert queue.clean_event.call_count >= 1
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
|
||||
async def test_process_event_handles_exception_during_disconnect(
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Contains tests for networking.py and app.py"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from contextlib import closing
|
||||
from unittest.mock import patch
|
||||
@ -469,10 +468,6 @@ class TestAuthenticatedRoutes:
|
||||
|
||||
class TestQueueRoutes:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Mocks don't work with async context managers in 3.7",
|
||||
)
|
||||
@patch("gradio.routes.get_server_url_from_ws_url", return_value="foo_url")
|
||||
async def test_queue_join_routes_sets_url_if_none_set(self, mock_get_url):
|
||||
io = Interface(lambda x: x, "text", "text").queue()
|
||||
|
Loading…
x
Reference in New Issue
Block a user