adds a run_examples_on_click parameter to gr.ChatInterface mirroring the the run_on_click parameter in gr.Examples (#10109)

* add param

* add changeset

* add changeset

* more changes

* add changeset

* slight refactor

* add changeset

* fix

* fixes

* tweak

* clean

* clean

* lint

* upload

* notebook

* more testing

* changes

* notebook

* add changeset

* notebooks

* format

* format

* fix undo

* changes

* changes

* fix

* changes

* fix assert

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-12-07 09:51:00 -06:00 committed by GitHub
parent 6645518a66
commit 48e4aa9d62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 222 additions and 123 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": minor
"website": minor
---
feat:adds a `run_examples_on_click` parameter to `gr.ChatInterface` mirroring the the `run_on_click` parameter in `gr.Examples`

View File

@ -1 +0,0 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatinterface_multimodal/files/avatar.png\n", "!wget -q -O files/cantina.wav https://github.com/gradio-app/gradio/raw/main/demo/chatinterface_multimodal/files/cantina.wav"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def echo(message, history):\n", " return message[\"text\"]\n", "\n", "demo = gr.ChatInterface(\n", " fn=echo,\n", " type=\"messages\",\n", " examples=[{\"text\": \"hello\"}, {\"text\": \"hola\"}, {\"text\": \"merhaba\"}],\n", " title=\"Echo Bot\",\n", " multimodal=True,\n", ")\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -1,13 +0,0 @@
import gradio as gr
def echo(message, history):
return message["text"]
demo = gr.ChatInterface(
fn=echo,
type="messages",
examples=[{"text": "hello"}, {"text": "hola"}, {"text": "merhaba"}],
title="Echo Bot",
multimodal=True,
)
demo.launch()

View File

@ -0,0 +1,19 @@
from pathlib import Path
import gradio as gr
image = str(Path(__file__).parent / "files" / "avatar.png")
audio = str(Path(__file__).parent / "files" / "cantina.wav")
def echo(message, history):
return f"You wrote: {message['text']} and uploaded {len(message['files'])} files."
demo = gr.ChatInterface(
fn=echo,
type="messages",
examples=[{"text": "hello"}, {"text": "hola", "files": [image]}, {"text": "merhaba", "files": [image, audio]}],
title="Echo Bot",
multimodal=True,
)
if __name__ == "__main__":
demo.launch()

View File

Before

Width:  |  Height:  |  Size: 5.2 KiB

After

Width:  |  Height:  |  Size: 5.2 KiB

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_multimodal_examples"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_multimodal_examples/cached_testcase.py\n", "os.mkdir('files')\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_multimodal_examples/files/avatar.png\n", "!wget -q -O files/cantina.wav https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_multimodal_examples/files/cantina.wav"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["from pathlib import Path\n", "import gradio as gr\n", "\n", "image = str(Path(__file__).parent / \"files\" / \"avatar.png\")\n", "audio = str(Path(__file__).parent / \"files\" / \"cantina.wav\")\n", "\n", "def echo(message, history):\n", " return f\"You wrote: {message['text']} and uploaded {len(message['files'])} files.\"\n", "\n", "demo = gr.ChatInterface(\n", " fn=echo,\n", " type=\"messages\",\n", " examples=[{\"text\": \"hello\"}, {\"text\": \"hola\", \"files\": [image]}, {\"text\": \"merhaba\", \"files\": [image, audio]}],\n", " title=\"Echo Bot\",\n", " multimodal=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,19 @@
from pathlib import Path
import gradio as gr
image = str(Path(__file__).parent / "files" / "avatar.png")
audio = str(Path(__file__).parent / "files" / "cantina.wav")
def echo(message, history):
return f"You wrote: {message['text']} and uploaded {len(message['files'])} files."
demo = gr.ChatInterface(
fn=echo,
type="messages",
examples=[{"text": "hello"}, {"text": "hola", "files": [image]}, {"text": "merhaba", "files": [image, audio]}],
title="Echo Bot",
multimodal=True,
)
if __name__ == "__main__":
demo.launch()

View File

@ -29,7 +29,6 @@ from gradio.components import (
)
from gradio.components.chatbot import (
ExampleMessage,
FileDataDict,
Message,
MessageDict,
TupleFormat,
@ -60,7 +59,7 @@ class ChatInterface(Blocks):
demo = gr.ChatInterface(fn=echo, type="messages", examples=[{"text": "hello", "text": "hola", "text": "merhaba"}], title="Echo Bot")
demo.launch()
Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo
Demos: chatinterface_random_response, chatinterface_streaming_echo, chatinterface_artifacts
Guides: creating-a-chatbot-fast, sharing-your-app
"""
@ -78,6 +77,7 @@ class ChatInterface(Blocks):
examples: list[str] | list[MultimodalValue] | list[list] | None = None,
example_labels: list[str] | None = None,
example_icons: list[str] | None = None,
run_examples_on_click: bool = True,
cache_examples: bool | None = None,
cache_mode: Literal["eager", "lazy"] | None = None,
title: str | None = None,
@ -115,6 +115,7 @@ class ChatInterface(Blocks):
example_icons: icons for the examples, to be displayed above the examples. If provided, should be a list of string URLs or local paths with the same length as the examples list. Only applies when examples are displayed within the chatbot (i.e. when `additional_inputs` is not provided).
cache_examples: if True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
cache_mode: if "eager", all examples are cached at app launch. If "lazy", examples are cached for all users after the first use by any user of the app. If None, will use the GRADIO_CACHE_MODE environment variable if defined, or default to "eager".
run_examples_on_click: if True, clicking on an example will run the example through the chatbot fn and the response will be displayed in the chatbot. If False, clicking on an example will only populate the chatbot input with the example message. Has no effect if `cache_examples` is True
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
theme: a Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
@ -165,6 +166,7 @@ class ChatInterface(Blocks):
self.examples_messages = self._setup_example_messages(
examples, example_labels, example_icons
)
self.run_examples_on_click = run_examples_on_click
self.cache_examples = cache_examples
self.cache_mode = cache_mode
self.additional_inputs = [
@ -308,11 +310,11 @@ class ChatInterface(Blocks):
if not input_component.is_rendered:
input_component.render()
self.saved_input = State()
self.saved_input = State() # Stores the most recent user message
self.previous_input = State(value=[]) # Stores all user messages
self.chatbot_state = (
State(self.chatbot.value) if self.chatbot.value else State([])
)
self.previous_input = State(value=[])
self.show_progress = show_progress
self._setup_events()
self._setup_api()
@ -356,7 +358,7 @@ class ChatInterface(Blocks):
queue=False,
)
.then(
self._display_input,
self._append_message_to_history,
[self.saved_input, self.chatbot],
[self.chatbot],
show_api=False,
@ -387,36 +389,38 @@ class ChatInterface(Blocks):
and self.examples
and not self._additional_inputs_in_examples
):
if self.cache_examples:
self.chatbot.example_select(
if self.cache_examples or self.run_examples_on_click:
example_select_event = self.chatbot.example_select(
self.example_clicked,
None,
[self.chatbot, self.saved_input],
show_api=False,
)
if not self.cache_examples:
example_select_event.then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot] + self.additional_outputs,
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
else:
self.chatbot.example_select(
self.example_clicked,
self.example_populated,
None,
[self.chatbot, self.saved_input],
[self.textbox],
show_api=False,
).then(
submit_fn,
[self.saved_input, self.chatbot],
[self.chatbot] + self.additional_outputs,
show_api=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
show_progress=cast(
Literal["full", "minimal", "hidden"], self.show_progress
),
)
retry_event = (
self.chatbot.retry(
self._delete_prev_fn,
[self.saved_input, self.chatbot],
self._pop_last_user_message,
[self.chatbot],
[self.chatbot, self.saved_input],
show_api=False,
queue=False,
@ -427,7 +431,7 @@ class ChatInterface(Blocks):
show_api=False,
)
.then(
self._display_input,
self._append_message_to_history,
[self.saved_input, self.chatbot],
[self.chatbot],
show_api=False,
@ -455,9 +459,15 @@ class ChatInterface(Blocks):
self._setup_stop_events(submit_triggers, [submit_event, retry_event])
self.chatbot.undo(
self._undo_msg,
[self.previous_input, self.chatbot],
[self.chatbot, self.textbox, self.saved_input, self.previous_input],
self._pop_last_user_message,
[self.chatbot],
[self.chatbot, self.saved_input],
show_api=False,
queue=False,
).then(
lambda x: x,
self.saved_input,
self.textbox,
show_api=False,
queue=False,
)
@ -570,66 +580,41 @@ class ChatInterface(Blocks):
str | MultimodalPostprocess,
list[str | MultimodalPostprocess],
]:
if self.multimodal:
previous_input += [message]
return (
MultimodalTextbox("", interactive=False, placeholder=""),
message,
previous_input,
)
else:
previous_input += [message]
return (
Textbox("", interactive=False, placeholder=""),
message,
previous_input,
)
previous_input += [message]
return (
type(self.textbox)("", interactive=False, placeholder=""),
message,
previous_input,
)
def _append_multimodal_history(
def _append_message_to_history(
self,
message: MultimodalPostprocess,
response: MessageDict | str | None,
message: MultimodalPostprocess | str,
history: list[MessageDict] | TupleFormat,
):
if isinstance(message, str):
message = {"text": message}
if self.type == "tuples":
for x in message.get("files", []):
if isinstance(x, dict):
history.append([(x.get("path"),), None]) # type: ignore
else:
history.append([(x,), None]) # type: ignore
x = x.get("path")
history.append([(x,), None]) # type: ignore
if message["text"] is None or not isinstance(message["text"], str):
return
pass
elif message["text"] == "" and message.get("files", []) != []:
history.append([None, response]) # type: ignore
history.append([None, None]) # type: ignore
else:
history.append([message["text"], cast(str, response)]) # type: ignore
history.append([message["text"], None]) # type: ignore
else:
for x in message.get("files", []):
if isinstance(x, dict):
history.append( # type: ignore
{"role": "user", "content": cast(FileDataDict, x)} # type: ignore
)
else:
history.append({"role": "user", "content": (x,)}) # type: ignore
x = x.get("path")
history.append({"role": "user", "content": (x,)}) # type: ignore
if message["text"] is None or not isinstance(message["text"], str):
return
pass
else:
history.append({"role": "user", "content": message["text"]}) # type: ignore
if response:
history.append(cast(MessageDict, response)) # type: ignore
async def _display_input(
self,
message: str | MultimodalPostprocess,
history: TupleFormat | list[MessageDict],
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
if self.multimodal and isinstance(message, dict):
self._append_multimodal_history(message, None, history)
elif isinstance(message, str) and self.type == "tuples":
history.append([message, None]) # type: ignore
elif isinstance(message, str) and self.type == "messages":
history.append({"role": "user", "content": message}) # type: ignore
return history # type: ignore
return history
def response_as_dict(self, response: MessageDict | Message | str) -> MessageDict:
if isinstance(response, Message):
@ -652,7 +637,7 @@ class ChatInterface(Blocks):
history = history_with_input[:-remove_input]
else:
history = history_with_input[:-1]
return message, history
return message, history # type: ignore
def _append_history(self, history, message, first_response=True):
if self.type == "tuples":
@ -757,6 +742,21 @@ class ChatInterface(Blocks):
history.append({"role": "user", "content": option.value})
return history, option.value
def _flatten_example_files(self, example: SelectData):
"""
Returns an example with the files flattened to just the file path.
Also ensures that the `files` key is always present in the example.
"""
example.value["files"] = [f["path"] for f in example.value.get("files", [])]
return example
def example_populated(self, example: SelectData):
if self.multimodal:
example = self._flatten_example_files(example)
return example.value
else:
return example.value["text"]
def example_clicked(
self, example: SelectData
) -> Generator[
@ -767,14 +767,9 @@ class ChatInterface(Blocks):
to the example message. Then, if example caching is enabled, the cached response is loaded
and added to the chat history as well.
"""
if self.type == "tuples":
history = [(example.value["text"], None)]
for file in example.value.get("files", []):
history.append(((file["path"]), None))
else:
history = [MessageDict(role="user", content=example.value["text"])]
for file in example.value.get("files", []):
history.append(MessageDict(role="user", content=file))
history = []
self._append_message_to_history(example.value, history)
example = self._flatten_example_files(example)
message = example.value if self.multimodal else example.value["text"]
yield history, message
if self.cache_examples:
@ -788,16 +783,18 @@ class ChatInterface(Blocks):
if self.multimodal:
message = cast(ExampleMessage, message)
if self.type == "tuples":
if "text" in message:
result.append([message["text"], None])
for file in message.get("files", []):
result.append([file, None])
if "text" in message:
result.append([message["text"], None])
result[-1][1] = response
else:
for file in message.get("files", []):
if isinstance(file, dict):
file = file.get("path")
result.append({"role": "user", "content": (file,)})
if "text" in message:
result.append({"role": "user", "content": message["text"]})
for file in message.get("files", []):
result.append({"role": "assistant", "content": file})
result.append({"role": "assistant", "content": response})
else:
message = cast(str, message)
@ -839,33 +836,60 @@ class ChatInterface(Blocks):
async for response in generator:
yield self._process_example(message, response)
async def _delete_prev_fn(
async def _pop_last_user_message(
self,
message: str | MultimodalPostprocess | None,
history: list[MessageDict] | TupleFormat,
) -> tuple[list[MessageDict] | TupleFormat, str | MultimodalPostprocess]:
extra = 1 if self.type == "messages" else 0
if self.multimodal and isinstance(message, dict):
remove_input = (
len(message.get("files", [])) + 1
if message["text"] is not None
else len(message.get("files", []))
) + extra
history = history[:-remove_input]
"""
Removes the last user message from the chat history and returns it.
If self.multimodal is True, returns a MultimodalPostprocess (dict) object with text and files.
If self.multimodal is False, returns just the message text as a string.
"""
if not history:
return history, "" if not self.multimodal else {"text": "", "files": []}
if self.type == "messages":
# Skip the last message as it's always an assistant message
i = len(history) - 2
while i >= 0 and history[i]["role"] == "user": # type: ignore
i -= 1
last_messages = history[i + 1 :]
last_user_message = ""
files = []
for msg in last_messages:
assert isinstance(msg, dict) # noqa: S101
if msg["role"] == "user":
content = msg["content"]
if isinstance(content, tuple):
files.append(content[0])
else:
last_user_message = content
return_message = (
{"text": last_user_message, "files": files}
if self.multimodal
else last_user_message
)
return history[: i + 1], return_message # type: ignore
else:
history = history[: -(1 + extra)]
return history, message or "" # type: ignore
async def _undo_msg(
self,
previous_input: list[str | MultimodalPostprocess],
history: list[MessageDict] | TupleFormat,
):
msg = previous_input.pop() if previous_input else None
history, msg = await self._delete_prev_fn(msg, history)
previous_msg = previous_input[-1] if len(previous_input) else msg
return history, msg, previous_msg, previous_input
# Skip the last message pair as it always includes an assistant message
i = len(history) - 2
while i >= 0 and history[i][1] is None: # type: ignore
i -= 1
last_messages = history[i + 1 :]
last_user_message = ""
files = []
for msg in last_messages:
assert isinstance(msg, (tuple, list)) # noqa: S101
if isinstance(msg[0], tuple):
files.append(msg[0][0])
elif msg[0] is not None:
last_user_message = msg[0]
return_message = (
{"text": last_user_message, "files": files}
if self.multimodal
else last_user_message
)
return history[: i + 1], return_message # type: ignore
def render(self) -> ChatInterface:
# If this is being rendered inside another Blocks, and the height is not explicitly set, set it to 400 instead of 200.

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import copy
import inspect
import warnings
from collections.abc import Callable, Sequence
@ -512,6 +513,7 @@ class Chatbot(Component):
def _postprocess_message_messages(
self, message: MessageDict | ChatMessage
) -> Message:
message = copy.deepcopy(message)
if isinstance(message, dict):
message["content"] = self._postprocess_content(message["content"])
msg = Message(**message) # type: ignore

View File

@ -114,8 +114,8 @@ demos_by_category = [
"dir": "chatbot_with_tools",
},
{
"name": "Multimodal Chatbot",
"dir": "chatinterface_multimodal",
"name": "Chatinterface with Code",
"dir": "chatinterface_artifacts",
},
]
},

View File

@ -369,7 +369,7 @@ SYSTEM_PROMPT += "Below are examples of full end-to-end Gradio apps:\n\n"
# 'audio_component_events', 'audio_mixer', 'blocks_essay', 'blocks_chained_events', 'blocks_xray', 'chatbot_multimodal', 'sentence_builder', 'custom_css', 'blocks_update', 'fake_gan'
# important_demos = ["annotatedimage_component", "blocks_essay_simple", "blocks_flipper", "blocks_form", "blocks_hello", "blocks_js_load", "blocks_js_methods", "blocks_kinematics", "blocks_layout", "blocks_plug", "blocks_simple_squares", "calculator", "chatbot_consecutive", "chatbot_simple", "chatbot_streaming", "chatinterface_multimodal", "datetimes", "diff_texts", "dropdown_key_up", "fake_diffusion", "fake_gan", "filter_records", "function_values", "gallery_component_events", "generate_tone", "hangman", "hello_blocks", "hello_blocks_decorator", "hello_world", "image_editor", "matrix_transpose", "model3D", "on_listener_decorator", "plot_component", "render_merge", "render_split", "reverse_audio_2", "sales_projections", "sepia_filter", "sort_records", "streaming_simple", "tabbed_interface_lite", "tax_calculator", "theme_soft", "timer", "timer_simple", "variable_outputs", "video_identity"]
important_demos = ['custom_css', "annotatedimage_component", "blocks_essay_simple", "blocks_flipper", "blocks_form", "blocks_hello", "blocks_js_load", "blocks_js_methods", "blocks_kinematics", "blocks_layout", "blocks_plug", "blocks_simple_squares", "calculator", "chatbot_consecutive", "chatbot_simple", "chatbot_streaming", "chatinterface_multimodal", "datetimes", "diff_texts", "dropdown_key_up", "fake_diffusion", "filter_records", "function_values", "gallery_component_events", "generate_tone", "hangman", "hello_blocks", "hello_blocks_decorator", "hello_world", "image_editor", "matrix_transpose", "model3D", "on_listener_decorator", "plot_component", "render_merge", "render_split", "reverse_audio_2", "sales_projections", "sepia_filter", "sort_records", "streaming_simple", "tabbed_interface_lite", "tax_calculator", "theme_soft", "timer", "timer_simple", "variable_outputs", "video_identity"]
important_demos = ['custom_css', "annotatedimage_component", "blocks_essay_simple", "blocks_flipper", "blocks_form", "blocks_hello", "blocks_js_load", "blocks_js_methods", "blocks_kinematics", "blocks_layout", "blocks_plug", "blocks_simple_squares", "calculator", "chatbot_consecutive", "chatbot_simple", "chatbot_streaming", "chatinterface_artifacts", "datetimes", "diff_texts", "dropdown_key_up", "fake_diffusion", "filter_records", "function_values", "gallery_component_events", "generate_tone", "hangman", "hello_blocks", "hello_blocks_decorator", "hello_world", "image_editor", "matrix_transpose", "model3D", "on_listener_decorator", "plot_component", "render_merge", "render_split", "reverse_audio_2", "sales_projections", "sepia_filter", "sort_records", "streaming_simple", "tabbed_interface_lite", "tax_calculator", "theme_soft", "timer", "timer_simple", "variable_outputs", "video_identity"]
def length(demo):

View File

@ -0,0 +1,42 @@
import { test, expect, go_to_testcase } from "@self/tootils";
const cases = ["not_cached", "cached"];
for (const test_case of cases) {
test(`case ${test_case}: clicked example is added to history and passed to chat function`, async ({
page
}) => {
if (cases.slice(1).includes(test_case)) {
await go_to_testcase(page, test_case);
}
// Click on an example and the input/response are correct
await page.getByRole("button", { name: "hello" }).click();
await expect(page.locator(".user p")).toContainText("hello");
await expect(page.locator(".bot p")).toContainText(
"You wrote: hello and uploaded 0 files."
);
await expect(page.locator(".user img")).not.toBeVisible();
// Clear the chat and click on a different example, the input/response are correct
await page.getByLabel("Clear").click();
await page.getByRole("button", { name: "hola example-image" }).click();
await expect(page.locator(".user img")).toBeVisible();
await expect(page.locator(".user p")).toContainText("hola");
await expect(page.locator(".bot p")).toContainText(
"You wrote: hola and uploaded 1 files."
);
// // Retry button works
await page.getByLabel("Retry").click();
await expect(page.locator(".user p")).toContainText("hola");
await expect(page.locator(".bot p")).toContainText(
"You wrote: hola and uploaded 1 files."
);
await expect(page.locator(".user img")).toBeVisible();
// Undo message resets to the examples view
await page.getByLabel("Undo", { exact: true }).click();
await expect(page.getByRole("button", { name: "hello" })).toBeVisible();
});
}