mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
Use state in ChatInterface (#4976)
* Use state in chatinterface * Changelog * Fix tests * use chatbot_state for everything * lint --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
c93c317bb8
commit
3dc9a65815
@ -5,6 +5,7 @@ No changes to highlight.
|
||||
## New Features:
|
||||
|
||||
- Add `show_download_button` param to allow the download button in static Image components to be hidden by [@hannahblair](https://github.com/hannahblair) in [PR 4959](https://github.com/gradio-app/gradio/pull/4959)
|
||||
- Use `gr.State` in `gr.ChatInterface` to reduce latency by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4976](https://github.com/gradio-app/gradio/pull/4976)
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
|
@ -177,6 +177,7 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
|
||||
self.saved_input = State()
|
||||
self.chatbot_state = State([])
|
||||
|
||||
self._setup_events()
|
||||
self._setup_api()
|
||||
@ -195,14 +196,14 @@ class ChatInterface(Blocks):
|
||||
queue=False,
|
||||
).then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot],
|
||||
[self.chatbot],
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot],
|
||||
[self.chatbot],
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
@ -215,41 +216,41 @@ class ChatInterface(Blocks):
|
||||
queue=False,
|
||||
).then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot],
|
||||
[self.chatbot],
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot],
|
||||
[self.chatbot],
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
if self.retry_btn:
|
||||
self.retry_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot],
|
||||
[self.chatbot, self.saved_input],
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot],
|
||||
[self.chatbot],
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot],
|
||||
[self.chatbot],
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
if self.undo_btn:
|
||||
self.undo_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot],
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
@ -263,9 +264,9 @@ class ChatInterface(Blocks):
|
||||
|
||||
if self.clear_btn:
|
||||
self.clear_btn.click(
|
||||
lambda: ([], None),
|
||||
lambda: ([], [], None),
|
||||
None,
|
||||
[self.chatbot, self.saved_input],
|
||||
[self.chatbot, self.chatbot_state, self.saved_input],
|
||||
queue=False,
|
||||
api_name=False,
|
||||
)
|
||||
@ -276,14 +277,10 @@ class ChatInterface(Blocks):
|
||||
else:
|
||||
api_fn = self._api_submit_fn
|
||||
|
||||
# Use a gr.State() instead of self.chatbot so that the API doesn't require passing forth
|
||||
# a chat history, instead it is just stored internally in the state.
|
||||
history = State([])
|
||||
|
||||
self.fake_api_btn.click(
|
||||
api_fn,
|
||||
[self.textbox, history],
|
||||
[self.textbox, history],
|
||||
[self.textbox, self.chatbot_state],
|
||||
[self.textbox, self.chatbot_state],
|
||||
api_name="chat",
|
||||
)
|
||||
|
||||
@ -292,30 +289,33 @@ class ChatInterface(Blocks):
|
||||
|
||||
def _display_input(
|
||||
self, message: str, history: list[list[str | None]]
|
||||
) -> list[list[str | None]]:
|
||||
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
||||
history.append([message, None])
|
||||
return history
|
||||
return history, history
|
||||
|
||||
def _submit_fn(
|
||||
self, message: str, history_with_input: list[list[str | None]]
|
||||
) -> list[list[str | None]]:
|
||||
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
||||
history = history_with_input[:-1]
|
||||
response = self.fn(message, history)
|
||||
history.append([message, response])
|
||||
return history
|
||||
return history, history
|
||||
|
||||
def _stream_fn(
|
||||
self, message: str, history_with_input: list[list[str | None]]
|
||||
) -> Generator[list[list[str | None]], None, None]:
|
||||
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
|
||||
history = history_with_input[:-1]
|
||||
generator = self.fn(message, history)
|
||||
try:
|
||||
first_response = next(generator)
|
||||
yield history + [[message, first_response]]
|
||||
update = history + [[message, first_response]]
|
||||
yield update, update
|
||||
except StopIteration:
|
||||
yield history + [[message, None]]
|
||||
update = history + [[message, None]]
|
||||
yield update, update
|
||||
for response in generator:
|
||||
yield history + [[message, response]]
|
||||
update = history + [[message, response]]
|
||||
yield update, update
|
||||
|
||||
def _api_submit_fn(
|
||||
self, message: str, history: list[list[str | None]]
|
||||
@ -347,9 +347,9 @@ class ChatInterface(Blocks):
|
||||
|
||||
def _delete_prev_fn(
|
||||
self, history: list[list[str | None]]
|
||||
) -> tuple[list[list[str | None]], str]:
|
||||
) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
|
||||
try:
|
||||
message, _ = history.pop()
|
||||
except IndexError:
|
||||
message = ""
|
||||
return history, message or ""
|
||||
return history, message or "", history
|
||||
|
@ -3,7 +3,7 @@ of the on-page-load event, which is defined in gr.Blocks().load()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
@ -91,8 +91,8 @@ class EventListenerMethod:
|
||||
def __call__(
|
||||
self,
|
||||
fn: Callable | None,
|
||||
inputs: Component | list[Component] | set[Component] | None = None,
|
||||
outputs: Component | list[Component] | None = None,
|
||||
inputs: Component | Sequence[Component] | set[Component] | None = None,
|
||||
outputs: Component | Sequence[Component] | None = None,
|
||||
api_name: str | None | Literal[False] = None,
|
||||
status_tracker: None = None,
|
||||
scroll_to_output: bool = False,
|
||||
|
@ -1,4 +1,4 @@
|
||||
import time
|
||||
from concurrent.futures import wait
|
||||
|
||||
import pytest
|
||||
|
||||
@ -96,8 +96,7 @@ class TestAPI:
|
||||
chatbot = gr.ChatInterface(stream).queue()
|
||||
with connect(chatbot) as client:
|
||||
job = client.submit("hello")
|
||||
while not job.done():
|
||||
time.sleep(0.1)
|
||||
wait([job])
|
||||
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
|
||||
|
||||
def test_non_streaming_api(self, connect):
|
||||
|
Loading…
Reference in New Issue
Block a user