Converts "Submit" button to "Stop" button in ChatInterface while streaming (#4971)

* initial work

* changes

* changes

* undo routes

* cancel fix

* cancel fix

* remove unnecessary test

* formatting

* changelog

* remove extraneous test

* simplify

* remove no streaming

* changelog

* no progress on textbox

* show progress revert

* clog

* lint

* fixes based on review

* updated textbox props

* fix

* format

* test fix
This commit is contained in:
Abubakar Abid 2023-07-20 15:01:53 -04:00 committed by GitHub
parent b68aeea412
commit f8e5bfa2d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 145 additions and 115 deletions

View File

@ -4,6 +4,11 @@ No changes to highlight.
## New Features:
- The `gr.ChatInterface` UI now converts the "Submit" button to a "Stop" button in ChatInterface while streaming, which can be used to pause generation. By [@abidlabs](https://github.com/abidlabs) in [PR 4971](https://github.com/gradio-app/gradio/pull/4971).
## Bug Fixes:
- Fixes `cancels` for generators so that if a generator is canceled before it is complete, subsequent runs of the event do not continue from the previous iteration, but rather start from the beginning. By [@abidlabs](https://github.com/abidlabs) in [PR 4969](https://github.com/gradio-app/gradio/pull/4969).
- Add `show_download_button` param to allow the download button in static Image components to be hidden by [@hannahblair](https://github.com/hannahblair) in [PR 4959](https://github.com/gradio-app/gradio/pull/4959)
- Added autofocus argument to Textbox by [@aliabid94](https://github.com/aliabid94) in [PR 4978](https://github.com/gradio-app/gradio/pull/4978)
- Use `gr.State` in `gr.ChatInterface` to reduce latency by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4976](https://github.com/gradio-app/gradio/pull/4976)

View File

@ -19,6 +19,7 @@ from gradio.components import (
State,
Textbox,
)
from gradio.events import Dependency, EventListenerMethod
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.layouts import Column, Group, Row
from gradio.themes import ThemeClass as Theme
@ -60,6 +61,7 @@ class ChatInterface(Blocks):
css: str | None = None,
analytics_enabled: bool | None = None,
submit_btn: str | None | Button = "Submit",
stop_btn: str | None | Button = "Stop",
retry_btn: str | None | Button = "🔄 Retry",
undo_btn: str | None | Button = "↩️ Undo",
clear_btn: str | None | Button = "🗑️ Clear",
@ -77,6 +79,7 @@ class ChatInterface(Blocks):
css: custom css or path to custom css file to use with interface.
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
@ -95,6 +98,7 @@ class ChatInterface(Blocks):
)
self.fn = fn
self.is_generator = inspect.isgeneratorfunction(self.fn)
self.examples = examples
if self.space_id and cache_examples is None:
self.cache_examples = True
@ -119,13 +123,15 @@ class ChatInterface(Blocks):
with Group():
with Row():
if textbox:
textbox.container = False
textbox.show_label = False
self.textbox = textbox.render()
else:
self.textbox = Textbox(
container=False,
show_label=False,
placeholder="Type a message...",
scale=10,
scale=7,
autofocus=True,
)
if submit_btn:
@ -139,17 +145,31 @@ class ChatInterface(Blocks):
raise ValueError(
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
)
self.buttons.append(submit_btn)
if stop_btn:
if isinstance(stop_btn, Button):
stop_btn.visible = False
stop_btn.render()
elif isinstance(stop_btn, str):
stop_btn = Button(
stop_btn,
variant="stop",
visible=False,
scale=1,
min_width=150,
)
else:
raise ValueError(
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
)
self.buttons.extend([submit_btn, stop_btn])
with Row():
self.stop_btn = Button("Stop", variant="stop", visible=False)
for btn in [retry_btn, undo_btn, clear_btn]:
if btn:
if isinstance(btn, Button):
btn.render()
elif isinstance(btn, str):
btn = Button(btn, variant="secondary", size="sm")
btn = Button(btn, variant="secondary")
else:
raise ValueError(
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
@ -162,13 +182,14 @@ class ChatInterface(Blocks):
)
(
self.submit_btn,
self.stop_btn,
self.retry_btn,
self.undo_btn,
self.clear_btn,
) = self.buttons
if examples:
if inspect.isgeneratorfunction(self.fn):
if self.is_generator:
examples_fn = self._examples_stream_fn
else:
examples_fn = self._examples_fn
@ -187,76 +208,87 @@ class ChatInterface(Blocks):
self._setup_events()
self._setup_api()
def _setup_events(self):
if inspect.isgeneratorfunction(self.fn):
submit_fn = self._stream_fn
else:
submit_fn = self._submit_fn
self.textbox.submit(
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
).then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
).then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
)
if self.submit_btn:
self.submit_btn.click(
def _setup_events(self) -> None:
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
submit_event = (
self.textbox.submit(
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
).then(
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
).then(
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
)
)
self._setup_stop_events(self.textbox.submit, submit_event)
if self.submit_btn:
click_event = (
self.submit_btn.click(
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
)
)
self._setup_stop_events(self.submit_btn.click, click_event)
if self.retry_btn:
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
).then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
retry_event = (
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
)
)
self._setup_stop_events(self.retry_btn.click, retry_event)
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
@ -276,11 +308,50 @@ class ChatInterface(Blocks):
api_name=False,
)
def _setup_api(self):
if inspect.isgeneratorfunction(self.fn):
api_fn = self._api_stream_fn
else:
api_fn = self._api_submit_fn
def _setup_stop_events(
self, event_trigger: EventListenerMethod, event_to_cancel: Dependency
) -> None:
if self.stop_btn and self.is_generator:
if self.submit_btn:
event_trigger(
lambda: (Button.update(visible=False), Button.update(visible=True)),
None,
[self.submit_btn, self.stop_btn],
api_name=False,
queue=False,
)
event_to_cancel.then(
lambda: (Button.update(visible=True), Button.update(visible=False)),
None,
[self.submit_btn, self.stop_btn],
api_name=False,
queue=False,
)
else:
event_trigger(
lambda: Button.update(visible=True),
None,
[self.stop_btn],
api_name=False,
queue=False,
)
event_to_cancel.then(
lambda: Button.update(visible=False),
None,
[self.stop_btn],
api_name=False,
queue=False,
)
self.stop_btn.click(
None,
None,
None,
cancels=event_to_cancel,
api_name=False,
)
def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
self.fake_api_btn.click(
api_fn,

View File

@ -110,6 +110,7 @@ class App(FastAPI):
self.blocks: gradio.Blocks | None = None
self.state_holder = {}
self.iterators = defaultdict(dict)
self.iterators_to_reset = defaultdict(set)
self.lock = asyncio.Lock()
self.queue_token = secrets.token_urlsafe(32)
self.startup_events_triggered = False
@ -393,7 +394,7 @@ class App(FastAPI):
return {"success": False}
async with app.lock:
app.iterators[body.session_hash][body.fn_index] = None
app.iterators[body.session_hash]["should_reset"].add(body.fn_index)
app.iterators_to_reset[body.session_hash].add(body.fn_index)
return {"success": True}
async def run_predict(
@ -401,6 +402,7 @@ class App(FastAPI):
request: Request | List[Request],
fn_index_inferred: int,
):
fn_index = body.fn_index
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
@ -409,21 +411,22 @@ class App(FastAPI):
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body.session_hash]
iterators = app.iterators[body.session_hash]
# The should_reset set keeps track of the fn_indices
# that have been cancelled. When a job is cancelled,
# the /reset route will mark the jobs as having been reset.
# That way if the cancel job finishes BEFORE the job being cancelled
# the job being cancelled will not overwrite the state of the iterator.
# In all cases, should_reset will be the empty set the next time
# the fn_index is run.
app.iterators[body.session_hash]["should_reset"] = set()
if fn_index in app.iterators_to_reset[body.session_hash]:
iterators = {}
app.iterators_to_reset[body.session_hash].remove(fn_index)
else:
iterators = app.iterators[body.session_hash]
else:
session_state = {}
iterators = {}
event_id = getattr(body, "event_id", None)
raw_input = body.data
fn_index = body.fn_index
dependency = app.get_blocks().dependencies[fn_index_inferred]
target = dependency["targets"][0] if len(dependency["targets"]) else None
@ -447,10 +450,7 @@ class App(FastAPI):
)
iterator = output.pop("iterator", None)
if hasattr(body, "session_hash"):
if fn_index in app.iterators[body.session_hash]["should_reset"]:
app.iterators[body.session_hash][fn_index] = None
else:
app.iterators[body.session_hash][fn_index] = iterator
app.iterators[body.session_hash][fn_index] = iterator
if isinstance(output, Error):
raise output
except BaseException as error:

