Use state in ChatInterface (#4976)

* Use state in chatinterface

* Changelog

* Fix tests

* use chatbot_state for everything

* lint

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-07-20 06:08:50 -05:00 committed by GitHub
parent c93c317bb8
commit 3dc9a65815
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 39 deletions

View File

@ -5,6 +5,7 @@ No changes to highlight.
## New Features:
- Add `show_download_button` param to allow the download button in static Image components to be hidden by [@hannahblair](https://github.com/hannahblair) in [PR 4959](https://github.com/gradio-app/gradio/pull/4959)
- Use `gr.State` in `gr.ChatInterface` to reduce latency by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4976](https://github.com/gradio-app/gradio/pull/4976)
## Bug Fixes:

View File

@ -177,6 +177,7 @@ class ChatInterface(Blocks):
)
self.saved_input = State()
self.chatbot_state = State([])
self._setup_events()
self._setup_api()
@ -195,14 +196,14 @@ class ChatInterface(Blocks):
queue=False,
).then(
self._display_input,
[self.saved_input, self.chatbot],
[self.chatbot],
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
).then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot],
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
)
@ -215,41 +216,41 @@ class ChatInterface(Blocks):
queue=False,
).then(
self._display_input,
[self.saved_input, self.chatbot],
[self.chatbot],
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
).then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot],
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
)
if self.retry_btn:
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot],
[self.chatbot, self.saved_input],
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
self._display_input,
[self.saved_input, self.chatbot],
[self.chatbot],
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
).then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot],
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
)
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.chatbot],
[self.chatbot_state],
[self.chatbot, self.saved_input],
api_name=False,
queue=False,
@ -263,9 +264,9 @@ class ChatInterface(Blocks):
if self.clear_btn:
self.clear_btn.click(
lambda: ([], None),
lambda: ([], [], None),
None,
[self.chatbot, self.saved_input],
[self.chatbot, self.chatbot_state, self.saved_input],
queue=False,
api_name=False,
)
@ -276,14 +277,10 @@ class ChatInterface(Blocks):
else:
api_fn = self._api_submit_fn
# Use a gr.State() instead of self.chatbot so that the API doesn't require passing forth
# a chat history, instead it is just stored internally in the state.
history = State([])
self.fake_api_btn.click(
api_fn,
[self.textbox, history],
[self.textbox, history],
[self.textbox, self.chatbot_state],
[self.textbox, self.chatbot_state],
api_name="chat",
)
@ -292,30 +289,33 @@ class ChatInterface(Blocks):
def _display_input(
self, message: str, history: list[list[str | None]]
) -> list[list[str | None]]:
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history.append([message, None])
return history
return history, history
def _submit_fn(
self, message: str, history_with_input: list[list[str | None]]
) -> list[list[str | None]]:
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history = history_with_input[:-1]
response = self.fn(message, history)
history.append([message, response])
return history
return history, history
def _stream_fn(
self, message: str, history_with_input: list[list[str | None]]
) -> Generator[list[list[str | None]], None, None]:
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
history = history_with_input[:-1]
generator = self.fn(message, history)
try:
first_response = next(generator)
yield history + [[message, first_response]]
update = history + [[message, first_response]]
yield update, update
except StopIteration:
yield history + [[message, None]]
update = history + [[message, None]]
yield update, update
for response in generator:
yield history + [[message, response]]
update = history + [[message, response]]
yield update, update
def _api_submit_fn(
self, message: str, history: list[list[str | None]]
@ -347,9 +347,9 @@ class ChatInterface(Blocks):
def _delete_prev_fn(
self, history: list[list[str | None]]
) -> tuple[list[list[str | None]], str]:
) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
try:
message, _ = history.pop()
except IndexError:
message = ""
return history, message or ""
return history, message or "", history

View File

@ -3,7 +3,7 @@ of the on-page-load event, which is defined in gr.Blocks().load()."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
from gradio_client.documentation import document, set_documentation_group
@ -91,8 +91,8 @@ class EventListenerMethod:
def __call__(
self,
fn: Callable | None,
inputs: Component | list[Component] | set[Component] | None = None,
outputs: Component | list[Component] | None = None,
inputs: Component | Sequence[Component] | set[Component] | None = None,
outputs: Component | Sequence[Component] | None = None,
api_name: str | None | Literal[False] = None,
status_tracker: None = None,
scroll_to_output: bool = False,

View File

@ -1,4 +1,4 @@
import time
from concurrent.futures import wait
import pytest
@ -96,8 +96,7 @@ class TestAPI:
chatbot = gr.ChatInterface(stream).queue()
with connect(chatbot) as client:
job = client.submit("hello")
while not job.done():
time.sleep(0.1)
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
def test_non_streaming_api(self, connect):