Clean up gr.ChatInterface and fix API type discrepancy (#10185)

* changes

* add changeset

* changes

* clean

* changes

* format

* changes

* clean

* add

* changes

* add changeset

* mutate

* changes

* fix streaming

* test

* format

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* refactor

* Fixes

* fix

* revert

* revert

* revert

* revert

* changes

* revert^3

* format

* comment

* format

* fix example caching

* modify api test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-12-12 12:44:32 -08:00 committed by GitHub
parent 22fe4ce5a1
commit e525680316
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 158 additions and 241 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Clean up `gr.ChatInterface` and fix API type discrepancy

View File

@ -5,7 +5,7 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
from __future__ import annotations
import builtins
import functools
import copy
import inspect
import warnings
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
@ -285,7 +285,7 @@ class ChatInterface(Blocks):
self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox(
label="Response", visible=False
)
) # Used to store the response from the API call
if self.examples:
self.examples_handler = Examples(
@ -311,13 +311,11 @@ class ChatInterface(Blocks):
input_component.render()
self.saved_input = State() # Stores the most recent user message
self.previous_input = State(value=[]) # Stores all user messages
self.chatbot_state = (
State(self.chatbot.value) if self.chatbot.value else State([])
)
self.show_progress = show_progress
self._setup_events()
self._setup_api()
@staticmethod
def _setup_example_messages(
@ -349,40 +347,60 @@ class ChatInterface(Blocks):
if hasattr(self.fn, "zerogpu"):
submit_fn.__func__.zerogpu = self.fn.zerogpu # type: ignore
synchronize_chat_state_kwargs = {
"fn": lambda x: x,
"inputs": [self.chatbot],
"outputs": [self.chatbot_state],
"show_api": False,
"queue": False,
}
submit_fn_kwargs = {
"fn": submit_fn,
"inputs": [self.saved_input, self.chatbot_state] + self.additional_inputs,
"outputs": [self.fake_response_textbox, self.chatbot]
+ self.additional_outputs,
"show_api": False,
"concurrency_limit": cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
"show_progress": cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
}
submit_event = (
self.textbox.submit(
self._clear_and_save_textbox,
[self.textbox, self.previous_input],
[self.textbox, self.saved_input, self.previous_input],
[self.textbox],
[self.textbox, self.saved_input],
show_api=False,
queue=False,
)
.then(
.then( # The reason we do this outside of the submit_fn is that we want to update the chatbot UI with the user message immediately, before the submit_fn is called
self._append_message_to_history,
[self.saved_input, self.chatbot],
[self.chatbot],
show_api=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot] + self.additional_inputs,
[self.chatbot] + self.additional_outputs,
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
.then(**submit_fn_kwargs)
)
submit_event.then(
submit_event.then(**synchronize_chat_state_kwargs).then(
lambda: update(value=None, interactive=True),
None,
self.textbox,
show_api=False,
)
# Creates the "/chat" API endpoint
self.fake_api_btn.click(
submit_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.fake_response_textbox, self.chatbot_state] + self.additional_outputs,
api_name=cast(Union[str, Literal[False]], self.api_name),
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)
if (
isinstance(self.chatbot, Chatbot)
@ -397,18 +415,8 @@ class ChatInterface(Blocks):
show_api=False,
)
if not self.cache_examples:
example_select_event.then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot] + self.additional_outputs,
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
example_select_event = example_select_event.then(**submit_fn_kwargs)
example_select_event.then(**synchronize_chat_state_kwargs)
else:
self.chatbot.example_select(
self.example_populated,
@ -420,8 +428,15 @@ class ChatInterface(Blocks):
retry_event = (
self.chatbot.retry(
self._pop_last_user_message,
[self.chatbot_state],
[self.chatbot_state, self.saved_input],
show_api=False,
queue=False,
)
.then(
self._append_message_to_history,
[self.saved_input, self.chatbot_state],
[self.chatbot],
[self.chatbot, self.saved_input],
show_api=False,
queue=False,
)
@ -430,27 +445,9 @@ class ChatInterface(Blocks):
outputs=[self.textbox],
show_api=False,
)
.then(
self._append_message_to_history,
[self.saved_input, self.chatbot],
[self.chatbot],
show_api=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot] + self.additional_inputs,
[self.chatbot] + self.additional_outputs,
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
.then(**submit_fn_kwargs)
)
retry_event.then(
retry_event.then(**synchronize_chat_state_kwargs).then(
lambda: update(interactive=True),
outputs=[self.textbox],
show_api=False,
@ -461,34 +458,19 @@ class ChatInterface(Blocks):
self.chatbot.undo(
self._pop_last_user_message,
[self.chatbot],
[self.chatbot, self.saved_input],
[self.chatbot, self.textbox],
show_api=False,
queue=False,
).then(
lambda x: x,
self.saved_input,
self.textbox,
show_api=False,
queue=False,
)
).then(**synchronize_chat_state_kwargs)
self.chatbot.option_select(
self.option_clicked,
[self.chatbot],
[self.chatbot, self.saved_input],
show_api=False,
).then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot] + self.additional_outputs,
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
).then(**submit_fn_kwargs).then(**synchronize_chat_state_kwargs)
self.chatbot.clear(**synchronize_chat_state_kwargs)
def _setup_stop_events(
self, event_triggers: list[Callable], events_to_cancel: list[Dependency]
@ -529,176 +511,122 @@ class ChatInterface(Blocks):
show_api=False,
)
def _setup_api(self) -> None:
if self.is_generator:
@functools.wraps(self.fn)
async def api_fn(message, history, *args, **kwargs): # type: ignore
if self.is_async:
generator = self.fn(message, history, *args, **kwargs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, message, history, *args, **kwargs, limiter=self.limiter
)
generator = utils.SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await utils.async_iteration(generator)
yield first_response, history + [[message, first_response]]
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
yield response, history + [[message, response]]
else:
@functools.wraps(self.fn)
async def api_fn(message, history, *args, **kwargs):
if self.is_async:
response = await self.fn(message, history, *args, **kwargs)
else:
response = await anyio.to_thread.run_sync(
self.fn, message, history, *args, **kwargs, limiter=self.limiter
)
history.append([message, response])
return response, history
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.fake_response_textbox, self.chatbot_state],
api_name=cast(Union[str, Literal[False]], self.api_name),
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)
def _clear_and_save_textbox(
self,
message: str | MultimodalPostprocess,
previous_input: list[str | MultimodalPostprocess],
) -> tuple[
Textbox | MultimodalTextbox,
str | MultimodalPostprocess,
list[str | MultimodalPostprocess],
]:
previous_input += [message]
return (
type(self.textbox)("", interactive=False, placeholder=""),
message,
previous_input,
)
@staticmethod
def _messages_to_tuples(history_messages: list[MessageDict]) -> TupleFormat:
history_tuples = []
for message in history_messages:
if message["role"] == "user":
history_tuples.append((message["content"], None))
elif history_tuples and history_tuples[-1][1] is None:
history_tuples[-1] = (history_tuples[-1][0], message["content"])
else:
history_tuples.append((None, message["content"]))
return history_tuples
@staticmethod
def _tuples_to_messages(history_tuples: TupleFormat) -> list[MessageDict]:
history_messages = []
for message_tuple in history_tuples:
if message_tuple[0]:
history_messages.append({"role": "user", "content": message_tuple[0]})
if message_tuple[1]:
history_messages.append(
{"role": "assistant", "content": message_tuple[1]}
)
return history_messages
def _append_message_to_history(
self,
message: MultimodalPostprocess | str,
message: MultimodalPostprocess | str | MessageDict,
history: list[MessageDict] | TupleFormat,
):
role: Literal["user", "assistant"] = "user",
) -> list[MessageDict] | TupleFormat:
if isinstance(message, str):
message = {"text": message}
if self.type == "tuples":
for x in message.get("files", []):
if isinstance(x, dict):
x = x.get("path")
history.append([(x,), None]) # type: ignore
if message["text"] is None or not isinstance(message["text"], str):
pass
elif message["text"] == "" and message.get("files", []) != []:
history.append([None, None]) # type: ignore
else:
history.append([message["text"], None]) # type: ignore
history = self._tuples_to_messages(history) # type: ignore
else:
history = copy.deepcopy(history)
if "content" in message: # in MessageDict format already
history.append(message) # type: ignore
else: # in MultimodalPostprocess format
for x in message.get("files", []):
if isinstance(x, dict):
x = x.get("path")
history.append({"role": "user", "content": (x,)}) # type: ignore
history.append({"role": role, "content": (x,)}) # type: ignore
if message["text"] is None or not isinstance(message["text"], str):
pass
else:
history.append({"role": "user", "content": message["text"]}) # type: ignore
history.append({"role": role, "content": message["text"]}) # type: ignore
if self.type == "tuples":
history = self._messages_to_tuples(history) # type: ignore
return history
def response_as_dict(self, response: MessageDict | Message | str) -> MessageDict:
def response_as_dict(
self, response: MessageDict | Message | str | Component
) -> MessageDict:
if isinstance(response, Message):
new_response = response.model_dump()
elif isinstance(response, str):
elif isinstance(response, (str, Component)):
return {"role": "assistant", "content": response}
else:
new_response = response
return cast(MessageDict, new_response)
def _process_msg_and_trim_history(
self,
message: str | MultimodalPostprocess,
history_with_input: TupleFormat | list[MessageDict],
) -> tuple[str | MultimodalPostprocess, TupleFormat | list[MessageDict]]:
if isinstance(message, dict):
remove_input = len(message.get("files", [])) + int(
message["text"] is not None
)
history = history_with_input[:-remove_input]
else:
history = history_with_input[:-1]
return message, history # type: ignore
def _append_history(self, history, message, first_response=True):
if self.type == "tuples":
if history:
history[-1][1] = message # type: ignore
else:
history.append([message, None])
else:
message = self.response_as_dict(message)
if first_response:
history.append(message) # type: ignore
else:
history[-1] = message
async def _submit_fn(
self,
message: str | MultimodalPostprocess,
history_with_input: TupleFormat | list[MessageDict],
history: TupleFormat | list[MessageDict],
request: Request,
*args,
) -> TupleFormat | list[MessageDict] | tuple[TupleFormat | list[MessageDict], ...]:
message_serialized, history = self._process_msg_and_trim_history(
message, history_with_input
)
) -> tuple:
inputs, _, _ = special_args(
self.fn, inputs=[message_serialized, history, *args], request=request
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
additional_outputs = None
if isinstance(response, tuple):
response, *additional_outputs = response
self._append_history(history_with_input, response)
else:
additional_outputs = None
history = self._append_message_to_history(message, history, "user")
response_ = self.response_as_dict(response)
history = self._append_message_to_history(response_, history, "assistant") # type: ignore
if additional_outputs:
return history_with_input, *additional_outputs
return history_with_input
return response, history, *additional_outputs
return response, history
async def _stream_fn(
self,
message: str | MultimodalPostprocess,
history_with_input: TupleFormat | list[MessageDict],
history: TupleFormat | list[MessageDict],
request: Request,
*args,
) -> AsyncGenerator[
TupleFormat | list[MessageDict] | tuple[TupleFormat | list[MessageDict], ...],
tuple,
None,
]:
message_serialized, history = self._process_msg_and_trim_history(
message, history_with_input
)
inputs, _, _ = special_args(
self.fn, inputs=[message_serialized, history, *args], request=request
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
generator = self.fn(*inputs)
else:
@ -707,28 +635,29 @@ class ChatInterface(Blocks):
)
generator = utils.SyncToAsyncIterator(generator, self.limiter)
history = self._append_message_to_history(message, history, "user")
additional_outputs = None
try:
first_response = await utils.async_iteration(generator)
if isinstance(first_response, tuple):
first_response, *additional_outputs = first_response
self._append_history(history_with_input, first_response)
yield (
history_with_input
if not additional_outputs
else (history_with_input, *additional_outputs)
history_ = self._append_message_to_history(
first_response, history, "assistant"
)
if not additional_outputs:
yield first_response, history_
else:
yield first_response, history_, *additional_outputs
except StopIteration:
yield history_with_input
yield None, history
async for response in generator:
if isinstance(response, tuple):
response, *additional_outputs = response
self._append_history(history_with_input, response, first_response=False)
yield (
history_with_input
if not additional_outputs
else (history_with_input, *additional_outputs)
)
history_ = self._append_message_to_history(response, history, "assistant")
if not additional_outputs:
yield response, history_
else:
yield response, history_, *additional_outputs
def option_clicked(
self, history: list[MessageDict], option: SelectData
@ -767,8 +696,7 @@ class ChatInterface(Blocks):
to the example message. Then, if example caching is enabled, the cached response is loaded
and added to the chat history as well.
"""
history = []
self._append_message_to_history(example.value, history)
history = self._append_message_to_history(example.value, [], "user")
example = self._flatten_example_files(example)
message = example.value if self.multimodal else example.value["text"]
yield history, message
@ -836,7 +764,7 @@ class ChatInterface(Blocks):
async for response in generator:
yield self._process_example(message, response)
async def _pop_last_user_message(
def _pop_last_user_message(
self,
history: list[MessageDict] | TupleFormat,
) -> tuple[list[MessageDict] | TupleFormat, str | MultimodalPostprocess]:
@ -848,48 +776,32 @@ class ChatInterface(Blocks):
if not history:
return history, "" if not self.multimodal else {"text": "", "files": []}
if self.type == "messages":
# Skip the last message as it's always an assistant message
i = len(history) - 2
while i >= 0 and history[i]["role"] == "user": # type: ignore
i -= 1
last_messages = history[i + 1 :]
last_user_message = ""
files = []
for msg in last_messages:
assert isinstance(msg, dict) # noqa: S101
if msg["role"] == "user":
content = msg["content"]
if isinstance(content, tuple):
files.append(content[0])
else:
last_user_message = content
return_message = (
{"text": last_user_message, "files": files}
if self.multimodal
else last_user_message
)
return history[: i + 1], return_message # type: ignore
else:
# Skip the last message pair as it always includes an assistant message
i = len(history) - 2
while i >= 0 and history[i][1] is None: # type: ignore
i -= 1
last_messages = history[i + 1 :]
last_user_message = ""
files = []
for msg in last_messages:
assert isinstance(msg, (tuple, list)) # noqa: S101
if isinstance(msg[0], tuple):
files.append(msg[0][0])
elif msg[0] is not None:
last_user_message = msg[0]
return_message = (
{"text": last_user_message, "files": files}
if self.multimodal
else last_user_message
)
return history[: i + 1], return_message # type: ignore
if self.type == "tuples":
history = self._tuples_to_messages(history) # type: ignore
# Skip the last message as it's always an assistant message
i = len(history) - 2
while i >= 0 and history[i]["role"] == "user": # type: ignore
i -= 1
last_messages = history[i + 1 :]
last_user_message = ""
files = []
for msg in last_messages:
assert isinstance(msg, dict) # noqa: S101
if msg["role"] == "user":
content = msg["content"]
if isinstance(content, tuple):
files.append(content[0])
else:
last_user_message = content
return_message = (
{"text": last_user_message, "files": files}
if self.multimodal
else last_user_message
)
history_ = history[: i + 1]
if self.type == "tuples":
history_ = self._messages_to_tuples(history_) # type: ignore
return history_, return_message # type: ignore
def render(self) -> ChatInterface:
# If this is being rendered inside another Blocks, and the height is not explicitly set, set it to 400 instead of 200.

View File

@ -294,7 +294,7 @@ def music(message, history):
return "Please provide the name of an artist"
gr.ChatInterface(
fake,
music,
type="messages",
textbox=gr.Textbox(placeholder="Which artist's music do you want to listen to?", scale=7),
).launch()

View File

@ -82,7 +82,7 @@ for (const test_case of cases) {
);
const api_recorder = await page.locator("#api-recorder");
await api_recorder.click();
const n_calls = test_case.includes("non_stream") ? 4 : 6;
const n_calls = test_case.includes("non_stream") ? 5 : 7;
await expect(page.locator("#num-recorded-api-calls")).toContainText(
`🪄 Recorded API Calls [${n_calls}]`
);