From 3dc9a65815c842612291726fd75ae19d59cf0999 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Thu, 20 Jul 2023 06:08:50 -0500 Subject: [PATCH] Use state in ChatInterface (#4976) * Use state in chatinterface * Changelog * Fix tests * use chatbot_state for everything * lint --------- Co-authored-by: Abubakar Abid --- CHANGELOG.md | 1 + gradio/chat_interface.py | 66 ++++++++++++++++++------------------- gradio/events.py | 6 ++-- test/test_chat_interface.py | 5 ++- 4 files changed, 39 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 07796da6d2..9f3451318d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ No changes to highlight. ## New Features: - Add `show_download_button` param to allow the download button in static Image components to be hidden by [@hannahblair](https://github.com/hannahblair) in [PR 4959](https://github.com/gradio-app/gradio/pull/4959) +- Use `gr.State` in `gr.ChatInterface` to reduce latency by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4976](https://github.com/gradio-app/gradio/pull/4976) ## Bug Fixes: diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 7c6bc63455..22113542e5 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -177,6 +177,7 @@ class ChatInterface(Blocks): ) self.saved_input = State() + self.chatbot_state = State([]) self._setup_events() self._setup_api() @@ -195,14 +196,14 @@ class ChatInterface(Blocks): queue=False, ).then( self._display_input, - [self.saved_input, self.chatbot], - [self.chatbot], + [self.saved_input, self.chatbot_state], + [self.chatbot, self.chatbot_state], api_name=False, queue=False, ).then( submit_fn, - [self.saved_input, self.chatbot], - [self.chatbot], + [self.saved_input, self.chatbot_state], + [self.chatbot, self.chatbot_state], api_name=False, ) @@ -215,41 +216,41 @@ class ChatInterface(Blocks): queue=False, ).then( self._display_input, - [self.saved_input, self.chatbot], - [self.chatbot], + [self.saved_input, self.chatbot_state], + [self.chatbot, self.chatbot_state], api_name=False, queue=False, ).then( submit_fn, - [self.saved_input, self.chatbot], - [self.chatbot], + [self.saved_input, self.chatbot_state], + [self.chatbot, self.chatbot_state], api_name=False, ) if self.retry_btn: self.retry_btn.click( self._delete_prev_fn, - [self.chatbot], - [self.chatbot, self.saved_input], + [self.chatbot_state], + [self.chatbot, self.saved_input, self.chatbot_state], api_name=False, queue=False, ).then( self._display_input, - [self.saved_input, self.chatbot], - [self.chatbot], + [self.saved_input, self.chatbot_state], + [self.chatbot, self.chatbot_state], api_name=False, queue=False, ).then( submit_fn, - [self.saved_input, self.chatbot], - [self.chatbot], + [self.saved_input, self.chatbot_state], + [self.chatbot, self.chatbot_state], api_name=False, ) if self.undo_btn: self.undo_btn.click( self._delete_prev_fn, - [self.chatbot], + [self.chatbot_state], [self.chatbot, self.saved_input], api_name=False, queue=False, @@ -263,9 +264,9 @@ class ChatInterface(Blocks): if self.clear_btn: self.clear_btn.click( - lambda: ([], None), + lambda: ([], [], None), None, - [self.chatbot, self.saved_input], + [self.chatbot, self.chatbot_state, self.saved_input], queue=False, api_name=False, ) @@ -276,14 +277,10 @@ class ChatInterface(Blocks): else: api_fn = self._api_submit_fn - # Use a gr.State() instead of self.chatbot so that the API doesn't require passing forth - # a chat history, instead it is just stored internally in the state. - history = State([]) - self.fake_api_btn.click( api_fn, - [self.textbox, history], - [self.textbox, history], + [self.textbox, self.chatbot_state], + [self.textbox, self.chatbot_state], api_name="chat", ) @@ -292,30 +289,33 @@ class ChatInterface(Blocks): def _display_input( self, message: str, history: list[list[str | None]] - ) -> list[list[str | None]]: + ) -> tuple[list[list[str | None]], list[list[str | None]]]: history.append([message, None]) - return history + return history, history def _submit_fn( self, message: str, history_with_input: list[list[str | None]] - ) -> list[list[str | None]]: + ) -> tuple[list[list[str | None]], list[list[str | None]]]: history = history_with_input[:-1] response = self.fn(message, history) history.append([message, response]) - return history + return history, history def _stream_fn( self, message: str, history_with_input: list[list[str | None]] - ) -> Generator[list[list[str | None]], None, None]: + ) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]: history = history_with_input[:-1] generator = self.fn(message, history) try: first_response = next(generator) - yield history + [[message, first_response]] + update = history + [[message, first_response]] + yield update, update except StopIteration: - yield history + [[message, None]] + update = history + [[message, None]] + yield update, update for response in generator: - yield history + [[message, response]] + update = history + [[message, response]] + yield update, update def _api_submit_fn( self, message: str, history: list[list[str | None]] @@ -347,9 +347,9 @@ class ChatInterface(Blocks): def _delete_prev_fn( self, history: list[list[str | None]] - ) -> tuple[list[list[str | None]], str]: + ) -> tuple[list[list[str | None]], str, list[list[str | None]]]: try: message, _ = history.pop() except IndexError: message = "" - return history, message or "" + return history, message or "", history diff --git a/gradio/events.py b/gradio/events.py index 524bceb312..c5c4622d7e 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -3,7 +3,7 @@ of the on-page-load event, which is defined in gr.Blocks().load().""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence from gradio_client.documentation import document, set_documentation_group @@ -91,8 +91,8 @@ class EventListenerMethod: def __call__( self, fn: Callable | None, - inputs: Component | list[Component] | set[Component] | None = None, - outputs: Component | list[Component] | None = None, + inputs: Component | Sequence[Component] | set[Component] | None = None, + outputs: Component | Sequence[Component] | None = None, api_name: str | None | Literal[False] = None, status_tracker: None = None, scroll_to_output: bool = False, diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index aba50c4a81..b398325da0 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -1,4 +1,4 @@ -import time +from concurrent.futures import wait import pytest @@ -96,8 +96,7 @@ class TestAPI: chatbot = gr.ChatInterface(stream).queue() with connect(chatbot) as client: job = client.submit("hello") - while not job.done(): - time.sleep(0.1) + wait([job]) assert job.outputs() == ["h", "he", "hel", "hell", "hello"] def test_non_streaming_api(self, connect):