View File

@ -12,7 +12,6 @@ import os
import pkgutil
import random
import re
import sys
import time
import typing
import warnings
@ -720,9 +719,6 @@ def get_function_with_locals(fn: Callable, blocks: Blocks, event_id: str | None)
async def cancel_tasks(task_ids: set[str]):
if sys.version_info < (3, 8):
return None
matching_tasks = [
task for task in asyncio.all_tasks() if task.get_name() in task_ids
]
@ -732,9 +728,7 @@ async def cancel_tasks(task_ids: set[str]):
def set_task_name(task, session_hash: str, fn_index: int, batch: bool):
if sys.version_info >= (3, 8) and not (
batch
): # You shouldn't be able to cancel a task if it's part of a batch
if not batch:
task.set_name(f"{session_hash}_{fn_index}")

View File

@ -1241,10 +1241,6 @@ class TestRender:
class TestCancel:
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Tasks dont have names in 3.7",
)
@pytest.mark.asyncio
async def test_cancel_function(self, capsys):
async def long_job():
@ -1266,10 +1262,6 @@ class TestCancel:
captured = capsys.readouterr()
assert "HELLO FROM LONG JOB" not in captured.out
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Tasks dont have names in 3.7",
)
@pytest.mark.asyncio
async def test_cancel_function_with_multiple_blocks(self, capsys):
async def long_job():
@ -1326,24 +1318,6 @@ class TestCancel:
cancel.click(None, None, None, cancels=[click])
demo.queue().launch(prevent_thread_lock=True)
@pytest.mark.asyncio
async def test_cancel_button_for_interfaces(self, connect):
def generate(x):
for i in range(4):
yield i
time.sleep(0.2)
io = gr.Interface(generate, gr.Textbox(), gr.Textbox()).queue()
stop_btn_id = next(
i for i, k in io.blocks.items() if getattr(k, "value", None) == "Stop"
)
assert not io.blocks[stop_btn_id].visible
with connect(io) as client:
job = client.submit("freddy", fn_index=1)
wait([job])
assert job.outputs()[-1] == "3"
class TestEvery:
def test_raise_exception_if_parameters_invalid(self):

View File

@ -1,6 +1,5 @@
import asyncio
import os
import sys
from collections import deque
from unittest.mock import MagicMock, patch
@ -168,10 +167,6 @@ class TestQueueEstimation:
class TestQueueProcessEvents:
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
@ -314,10 +309,6 @@ class TestQueueProcessEvents:
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event_handles_exception_during_disconnect(

View File

@ -1,7 +1,6 @@
"""Contains tests for networking.py and app.py"""
import json
import os
import sys
import tempfile
from contextlib import closing
from unittest.mock import patch
@ -469,10 +468,6 @@ class TestAuthenticatedRoutes:
class TestQueueRoutes:
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks don't work with async context managers in 3.7",
)
@patch("gradio.routes.get_server_url_from_ws_url", return_value="foo_url")
async def test_queue_join_routes_sets_url_if_none_set(self, mock_get_url):
io = Interface(lambda x: x, "text", "text").queue()