Support Message API for chatbot and chatinterface (#8422)

* first commit

* Add code

* Tests + code

* lint

* Add code

* notebook

* add changeset

* type

* Add client test

* type

* Add code

* Chatbot type

* Add code

* test chatbot

* fix e2e test

* js tests

* Consolidate Error and Tool message. Allow Messages in postprocess

* Rename to messages

* fix tests

* notebook clean

* More tests and messages

* add changeset

* notebook

* client test

* Fix issues

* Chatbot docs

* add changeset

* Add image

* Add img tag

* Address comments

* Add code

* Revert chatinterface streaming change. Use title in metadata. Address pngwn comments

* Add code

* changelog highlight

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-07-10 13:08:06 +02:00 committed by GitHub
parent 936c7137a9
commit 4221290d84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 1856 additions and 656 deletions

View File

@ -0,0 +1,93 @@
---
"@gradio/chatbot": minor
"@gradio/tootils": minor
"gradio": minor
"website": minor
---
highlight:
#### Support message format in chatbot 💬
`gr.Chatbot` and `gr.ChatInterface` now support the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api#messages-api), which is fully compatible with LLM API providers such as Hugging Face Text Generation Inference, OpenAI's chat completions API, and Llama.cpp server.
Building Gradio applications around these LLM solutions is now even easier!
`gr.Chatbot` and `gr.ChatInterface` now have a `msg_format` parameter that can accept two values - `'tuples'` and `'messages'`. If set to `'tuples'`, the default chatbot data format is expected. If set to `'messages'`, a list of dictionaries with `content` and `role` keys is expected. See below -
```python
def chat_greeter(msg, history):
history.append({"role": "assistant", "content": "Hello!"})
return history
```
Additionally, gradio now exposes a `gr.ChatMessage` dataclass you can use for IDE type hints and auto completion.
<img width="852" alt="image" src="https://github.com/freddyaboulton/freddyboulton/assets/41651716/d283e8f3-b194-466a-8194-c7e697dca9ad">
#### Tool use in Chatbot 🛠️
The Gradio Chatbot can now natively display tool usage and intermediate thoughts common in Agent and chain-of-thought workflows!
If you are using the new "messages" format, simply add a `metadata` key with a dictionary containing a `title` key and `value`. This will display the assistant message in an expandable message box to show the result of a tool or intermediate step.
```python
import gradio as gr
from gradio import ChatMessage
import time
def generate_response(history):
history.append(ChatMessage(role="user", content="What is the weather in San Francisco right now?"))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="In order to find the current weather in San Francisco, I will need to use my weather tool.")
)
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="API Error when connecting to weather service.",
metadata={"title": "💥 Error using tool 'Weather'"})
)
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="I will try again",
))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="Weather 72 degrees Fahrenheit with 20% chance of rain.",
metadata={"title": "🛠️ Used tool 'Weather'"}
))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="Now that the API succeeded I can complete my task.",
))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!",
))
yield history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
button = gr.Button("Get San Francisco Weather")
button.click(generate_response, chatbot, chatbot)
if __name__ == "__main__":
demo.launch()
```
![tool-box-demo](https://github.com/freddyaboulton/freddyboulton/assets/41651716/cf73ecc9-90ac-42ce-bca5-768e0cc00a48)

View File

@ -19,7 +19,10 @@ const test_files = readdirSync(TEST_FILES_PATH)
!f.endsWith(".component.spec.ts") &&
!f.endsWith(".reload.spec.ts")
)
.map((f) => basename(f, ".spec.ts"));
.map((f) => ({
module_name: `${basename(f, ".spec.ts")}.run`,
dir_name: basename(f, ".spec.ts")
}));
export default async function global_setup() {
const verbose = process.env.GRADIO_TEST_VERBOSE;
@ -29,7 +32,24 @@ export default async function global_setup() {
process.stdout.write(kl.yellow("\nCreating test gradio app.\n\n"));
const test_app = make_app(test_files, port);
const test_cases = [];
// check if there is a testcase file in the same directory as the test file
// if there is, append that to the file
test_files.forEach((value) => {
const test_case_dir = join(ROOT, "demo", value.dir_name);
readdirSync(test_case_dir)
.filter((f) => f.endsWith("_testcase.py"))
.forEach((f) => {
test_cases.push({
module_name: `${value.dir_name}.${basename(f, ".py")}`,
dir_name: `${value.dir_name}_${basename(f, ".py")}`
});
});
});
const all_test_files = test_files.concat(test_cases);
const test_app = make_app(all_test_files, port);
process.stdout.write(kl.yellow("App created. Starting test server.\n\n"));
process.stdout.write(kl.bgBlue(" =========================== \n"));
@ -111,14 +131,14 @@ import uvicorn
from fastapi import FastAPI
import gradio as gr
${demos.map((d) => `from demo.${d}.run import demo as ${d}`).join("\n")}
${demos.map((obj) => `from demo.${obj.module_name} import demo as ${obj.dir_name}`).join("\n")}
app = FastAPI()
${demos
.map(
(d) =>
`app = gr.mount_gradio_app(app, ${d}, path="/${d}", max_file_size=${
d == "upload_file_limit_test" ? "'15kb'" : "None"
(obj) =>
`app = gr.mount_gradio_app(app, ${obj.dir_name}, path="/${obj.dir_name}", max_file_size=${
obj.dir_name == "upload_file_limit_test" ? "'15kb'" : "None"
})`
)
.join("\n")}

View File

@ -459,3 +459,26 @@ def max_file_size_demo():
)
return demo
@pytest.fixture
def chatbot_message_format():
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
msg = gr.Textbox()
def respond(message, chat_history: list):
bot_message = random.choice(
["How are you?", "I love you", "I'm very hungry"]
)
chat_history.extend(
[
{"role": "user", "content": message},
{"role": "assistant", "content": bot_message},
]
)
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot], api_name="chat")
return demo

View File

@ -682,6 +682,25 @@ class TestClientPredictionsWithKwargs:
):
client.predict(num1=3, operation="add", api_name="/predict")
def test_chatbot_message_format(self, chatbot_message_format):
with connect(chatbot_message_format) as client:
_, history = client.predict("hello", [], api_name="/chat")
assert history[1]["role"] == "assistant"
assert history[1]["content"] in [
"How are you?",
"I love you",
"I'm very hungry",
]
_, history = client.predict("hi", history, api_name="/chat")
assert history[2]["role"] == "user"
assert history[2]["content"] == "hi"
assert history[3]["role"] == "assistant"
assert history[3]["content"] in [
"How are you?",
"I love you",
"I'm very hungry",
]
class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")

View File

@ -0,0 +1,101 @@
import gradio as gr
import random
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
color_map = {
"harmful": "crimson",
"neutral": "gray",
"beneficial": "green",
}
def html_src(harm_level):
return f"""
<div style="display: flex; gap: 5px;padding: 2px 4px;margin-top: -40px">
<div style="background-color: {color_map[harm_level]}; padding: 2px; border-radius: 5px;">
{harm_level}
</div>
</div>
"""
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def add_message(history, message):
for x in message["files"]:
history.append({"role": "user", "content": {"path": x}})
if message["text"] is not None:
history.append({"role": "user", "content": message['text']})
return history, gr.MultimodalTextbox(value=None, interactive=False)
def bot(history, response_type):
if response_type == "gallery":
msg = {"role": "assistant", "content": gr.Gallery(
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"]
)
}
elif response_type == "image":
msg = {"role": "assistant",
"content": gr.Image("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png")
}
elif response_type == "video":
msg = {"role": "assistant",
"content": gr.Video("https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4")
}
elif response_type == "audio":
msg = {"role": "assistant",
"content": gr.Audio("https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav")
}
elif response_type == "html":
msg = {"role": "assistant",
"content": gr.HTML(
html_src(random.choice(["harmful", "neutral", "beneficial"]))
)
}
else:
msg = {"role": "assistant", "content": "Cool!"}
history.append(msg)
return history
with gr.Blocks(fill_height=True) as demo:
chatbot = gr.Chatbot(
elem_id="chatbot",
bubble_full_width=False,
scale=1,
msg_format="messages"
)
response_type = gr.Radio(
[
"image",
"text",
"gallery",
"video",
"audio",
"html",
],
value="text",
label="Response Type",
)
chat_input = gr.MultimodalTextbox(
interactive=True,
placeholder="Enter message or upload file...",
show_label=False,
)
chat_msg = chat_input.submit(
add_message, [chatbot, chat_input], [chatbot, chat_input]
)
bot_msg = chat_msg.then(
bot, [chatbot, response_type], chatbot, api_name="bot_response"
)
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
chatbot.like(print_like_dislike, None, None)
demo.queue()
if __name__ == "__main__":
demo.launch()

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_core_components_simple"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/audio.wav https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/audio.wav\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/avatar.png\n", "!wget -q -O files/sample.txt https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/sample.txt\n", "!wget -q -O files/world.mp4 https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/world.mp4"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import random\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "color_map = {\n", " \"harmful\": \"crimson\",\n", " \"neutral\": \"gray\",\n", " \"beneficial\": \"green\",\n", "}\n", "\n", "\n", "def html_src(harm_level):\n", " return f\"\"\"\n", "<div style=\"display: flex; gap: 5px;\">\n", " <div style=\"background-color: {color_map[harm_level]}; padding: 2px; border-radius: 5px;\">\n", " {harm_level}\n", " </div>\n", "</div>\n", "\"\"\"\n", "\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "\n", "def bot(history, response_type):\n", " if response_type == \"gallery\":\n", " history[-1][1] = gr.Gallery(\n", " [\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " ]\n", " )\n", " elif response_type == \"image\":\n", " history[-1][1] = gr.Image(\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\"\n", " )\n", " elif response_type == \"video\":\n", " history[-1][1] = gr.Video(\n", " \"https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4\"\n", " )\n", " elif response_type == \"audio\":\n", " history[-1][1] = gr.Audio(\n", " \"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav\"\n", " )\n", " elif response_type == \"html\":\n", " history[-1][1] = gr.HTML(\n", " html_src(random.choice([\"harmful\", \"neutral\", \"beneficial\"]))\n", " )\n", " else:\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", " response_type = gr.Radio(\n", " [\n", " \"image\",\n", " \"text\",\n", " \"gallery\",\n", " \"video\",\n", " \"audio\",\n", " \"html\",\n", " ],\n", " value=\"text\",\n", " label=\"Response Type\",\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(\n", " interactive=True,\n", " placeholder=\"Enter message or upload file...\",\n", " show_label=False,\n", " )\n", "\n", " chat_msg = chat_input.submit(\n", " add_message, [chatbot, chat_input], [chatbot, chat_input]\n", " )\n", " bot_msg = chat_msg.then(\n", " bot, [chatbot, response_type], chatbot, api_name=\"bot_response\"\n", " )\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_core_components_simple"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/audio.wav https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/audio.wav\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/avatar.png\n", "!wget -q -O files/sample.txt https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/sample.txt\n", "!wget -q -O files/world.mp4 https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/world.mp4\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import random\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "color_map = {\n", " \"harmful\": \"crimson\",\n", " \"neutral\": \"gray\",\n", " \"beneficial\": \"green\",\n", "}\n", "\n", "\n", "def html_src(harm_level):\n", " return f\"\"\"\n", "<div style=\"display: flex; gap: 5px;\">\n", " <div style=\"background-color: {color_map[harm_level]}; padding: 2px; border-radius: 5px;\">\n", " {harm_level}\n", " </div>\n", "</div>\n", "\"\"\"\n", "\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "\n", "def bot(history, response_type):\n", " if response_type == \"gallery\":\n", " history[-1][1] = gr.Gallery(\n", " [\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " ]\n", " )\n", " elif response_type == \"image\":\n", " history[-1][1] = gr.Image(\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\"\n", " )\n", " elif response_type == \"video\":\n", " history[-1][1] = gr.Video(\n", " \"https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4\"\n", " )\n", " elif response_type == \"audio\":\n", " history[-1][1] = gr.Audio(\n", " \"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav\"\n", " )\n", " elif response_type == \"html\":\n", " history[-1][1] = gr.HTML(\n", " html_src(random.choice([\"harmful\", \"neutral\", \"beneficial\"]))\n", " )\n", " else:\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", " response_type = gr.Radio(\n", " [\n", " \"image\",\n", " \"text\",\n", " \"gallery\",\n", " \"video\",\n", " \"audio\",\n", " \"html\",\n", " ],\n", " value=\"text\",\n", " label=\"Response Type\",\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(\n", " interactive=True,\n", " placeholder=\"Enter message or upload file...\",\n", " show_label=False,\n", " )\n", "\n", " chat_msg = chat_input.submit(\n", " add_message, [chatbot, chat_input], [chatbot, chat_input]\n", " )\n", " bot_msg = chat_msg.then(\n", " bot, [chatbot, response_type], chatbot, api_name=\"bot_response\"\n", " )\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,45 @@
import gradio as gr
import time
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
def add_message(history, message):
for x in message["files"]:
history.append({"role": "user", "content": {"path": x}})
if message["text"] is not None:
history.append({"role": "user", "content": message["text"]})
return history, gr.MultimodalTextbox(value=None, interactive=False)
def bot(history: list):
response = "**That's cool!**"
history.append({"role": "assistant", "content": ""})
for character in response:
history[-1]['content'] += character
time.sleep(0.05)
yield history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
msg_format="messages"
)
chat_input = gr.MultimodalTextbox(interactive=True,
file_count="multiple",
placeholder="Enter message or upload file...", show_label=False)
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
chatbot.like(print_like_dislike, None, None)
if __name__ == "__main__":
demo.launch()

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio plotly"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/files/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import plotly.express as px\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "def random_plot():\n", " df = px.data.iris()\n", " fig = px.scatter(df, x=\"sepal_width\", y=\"sepal_length\", color=\"species\",\n", " size='petal_length', hover_data=['petal_width'])\n", " return fig\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "def bot(history):\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "fig = random_plot()\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(interactive=True,\n", " file_count=\"multiple\",\n", " placeholder=\"Enter message or upload file...\", show_label=False)\n", "\n", " chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])\n", " bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name=\"bot_response\")\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio plotly"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/files/avatar.png\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import plotly.express as px\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "def random_plot():\n", " df = px.data.iris()\n", " fig = px.scatter(df, x=\"sepal_width\", y=\"sepal_length\", color=\"species\",\n", " size='petal_length', hover_data=['petal_width'])\n", " return fig\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "def bot(history):\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "fig = random_plot()\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(interactive=True,\n", " file_count=\"multiple\",\n", " placeholder=\"Enter message or upload file...\", show_label=False)\n", "\n", " chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])\n", " bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name=\"bot_response\")\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_streaming"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " msg = gr.Textbox()\n", " clear = gr.Button(\"Clear\")\n", "\n", " def user(user_message, history):\n", " return \"\", history + [[user_message, None]]\n", "\n", " def bot(history):\n", " bot_message = random.choice([\"How are you?\", \"I love you\", \"I'm very hungry\"])\n", " history[-1][1] = \"\"\n", " for character in bot_message:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", " \n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_streaming"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_streaming/testcase_messages.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " msg = gr.Textbox()\n", " clear = gr.Button(\"Clear\")\n", "\n", " def user(user_message, history):\n", " return \"\", history + [[user_message, None]]\n", "\n", " def bot(history):\n", " bot_message = random.choice([\"How are you?\", \"I love you\", \"I'm very hungry\"])\n", " history[-1][1] = \"\"\n", " for character in bot_message:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -22,7 +22,6 @@ with gr.Blocks() as demo:
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1,28 @@
import gradio as gr
import random
import time
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
def bot(history: list):
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
history.append({"role": "assistant", "content": ""})
for character in bot_message:
history[-1]['content'] += character
time.sleep(0.05)
yield history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_with_tools"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from gradio import ChatMessage\n", "import time\n", "\n", "def generate_response(history):\n", " history.append(ChatMessage(role=\"user\", content=\"What is the weather in San Francisco right now?\"))\n", " yield history\n", " time.sleep(0.25)\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"In order to find the current weather in San Francisco, I will need to use my weather tool.\")\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"API Error when connecting to weather service.\",\n", " metadata={\"title\": \"\ud83d\udca5 Error using tool 'Weather'\"})\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"I will try again\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Weather 72 degrees Fahrenheit with 20% chance of rain.\",\n", " metadata={\"title\": \"\ud83d\udee0\ufe0f Used tool 'Weather'\"}\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Now that the API succeeded I can complete my task.\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!\",\n", " ))\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(msg_format=\"messages\")\n", " button = gr.Button(\"Get San Francisco Weather\")\n", " button.click(generate_response, chatbot, chatbot)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,53 @@
import gradio as gr
from gradio import ChatMessage
import time
def generate_response(history):
history.append(ChatMessage(role="user", content="What is the weather in San Francisco right now?"))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="In order to find the current weather in San Francisco, I will need to use my weather tool.")
)
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="API Error when connecting to weather service.",
metadata={"title": "💥 Error using tool 'Weather'"})
)
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="I will try again",
))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="Weather 72 degrees Fahrenheit with 20% chance of rain.",
metadata={"title": "🛠️ Used tool 'Weather'"}
))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="Now that the API succeeded I can complete my task.",
))
yield history
time.sleep(0.25)
history.append(ChatMessage(role="assistant",
content="It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!",
))
yield history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
button = gr.Button("Get San Francisco Weather")
button.click(generate_response, chatbot, chatbot)
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1,15 @@
import time
import gradio as gr
def slow_echo(message, history):
for i in range(len(message)):
time.sleep(0.05)
yield "You typed: " + message[: i + 1]
demo = gr.ChatInterface(slow_echo, msg_format="messages")
if __name__ == "__main__":
demo.launch()

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "\n", "def slow_echo(message, history):\n", " for i in range(len(message)):\n", " time.sleep(0.05)\n", " yield \"You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo).queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatinterface_streaming_echo/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "\n", "def slow_echo(message, history):\n", " for i in range(len(message)):\n", " time.sleep(0.05)\n", " yield \"You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -8,7 +8,7 @@ def slow_echo(message, history):
yield "You typed: " + message[: i + 1]
demo = gr.ChatInterface(slow_echo).queue()
demo = gr.ChatInterface(slow_echo)
if __name__ == "__main__":
demo.launch()

View File

@ -0,0 +1,17 @@
import time
import gradio as gr
runs = 0
def slow_echo(message, history):
global runs # i didn't want to add state or anything to this demo
runs = runs + 1
for i in range(len(message)):
time.sleep(0.05)
yield f"Run {runs} - You typed: " + message[: i + 1]
demo = gr.ChatInterface(slow_echo, msg_format="messages").queue()
if __name__ == "__main__":
demo.launch()

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo).queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo).queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -16,6 +16,7 @@ from gradio.components import (
BarPlot,
Button,
Chatbot,
ChatMessage,
Checkbox,
CheckboxGroup,
Checkboxgroup,
@ -42,6 +43,7 @@ from gradio.components import (
LoginButton,
LogoutButton,
Markdown,
MessageDict,
Model3D,
MultimodalTextbox,
Number,

View File

@ -6,6 +6,7 @@ from __future__ import annotations
import functools
import inspect
import warnings
from typing import AsyncGenerator, Callable, Literal, Union, cast
import anyio
@ -22,6 +23,8 @@ from gradio.components import (
Textbox,
get_component_instance,
)
from gradio.components.chatbot import FileDataDict, Message, MessageDict, TupleFormat
from gradio.components.multimodal_textbox import MultimodalData
from gradio.events import Dependency, on
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.helpers import special_args
@ -56,6 +59,7 @@ class ChatInterface(Blocks):
fn: Callable,
*,
multimodal: bool = False,
msg_format: Literal["messages", "tuples"] = "tuples",
chatbot: Chatbot | None = None,
textbox: Textbox | MultimodalTextbox | None = None,
additional_inputs: str | Component | list[str | Component] | None = None,
@ -123,6 +127,7 @@ class ChatInterface(Blocks):
fill_height=fill_height,
delete_cache=delete_cache,
)
self.msg_format: Literal["messages", "tuples"] = msg_format
self.multimodal = multimodal
self.concurrency_limit = concurrency_limit
self.fn = fn
@ -182,10 +187,19 @@ class ChatInterface(Blocks):
Markdown(description)
if chatbot:
if self.msg_format != chatbot.msg_format:
warnings.warn(
"The msg_format of the chatbot does not match the msg_format of the chat interface. The msg_format of the chat interface will be used."
"Recieved msg_format of chatbot: {chatbot.msg_format}, msg_format of chat interface: {self.msg_format}"
)
chatbot.msg_format = self.msg_format
self.chatbot = get_component_instance(chatbot, render=True)
else:
self.chatbot = Chatbot(
label="Chatbot", scale=1, height=200 if fill_height else None
label="Chatbot",
scale=1,
height=200 if fill_height else None,
msg_format=self.msg_format,
)
with Row():
@ -329,6 +343,7 @@ class ChatInterface(Blocks):
[self.textbox, self.saved_input],
show_api=False,
queue=False,
preprocess=False,
)
.then(
self._display_input,
@ -383,6 +398,12 @@ class ChatInterface(Blocks):
)
self._setup_stop_events([self.retry_btn.click], retry_event)
async def format_textbox(data: str | MultimodalData) -> str | dict:
if isinstance(data, MultimodalData):
return {"text": data.text, "files": [x.path for x in data.files]}
else:
return data
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
@ -391,7 +412,7 @@ class ChatInterface(Blocks):
show_api=False,
queue=False,
).then(
async_lambda(lambda x: x),
format_textbox,
[self.saved_input],
[self.textbox],
show_api=False,
@ -498,54 +519,82 @@ class ChatInterface(Blocks):
),
)
def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]:
def _clear_and_save_textbox(
self, message: str | dict
) -> tuple[str | dict, str | MultimodalData]:
if self.multimodal:
return {"text": "", "files": []}, message
return {"text": "", "files": []}, MultimodalData(**cast(dict, message))
else:
return "", message
return "", cast(str, message)
def _append_multimodal_history(
self,
message: dict[str, list],
response: str | None,
history: list[list[str | tuple | None]],
message: MultimodalData,
response: MessageDict | str | None,
history: list[MessageDict] | TupleFormat,
):
for x in message["files"]:
history.append([(x,), None])
if message["text"] is None or not isinstance(message["text"], str):
return
elif message["text"] == "" and message["files"] != []:
history.append([None, response])
if self.msg_format == "tuples":
for x in message.files:
history.append([(x.path,), None]) # type: ignore
if message.text is None or not isinstance(message.text, str):
return
elif message.text == "" and message.files != []:
history.append([None, response]) # type: ignore
else:
history.append([message.text, cast(str, response)]) # type: ignore
else:
history.append([message["text"], response])
for x in message.files:
history.append(
{"role": "user", "content": cast(FileDataDict, x.model_dump())} # type: ignore
)
if message.text is None or not isinstance(message.text, str):
return
else:
history.append({"role": "user", "content": message.text}) # type: ignore
if response:
history.append(cast(MessageDict, response)) # type: ignore
async def _display_input(
self, message: str | dict[str, list], history: list[list[str | tuple | None]]
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
if self.multimodal and isinstance(message, dict):
self, message: str | MultimodalData, history: TupleFormat | list[MessageDict]
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
if self.multimodal and isinstance(message, MultimodalData):
self._append_multimodal_history(message, None, history)
elif isinstance(message, str):
history.append([message, None])
return history, history
elif isinstance(message, str) and self.msg_format == "tuples":
history.append([message, None]) # type: ignore
elif isinstance(message, str) and self.msg_format == "messages":
history.append({"role": "user", "content": message}) # type: ignore
return history, history # type: ignore
def response_as_dict(self, response: MessageDict | Message | str) -> MessageDict:
if isinstance(response, Message):
new_response = response.model_dump()
elif isinstance(response, str):
return {"role": "assistant", "content": response}
else:
new_response = response
return cast(MessageDict, new_response)
async def _submit_fn(
self,
message: str | dict[str, list],
history_with_input: list[list[str | tuple | None]],
message: str | MultimodalData,
history_with_input: TupleFormat | list[MessageDict],
request: Request,
*args,
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
if self.multimodal and isinstance(message, dict):
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
if self.multimodal and isinstance(message, MultimodalData):
remove_input = (
len(message["files"]) + 1
if message["text"] is not None
else len(message["files"])
len(message.files) + 1
if message.text is not None
else len(message.files)
)
history = history_with_input[:-remove_input]
message_serialized = message.model_dump()
else:
history = history_with_input[:-1]
message_serialized = message
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
self.fn, inputs=[message_serialized, history, *args], request=request
)
if self.is_async:
@ -555,24 +604,31 @@ class ChatInterface(Blocks):
self.fn, *inputs, limiter=self.limiter
)
if self.multimodal and isinstance(message, dict):
self._append_multimodal_history(message, response, history)
elif isinstance(message, str):
history.append([message, response])
return history, history
if self.msg_format == "messages":
new_response = self.response_as_dict(response)
else:
new_response = response
if self.multimodal and isinstance(message, MultimodalData):
self._append_multimodal_history(message, new_response, history) # type: ignore
elif isinstance(message, str) and self.msg_format == "tuples":
history.append([message, new_response]) # type: ignore
elif isinstance(message, str) and self.msg_format == "messages":
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
return history, history # type: ignore
async def _stream_fn(
self,
message: str | dict[str, list],
history_with_input: list[list[str | tuple | None]],
message: str | MultimodalData,
history_with_input: TupleFormat | list[MessageDict],
request: Request,
*args,
) -> AsyncGenerator:
if self.multimodal and isinstance(message, dict):
if self.multimodal and isinstance(message, MultimodalData):
remove_input = (
len(message["files"]) + 1
if message["text"] is not None
else len(message["files"])
len(message.files) + 1
if message.text is not None
else len(message.files)
)
history = history_with_input[:-remove_input]
else:
@ -590,30 +646,136 @@ class ChatInterface(Blocks):
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
if self.multimodal and isinstance(message, dict):
for x in message["files"]:
history.append([(x,), None])
update = history + [[message["text"], first_response]]
if self.msg_format == "messages":
first_response = self.response_as_dict(first_response)
if (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "tuples"
):
for x in message.files:
history.append([(x,), None]) # type: ignore
update = history + [[message.text, first_response]]
yield update, update
else:
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "messages"
):
for x in message.files:
history.append(
{"role": "user", "content": cast(FileDataDict, x.model_dump())} # type: ignore
)
update = history + [
{"role": "user", "content": message.text},
first_response,
]
yield update, update
elif self.msg_format == "tuples":
update = history + [[message, first_response]]
yield update, update
else:
update = history + [
{"role": "user", "content": message},
first_response,
]
yield update, update
except StopIteration:
if self.multimodal and isinstance(message, dict):
if self.multimodal and isinstance(message, MultimodalData):
self._append_multimodal_history(message, None, history)
yield history, history
else:
update = history + [[message, None]]
yield update, update
async for response in generator:
if self.multimodal and isinstance(message, dict):
update = history + [[message["text"], response]]
if self.msg_format == "messages":
response = self.response_as_dict(response)
if (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "tuples"
):
update = history + [[message.text, response]]
yield update, update
else:
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "messages"
):
update = history + [
{"role": "user", "content": message.text},
response,
]
yield update, update
elif self.msg_format == "tuples":
update = history + [[message, response]]
yield update, update
else:
update = history + [{"role": "user", "content": message}, response]
yield update, update
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
async def _api_submit_fn(
self,
message: str,
history: TupleFormat | list[MessageDict],
request: Request,
*args,
) -> tuple[str, TupleFormat | list[MessageDict]]:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
if self.msg_format == "tuples":
history.append([message, response]) # type: ignore
else:
new_response = self.response_as_dict(response)
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
return response, history
async def _api_stream_fn(
self, message: str, history: list[list[str | None]], request: Request, *args
) -> AsyncGenerator:
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *args], request=request
)
if self.is_async:
generator = self.fn(*inputs)
else:
generator = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
if self.msg_format == "tuples":
yield first_response, history + [[message, first_response]]
else:
first_response = self.response_as_dict(first_response)
yield (
first_response,
history + [{"role": "user", "content": message}, first_response],
)
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
if self.msg_format == "tuples":
yield response, history + [[message, response]]
else:
new_response = self.response_as_dict(response)
yield (
new_response,
history + [{"role": "user", "content": message}, new_response],
)
async def _examples_fn(
self, message: str, *args
) -> TupleFormat | list[MessageDict]:
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
if self.is_async:
@ -622,7 +784,10 @@ class ChatInterface(Blocks):
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
return [[message, response]]
if self.msg_format == "tuples":
return [[message, response]]
else:
return [{"role": "user", "content": message}, response]
async def _examples_stream_fn(
self,
@ -639,24 +804,29 @@ class ChatInterface(Blocks):
)
generator = SyncToAsyncIterator(generator, self.limiter)
async for response in generator:
yield [[message, response]]
if self.msg_format == "tuples":
yield [[message, response]]
else:
new_response = self.response_as_dict(response)
yield [{"role": "user", "content": message}, new_response]
async def _delete_prev_fn(
self,
message: str | dict[str, list],
history: list[list[str | tuple | None]],
message: str | MultimodalData | None,
history: list[MessageDict] | TupleFormat,
) -> tuple[
list[list[str | tuple | None]],
str | dict[str, list],
list[list[str | tuple | None]],
list[MessageDict] | TupleFormat,
str | MultimodalData,
list[MessageDict] | TupleFormat,
]:
if self.multimodal and isinstance(message, dict):
extra = 1 if self.msg_format == "messages" else 0
if self.multimodal and isinstance(message, MultimodalData):
remove_input = (
len(message["files"]) + 1
if message["text"] is not None
else len(message["files"])
)
len(message.files) + 1
if message.text is not None
else len(message.files)
) + extra
history = history[:-remove_input]
else:
history = history[:-1]
history = history[: -(1 + extra)]
return history, message or "", history

View File

@ -11,7 +11,7 @@ from gradio.components.base import (
get_component_instance,
)
from gradio.components.button import Button
from gradio.components.chatbot import Chatbot
from gradio.components.chatbot import Chatbot, ChatMessage, MessageDict
from gradio.components.checkbox import Checkbox
from gradio.components.checkboxgroup import CheckboxGroup
from gradio.components.clear_button import ClearButton
@ -64,6 +64,7 @@ __all__ = [
"BarPlot",
"Button",
"Chatbot",
"ChatMessage",
"ClearButton",
"Component",
"component",
@ -92,6 +93,7 @@ __all__ = [
"LoginButton",
"LogoutButton",
"Markdown",
"MessageDict",
"Textbox",
"Dropdown",
"Model3D",

View File

@ -12,7 +12,7 @@ import warnings
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Type
import gradio_client.utils as client_utils
@ -166,7 +166,7 @@ class Component(ComponentBase, Block):
# This gets overridden when `select` is called
self._selectable = False
if not hasattr(self, "data_model"):
self.data_model: type[GradioDataModel] | None = None
self.data_model: Type[GradioDataModel] | None = None
Block.__init__(
self,

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import inspect
from dataclasses import dataclass, field
from pathlib import Path
from typing import (
TYPE_CHECKING,
@ -13,11 +14,16 @@ from typing import (
Literal,
Optional,
Tuple,
Type,
TypedDict,
Union,
cast,
)
from gradio_client import utils as client_utils
from gradio_client.documentation import document
from pydantic import Field
from typing_extensions import NotRequired
from gradio import utils
from gradio.component_meta import ComponentMeta
@ -27,6 +33,31 @@ from gradio.components import (
from gradio.components.base import Component
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.events import Events
from gradio.exceptions import Error
from gradio.processing_utils import move_resource_to_block_cache
class MetadataDict(TypedDict):
title: Union[str, None]
class FileDataDict(TypedDict):
path: str # server filepath
url: NotRequired[Optional[str]] # normalised server url
size: NotRequired[Optional[int]] # size in bytes
orig_name: NotRequired[Optional[str]] # original filename
mime_type: NotRequired[Optional[str]]
is_stream: NotRequired[bool]
meta: dict[Literal["_type"], Literal["gradio.FileData"]]
class MessageDict(TypedDict):
content: str | FileDataDict | tuple | Component
role: Literal["user", "assistant", "system"]
metadata: NotRequired[MetadataDict]
TupleFormat = List[List[Union[str, Tuple[str], Tuple[str, str], None]]]
if TYPE_CHECKING:
from gradio.components import Timer
@ -59,7 +90,7 @@ class ComponentMessage(GradioModel):
props: Dict[str, Any]
class ChatbotData(GradioRootModel):
class ChatbotDataTuples(GradioRootModel):
root: List[
Tuple[
Union[str, FileMessage, ComponentMessage, None],
@ -68,6 +99,27 @@ class ChatbotData(GradioRootModel):
]
class Metadata(GradioModel):
title: Optional[str] = None
class Message(GradioModel):
role: str
metadata: Metadata = Field(default_factory=Metadata)
content: Union[str, FileMessage, ComponentMessage]
@dataclass
class ChatMessage:
role: Literal["user", "assistant", "system"]
content: str | FileData | Component | FileDataDict | tuple | list
metadata: MetadataDict | Metadata = field(default_factory=Metadata)
class ChatbotDataMessages(GradioRootModel):
root: List[Message]
@document()
class Chatbot(Component):
"""
@ -80,7 +132,6 @@ class Chatbot(Component):
"""
EVENTS = [Events.change, Events.select, Events.like]
data_model = ChatbotData
def __init__(
self,
@ -92,6 +143,7 @@ class Chatbot(Component):
| None
) = None,
*,
msg_format: Literal["messages", "tuples"] = "tuples",
label: str | None = None,
every: Timer | float | None = None,
inputs: Component | list[Component] | set[Component] | None = None,
@ -121,6 +173,7 @@ class Chatbot(Component):
"""
Parameters:
value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
msg_format: The format of the messages. If 'tuples', expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content' key's value supports everything the 'tuples' format supports. The 'role' key should be one of 'user' or 'assistant'. Any other roles will not be displayed in the output.
label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
@ -148,6 +201,15 @@ class Chatbot(Component):
placeholder: a placeholder message to display in the chatbot when it is empty. Centered vertically and horizontally in the Chatbot. Supports Markdown and HTML. If None, no placeholder is displayed.
"""
self.likeable = likeable
if msg_format not in ["messages", "tuples"]:
raise ValueError(
"msg_format must be 'messages' or 'tuples', received: {msg_format}"
)
self.msg_format: Literal["tuples", "messages"] = msg_format
if msg_format == "messages":
self.data_model = ChatbotDataMessages
else:
self.data_model = ChatbotDataTuples
self.height = height
self.rtl = rtl
if latex_delimiters is None:
@ -189,7 +251,27 @@ class Chatbot(Component):
]
self.placeholder = placeholder
def _preprocess_chat_messages(
@staticmethod
def _check_format(messages: list[Any], msg_format: Literal["messages", "tuples"]):
if msg_format == "messages":
all_dicts = all(
isinstance(message, dict) and "role" in message and "content" in message
for message in messages
)
all_msgs = all(isinstance(msg, ChatMessage) for msg in messages)
if not (all_dicts or all_msgs):
raise Error(
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
)
elif not all(
isinstance(message, (tuple, list)) and len(message) == 2
for message in messages
):
raise Error(
"Data incompatible with tuples format. Each message should be a list of length 2."
)
def _preprocess_content(
self,
chat_message: str | FileMessage | ComponentMessage | None,
) -> str | GradioComponent | tuple[str | None] | tuple[str | None, str] | None:
@ -228,18 +310,9 @@ class Chatbot(Component):
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
def preprocess(
self,
payload: ChatbotData | None,
) -> list[list[str | GradioComponent | tuple[str] | tuple[str, str] | None]] | None:
"""
Parameters:
payload: data as a ChatbotData object
Returns:
Passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed.
"""
if payload is None:
return payload
def _preprocess_messages_tuples(
self, payload: ChatbotDataTuples
) -> list[list[str | tuple[str] | tuple[str, str] | None]]:
processed_messages = []
for message_pair in payload.root:
if not isinstance(message_pair, (tuple, list)):
@ -252,29 +325,71 @@ class Chatbot(Component):
)
processed_messages.append(
[
self._preprocess_chat_messages(message_pair[0]),
self._preprocess_chat_messages(message_pair[1]),
self._preprocess_content(message_pair[0]),
self._preprocess_content(message_pair[1]),
]
)
return processed_messages
def _postprocess_chat_messages(
self, chat_message: str | tuple | list | GradioComponent | None
) -> str | FileMessage | ComponentMessage | None:
def create_file_message(chat_message, filepath):
mime_type = client_utils.get_mimetype(filepath)
return FileMessage(
file=FileData(path=filepath, mime_type=mime_type),
alt_text=(
chat_message[1]
if not isinstance(chat_message, GradioComponent)
and len(chat_message) > 1
else None
),
)
def preprocess(
self,
payload: ChatbotDataTuples | ChatbotDataMessages | None,
) -> (
list[list[str | tuple[str] | tuple[str, str] | None]] | list[MessageDict] | None
):
"""
Parameters:
payload: data as a ChatbotData object
Returns:
If msg_format is 'tuples', passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed. If msg_format is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
"""
if payload is None:
return payload
if self.msg_format == "tuples":
if not isinstance(payload, ChatbotDataTuples):
raise Error("Data incompatible with the tuples format")
return self._preprocess_messages_tuples(cast(ChatbotDataTuples, payload))
if not isinstance(payload, ChatbotDataMessages):
raise Error("Data incompatible with the messages format")
message_dicts = []
for message in payload.root:
message_dict = cast(MessageDict, message.model_dump())
message_dict["content"] = self._preprocess_content(message.content)
message_dicts.append(message_dict)
return message_dicts
@staticmethod
def _get_alt_text(chat_message: dict | list | tuple | GradioComponent):
if isinstance(chat_message, dict):
return chat_message.get("alt_text")
elif not isinstance(chat_message, GradioComponent) and len(chat_message) > 1:
return chat_message[1]
@staticmethod
def _create_file_message(chat_message, filepath):
mime_type = client_utils.get_mimetype(filepath)
return FileMessage(
file=FileData(path=filepath, mime_type=mime_type),
alt_text=Chatbot._get_alt_text(chat_message),
)
def _postprocess_content(
self,
chat_message: str
| tuple
| list
| FileDataDict
| FileData
| GradioComponent
| None,
) -> str | FileMessage | ComponentMessage | None:
if chat_message is None:
return None
elif isinstance(chat_message, FileMessage):
return chat_message
elif isinstance(chat_message, FileData):
return FileMessage(file=chat_message)
elif isinstance(chat_message, GradioComponent):
component = import_component_and_data(type(chat_message).__name__)
if component:
@ -287,54 +402,110 @@ class Chatbot(Component):
constructor_args=chat_message.constructor_args,
props=config,
)
elif isinstance(chat_message, dict) and "path" in chat_message:
filepath = chat_message["path"]
return self._create_file_message(chat_message, filepath)
elif isinstance(chat_message, (tuple, list)):
filepath = str(chat_message[0])
return create_file_message(chat_message, filepath)
return self._create_file_message(chat_message, filepath)
elif isinstance(chat_message, str):
chat_message = inspect.cleandoc(chat_message)
return chat_message
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
def _postprocess_messages_tuples(self, value: TupleFormat) -> ChatbotDataTuples:
processed_messages = []
for message_pair in value:
processed_messages.append(
[
self._postprocess_content(message_pair[0]),
self._postprocess_content(message_pair[1]),
]
)
return ChatbotDataTuples(root=processed_messages)
def _postprocess_message_messages(
self, message: MessageDict | ChatMessage
) -> list[Message]:
if isinstance(message, dict):
message["content"] = self._postprocess_content(message["content"])
msg = Message(**message) # type: ignore
elif isinstance(message, ChatMessage):
message.content = self._postprocess_content(message.content) # type: ignore
msg = Message(
role=message.role,
content=message.content, # type: ignore
metadata=message.metadata, # type: ignore
)
else:
raise Error(
f"Invalid message for Chatbot component: {message}", visible=False
)
# extract file path from message
new_messages = []
if isinstance(msg.content, str):
for word in msg.content.split(" "):
filepath = Path(word)
try:
is_file = filepath.is_file() and filepath.exists()
except OSError:
is_file = False
if is_file:
filepath = cast(
str, move_resource_to_block_cache(filepath, block=self)
)
mime_type = client_utils.get_mimetype(filepath)
new_messages.append(
Message(
role=msg.role,
metadata=msg.metadata,
content=FileMessage(
file=FileData(path=filepath, mime_type=mime_type)
),
),
)
return [msg, *new_messages]
def postprocess(
self,
value: (
list[
list[str | GradioComponent | tuple[str] | tuple[str, str] | None]
| tuple
]
| None
),
) -> ChatbotData:
value: TupleFormat | list[MessageDict | Message] | None,
) -> ChatbotDataTuples | ChatbotDataMessages:
"""
Parameters:
value: expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed.
value: If msg_format is `tuples`, expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If msg_format is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
Returns:
an object of type ChatbotData
"""
data_model = cast(
Union[Type[ChatbotDataTuples], Type[ChatbotDataMessages]], self.data_model
)
if value is None:
return ChatbotData(root=[])
processed_messages = []
for message_pair in value:
if not isinstance(message_pair, (tuple, list)):
raise TypeError(
f"Expected a list of lists or list of tuples. Received: {message_pair}"
)
if len(message_pair) != 2:
raise TypeError(
f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
)
processed_messages.append(
[
self._postprocess_chat_messages(message_pair[0]),
self._postprocess_chat_messages(message_pair[1]),
]
)
return ChatbotData(root=processed_messages)
return data_model(root=[])
if self.msg_format == "tuples":
self._check_format(value, "tuples")
return self._postprocess_messages_tuples(cast(TupleFormat, value))
self._check_format(value, "messages")
processed_messages = [
msg
for message in value
for msg in self._postprocess_message_messages(cast(MessageDict, message))
]
return ChatbotDataMessages(root=processed_messages)
def example_payload(self) -> Any:
if self.msg_format == "messages":
return [
Message(role="user", content="Hello!").model_dump(),
Message(role="assistant", content="How can I help you?").model_dump(),
]
return [["Hello!", None]]
def example_value(self) -> Any:
if self.msg_format == "messages":
return [
Message(role="user", content="Hello!").model_dump(),
Message(role="assistant", content="How can I help you?").model_dump(),
]
return [["Hello!", None]]

View File

@ -229,7 +229,7 @@ class Interface(Blocks):
state_output_index = state_output_indexes[0]
if inputs[state_input_index] == "state":
default = utils.get_default_args(fn)[state_input_index]
state_variable = State(value=default) # type: ignore
state_variable = State(value=default)
else:
state_variable = inputs[state_input_index]
@ -244,12 +244,10 @@ class Interface(Blocks):
self.cache_examples = False
self.main_input_components = [
get_component_instance(i, unrender=True)
for i in inputs # type: ignore
get_component_instance(i, unrender=True) for i in inputs
]
self.additional_input_components = [
get_component_instance(i, unrender=True)
for i in additional_inputs # type: ignore
get_component_instance(i, unrender=True) for i in additional_inputs
]
if additional_inputs_accordion is None:
self.additional_inputs_accordion_params = {

View File

@ -36,27 +36,133 @@ demo.launch()
<!-- Behavior -->
### Behavior
The data format accepted by the Chatbot is dictated by the `msg_format` parameter.
This parameter can take two values, `'tuples'` and `'messages'`.
If `msg_format` is `'tuples'`, then the data sent to/from the chatbot will be a list of tuples.
The first element of each tuple is the user message and the second element is the bot's response.
Each element can be a string (markdown/html is supported),
a tuple (in which case the first element is a filepath that will be displayed in the chatbot),
or a gradio component (see the Examples section for more details).
If the `msg_format` is `'messages'`, then the data sent to/from the chatbot will be a list of dictionaries
with `role` and `content` keys. This format is compliant with the format expected by most LLM APIs (HuggingChat, OpenAI, Claude).
The `role` key is either `'user'` or `'`assistant'` and the `content` key can be a string (markdown/html supported),
a `FileDataDict` (to represent a file that is displayed in the chatbot - documented below), or a gradio component.
For convenience, you can use the `ChatMessage` dataclass so that your text editor can give you autocomplete hints and typechecks.
```python
from gradio import ChatMessage
def generate_response(history):
history.append(
ChatMessage(role="assistant",
content="How can I help you?")
)
return history
```
Additionally, when `msg_format` is `messages`, you can provide additional metadata regarding any tools used to generate the response.
This is useful for displaying the thought process of LLM agents. For example,
```python
def generate_response(history):
history.append(
ChatMessage(role="assistant",
content="The weather API says it is 20 degrees Celcius in New York.",
metadata={"title": "🛠️ Used tool Weather API"})
)
return history
```
Would be displayed as following:
<img src="https://github.com/freddyaboulton/freddyboulton/assets/41651716/a4bb2b0a-5f8a-4287-814b-4eab278e021e" alt="Gradio chatbot tool display">
All of the types expected by the messages format are documented below:
```python
class MetadataDict(TypedDict):
title: Union[str, None]
class FileDataDict(TypedDict):
path: str # server filepath
url: NotRequired[Optional[str]] # normalised server url
size: NotRequired[Optional[int]] # size in bytes
orig_name: NotRequired[Optional[str]] # original filename
mime_type: NotRequired[Optional[str]]
is_stream: NotRequired[bool]
meta: dict[Literal["_type"], Literal["gradio.FileData"]]
class MessageDict(TypedDict):
content: str | FileDataDict | Component
role: Literal["user", "assistant", "system"]
metadata: NotRequired[MetadataDict]
@dataclass
class Metadata:
title: Optional[str] = None
@dataclass
class ChatMessage:
role: Literal["user", "assistant", "system"]
content: str | FileData | Component | FileDataDict | tuple | list
metadata: MetadataDict | Metadata = field(default_factory=Metadata)
```
## **As input component**: {@html style_formatted_text(obj.preprocess.return_doc.doc)}
##### Your function should accept one of these types:
If `msg_format` is `tuples` -
```python
from gradio import Component
def predict(
value: list[list[str | tuple[str] | tuple[str, str] | None]] | None
)
value: list[list[str | tuple[str, str] | Component | None]] | None
):
...
```
If `msg_format` is `messages` -
```python
from gradio import MessageDict
def predict(value: list[MessageDict] | None):
...
```
<br>
## **As output component**: {@html style_formatted_text(obj.postprocess.parameter_doc[0].doc)}
##### Your function should return one of these types:
If `msg_format` is `tuples` -
```python
def predict(···) -> list[list[str | tuple[str] | tuple[str, str] | None] | tuple] | None
...
return value
```
If `msg_format` is `messages` -
from gradio import ChatMessage, MessageDict
```python
def predict(···) - > list[MessageDict] | list[ChatMessage]:
...
```
<!--- Initialization -->
### Initialization

View File

@ -0,0 +1,65 @@
import { test, expect, go_to_testcase } from "@gradio/tootils";
for (const msg_format of ["tuples", "messages"]) {
test(`message format ${msg_format} - Gallery component properly displayed`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
await page.getByTestId("gallery-radio-label").click();
await page.getByTestId("textbox").click();
await page.getByTestId("textbox").fill("gallery");
await page.keyboard.press("Enter");
await expect(
page.getByLabel("Thumbnail 1 of 2").locator("img")
).toBeVisible();
await expect(
page.getByLabel("Thumbnail 2 of 2").locator("img")
).toBeVisible();
});
test(`message format ${msg_format} - Audio component properly displayed`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
await page.getByTestId("audio-radio-label").click();
await page.getByTestId("textbox").click();
await page.getByTestId("textbox").fill("audio");
await page.keyboard.press("Enter");
await expect(
page.getByTestId("unlabelled-audio").locator("audio")
).toBeAttached();
});
test(`message format ${msg_format} - Video component properly displayed`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
await page.getByTestId("video-radio-label").click();
await page.getByTestId("textbox").click();
await page.getByTestId("textbox").fill("video");
await page.keyboard.press("Enter");
await expect(page.getByTestId("test-player")).toBeAttached();
await expect(
page.getByTestId("test-player").getAttribute("src")
).toBeTruthy();
});
test(`message format ${msg_format} - Image component properly displayed`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
await page.getByTestId("image-radio-label").click();
await page.getByTestId("textbox").click();
await page.getByTestId("textbox").fill("image");
await page.keyboard.press("Enter");
await expect(page.getByTestId("bot").locator("img")).toBeAttached();
});
}

View File

@ -1,221 +1,268 @@
import { test, expect } from "@gradio/tootils";
import { test, expect, go_to_testcase } from "@gradio/tootils";
test("text input by a user should be shown in the chatbot as a paragraph", async ({
page
}) => {
const textbox = await page.getByTestId("textbox");
await textbox.fill("Lorem ipsum");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.textContent();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toEqual("Lorem ipsum");
await expect(bot_message).toBeTruthy();
});
for (const msg_format of ["tuples", "messages"]) {
test(`message format ${msg_format} - text input by a user should be shown in the chatbot as a paragraph`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const textbox = await page.getByTestId("textbox");
await textbox.fill("Lorem ipsum");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.textContent();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toEqual("Lorem ipsum");
await expect(bot_message).toBeTruthy();
});
test("images uploaded by a user should be shown in the chat", async ({
page
}) => {
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("./test/files/cheetah1.jpg");
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");
test(`message format ${msg_format} - images uploaded by a user should be shown in the chat`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("./test/files/cheetah1.jpg");
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");
const user_message_locator = await page.getByTestId("user").first();
const user_message = await user_message_locator.elementHandle();
if (user_message) {
const imageContainer = await user_message.$("div.image-container");
const user_message_locator = await page.getByTestId("user").first();
const user_message = await user_message_locator.elementHandle();
if (user_message) {
const imageContainer = await user_message.$("div.image-container");
if (imageContainer) {
const imgElement = await imageContainer.$("img");
if (imgElement) {
const image_src = await imgElement.getAttribute("src");
expect(image_src).toBeTruthy();
if (imageContainer) {
const imgElement = await imageContainer.$("img");
if (imgElement) {
const image_src = await imgElement.getAttribute("src");
expect(image_src).toBeTruthy();
}
}
}
}
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
expect(bot_message).toBeTruthy();
});
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
test("Users can upload multiple images and they will be shown as thumbnails", async ({
page
}) => {
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles([
"./test/files/cheetah1.jpg",
"./test/files/cheetah1.jpg"
]);
expect
.poll(async () => await page.locator("thumbnail-image").count(), {
timeout: 5000
})
.toEqual(2);
});
expect(bot_message).toBeTruthy();
});
test("audio uploaded by a user should be shown in the chatbot", async ({
page
}) => {
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("../../test/test_files/audio_sample.wav");
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");
test(`message format ${msg_format} - audio uploaded by a user should be shown in the chatbot`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("../../test/test_files/audio_sample.wav");
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");
const user_message = await page.getByTestId("user").first().locator("audio");
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
const audio_data = await user_message.getAttribute("src");
await expect(audio_data).toBeTruthy();
await expect(bot_message).toBeTruthy();
});
const user_message = await page
.getByTestId("user")
.first()
.locator("audio");
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
const audio_data = await user_message.getAttribute("src");
await expect(audio_data).toBeTruthy();
await expect(bot_message).toBeTruthy();
});
test("videos uploaded by a user should be shown in the chatbot", async ({
page
}) => {
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("../../test/test_files/video_sample.mp4");
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");
test(`message format ${msg_format} - videos uploaded by a user should be shown in the chatbot`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles("../../test/test_files/video_sample.mp4");
await page.getByTestId("textbox").click();
await page.keyboard.press("Enter");
const user_message = await page.getByTestId("user").first().locator("video");
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
const video_data = await user_message.getAttribute("src");
await expect(video_data).toBeTruthy();
await expect(bot_message).toBeTruthy();
});
const user_message = await page
.getByTestId("user")
.first()
.locator("video");
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
const video_data = await user_message.getAttribute("src");
await expect(video_data).toBeTruthy();
await expect(bot_message).toBeTruthy();
});
test("markdown input by a user should be correctly formatted: bold, italics, links", async ({
page
}) => {
const textbox = await page.getByTestId("textbox");
await textbox.fill(
"This is **bold text**. This is *italic text*. This is a [link](https://gradio.app)."
);
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toEqual(
'This is <strong>bold text</strong>. This is <em>italic text</em>. This is a <a href="https://gradio.app" target="_blank" rel="noopener noreferrer">link</a>.'
);
await expect(bot_message).toBeTruthy();
});
test(`message format ${msg_format} - markdown input by a user should be correctly formatted: bold, italics, links`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const textbox = await page.getByTestId("textbox");
await textbox.fill(
"This is **bold text**. This is *italic text*. This is a [link](https://gradio.app)."
);
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toEqual(
'This is <strong>bold text</strong>. This is <em>italic text</em>. This is a <a href="https://gradio.app" target="_blank" rel="noopener noreferrer">link</a>.'
);
await expect(bot_message).toBeTruthy();
});
test("inline code markdown input by the user should be correctly formatted", async ({
page
}) => {
const textbox = await page.getByTestId("textbox");
await textbox.fill("This is `code`.");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toEqual("This is <code>code</code>.");
await expect(bot_message).toBeTruthy();
});
test(`message format ${msg_format} - inline code markdown input by the user should be correctly formatted`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const textbox = await page.getByTestId("textbox");
await textbox.fill("This is `code`.");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toEqual("This is <code>code</code>.");
await expect(bot_message).toBeTruthy();
});
test("markdown code blocks input by a user should be rendered correctly with the correct language tag", async ({
page
}) => {
const textbox = await page.getByTestId("textbox");
await textbox.fill("```python\nprint('Hello')\nprint('World!')\n```");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.locator("pre")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toContain("language-python");
await expect(bot_message).toBeTruthy();
});
test(`message format ${msg_format} - markdown code blocks input by a user should be rendered correctly with the correct language tag`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const textbox = await page.getByTestId("textbox");
await textbox.fill("```python\nprint('Hello')\nprint('World!')\n```");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.locator("pre")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toContain("language-python");
await expect(bot_message).toBeTruthy();
});
test("LaTeX input by a user should be rendered correctly", async ({ page }) => {
const textbox = await page.getByTestId("textbox");
await textbox.fill("This is LaTeX $$x^2$$");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toContain("katex-display");
await expect(bot_message).toBeTruthy();
});
test(`message format ${msg_format} - LaTeX input by a user should be rendered correctly`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const textbox = await page.getByTestId("textbox");
await textbox.fill("This is LaTeX $$x^2$$");
await page.keyboard.press("Enter");
const user_message = await page
.getByTestId("user")
.first()
.getByRole("paragraph")
.innerHTML();
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph")
.textContent();
await expect(user_message).toContain("katex-display");
await expect(bot_message).toBeTruthy();
});
test("when a new message is sent the chatbot should scroll to the latest message", async ({
page
}) => {
const textbox = await page.getByTestId("textbox");
const line_break = "<br>";
await textbox.fill(line_break.repeat(30));
await page.keyboard.press("Enter");
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph");
await expect(bot_message).toBeVisible();
const bot_message_text = bot_message.textContent();
await expect(bot_message_text).toBeTruthy();
});
test(`message format ${msg_format} - when a new message is sent the chatbot should scroll to the latest message`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const textbox = await page.getByTestId("textbox");
const line_break = "<br>";
await textbox.fill(line_break.repeat(30));
await page.keyboard.press("Enter");
const bot_message = await page
.getByTestId("bot")
.first()
.getByRole("paragraph");
await expect(bot_message).toBeVisible();
const bot_message_text = bot_message.textContent();
await expect(bot_message_text).toBeTruthy();
});
test("chatbot like and dislike functionality", async ({ page }) => {
await page.getByTestId("textbox").click();
await page.getByTestId("textbox").fill("hello");
await page.keyboard.press("Enter");
await page.getByLabel("like", { exact: true }).click();
await page.getByLabel("dislike").click();
test(`message format ${msg_format} - chatbot like and dislike functionality`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
await page.getByTestId("textbox").click();
await page.getByTestId("textbox").fill("hello");
await page.keyboard.press("Enter");
await page.getByLabel("like", { exact: true }).click();
await page.getByLabel("dislike").click();
expect(await page.getByLabel("clicked dislike").count()).toEqual(1);
expect(await page.getByLabel("clicked like").count()).toEqual(0);
});
expect(await page.getByLabel("clicked dislike").count()).toEqual(1);
expect(await page.getByLabel("clicked like").count()).toEqual(0);
});
test(`message format ${msg_format} - Users can upload multiple images and they will be shown as thumbnails`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const fileChooserPromise = page.waitForEvent("filechooser");
await page.getByTestId("upload-button").click();
const fileChooser = await fileChooserPromise;
await fileChooser.setFiles([
"./test/files/cheetah1.jpg",
"./test/files/cheetah1.jpg"
]);
expect
.poll(async () => await page.locator("thumbnail-image").count(), {
timeout: 5000
})
.toEqual(2);
});
}

View File

@ -0,0 +1,16 @@
import { test, expect } from "@gradio/tootils";
test("Chatbot can support agentic demos by displaying messages with metadata", async ({
page
}) => {
await page.getByRole("button", { name: "Get San Francisco Weather" }).click();
await expect(
await page.locator("button").filter({ hasText: "💥 Error" }).nth(1)
).toBeVisible();
await expect(
page.locator("span").filter({ hasText: "🛠️ Used tool" })
).toBeVisible();
await expect(
page.locator("button").filter({ hasText: "It's a sunny day in San" })
).toBeVisible();
});

View File

@ -1,82 +1,95 @@
import { test, expect } from "@gradio/tootils";
import { test, expect, go_to_testcase } from "@gradio/tootils";
test("chatinterface works with streaming functions and all buttons behave as expected", async ({
page
}) => {
const submit_button = page.getByRole("button", { name: "Submit" });
const retry_button = page.getByRole("button", { name: "🔄 Retry" });
const undo_button = page.getByRole("button", { name: "↩️ Undo" });
const clear_button = page.getByRole("button", { name: "🗑️ Clear" });
const textbox = page.getByPlaceholder("Type a message...");
for (const msg_format of ["tuples", "messages"]) {
test(`msg format ${msg_format} chatinterface works with streaming functions and all buttons behave as expected`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const submit_button = page.getByRole("button", { name: "Submit" });
const retry_button = page.getByRole("button", { name: "🔄 Retry" });
const undo_button = page.getByRole("button", { name: "↩️ Undo" });
const clear_button = page.getByRole("button", { name: "🗑️ Clear" });
const textbox = page.getByPlaceholder("Type a message...");
await textbox.fill("hello");
await submit_button.click();
await textbox.fill("hello");
await submit_button.click();
await expect(textbox).toHaveValue("");
const expected_text_el_0 = page.locator(".bot p", {
hasText: "Run 1 - You typed: hello"
await expect(textbox).toHaveValue("");
const expected_text_el_0 = page.locator(".bot p", {
hasText: "Run 1 - You typed: hello"
});
await expect(expected_text_el_0).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(1);
await textbox.fill("hi");
await submit_button.click();
await expect(textbox).toHaveValue("");
const expected_text_el_1 = page.locator(".bot p", {
hasText: "Run 2 - You typed: hi"
});
await expect(expected_text_el_1).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(2);
await undo_button.click();
await expect
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
.toBe(1);
await expect(textbox).toHaveValue("hi");
await retry_button.click();
const expected_text_el_2 = page.locator(".bot p", {
hasText: ""
});
await expect(expected_text_el_2).toBeVisible();
await expect
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
.toBe(1);
await textbox.fill("hi");
await submit_button.click();
await expect(textbox).toHaveValue("");
const expected_text_el_3 = page.locator(".bot p", {
hasText: "Run 4 - You typed: hi"
});
await expect(expected_text_el_3).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(2);
await clear_button.click();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 5000 })
.toBe(0);
});
await expect(expected_text_el_0).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(1);
await textbox.fill("hi");
await submit_button.click();
await expect(textbox).toHaveValue("");
const expected_text_el_1 = page.locator(".bot p", {
hasText: "Run 2 - You typed: hi"
test(`msg format ${msg_format} the api recorder correctly records the api calls`, async ({
page
}) => {
if (msg_format === "messages") {
await go_to_testcase(page, "messages");
}
const textbox = page.getByPlaceholder("Type a message...");
const submit_button = page.getByRole("button", { name: "Submit" });
await textbox.fill("hi");
await page.getByRole("button", { name: "Use via API logo" }).click();
await page.locator("#start-api-recorder").click();
await submit_button.click();
await expect(textbox).toHaveValue("");
await expect(page.locator(".bot p").first()).toContainText(
/\- You typed: hi/
);
const api_recorder = await page.locator("#api-recorder");
await api_recorder.click();
await expect(page.locator("#num-recorded-api-calls")).toContainText(
"🪄 Recorded API Calls [5]"
);
});
await expect(expected_text_el_1).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(2);
await undo_button.click();
await expect
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
.toBe(1);
await expect(textbox).toHaveValue("hi");
await retry_button.click();
const expected_text_el_2 = page.locator(".bot p", {
hasText: ""
});
await expect(expected_text_el_2).toBeVisible();
await expect
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
.toBe(1);
await textbox.fill("hi");
await submit_button.click();
await expect(textbox).toHaveValue("");
const expected_text_el_3 = page.locator(".bot p", {
hasText: "Run 4 - You typed: hi"
});
await expect(expected_text_el_3).toBeVisible();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
.toBe(2);
await clear_button.click();
await expect
.poll(async () => page.locator(".bot.message").count(), { timeout: 5000 })
.toBe(0);
});
test("the api recorder correctly records the api calls", async ({ page }) => {
const textbox = page.getByPlaceholder("Type a message...");
const submit_button = page.getByRole("button", { name: "Submit" });
await textbox.fill("hi");
await page.getByRole("button", { name: "Use via API logo" }).click();
await page.locator("#start-api-recorder").click();
await submit_button.click();
await expect(textbox).toHaveValue("");
const api_recorder = await page.locator("#api-recorder");
await api_recorder.click();
const num_calls = await page.locator("#num-recorded-api-calls").innerText();
await expect(num_calls).toBe("🪄 Recorded API Calls [5]");
});
}

View File

@ -11,17 +11,19 @@
import { Chat } from "@gradio/icons";
import type { FileData } from "@gradio/client";
import { StatusTracker } from "@gradio/statustracker";
import type {
Message,
TupleFormat,
MessageRole,
NormalisedMessage
} from "./types";
import {
type messages,
type NormalisedMessage,
normalise_messages
} from "./shared/utils";
import { normalise_tuples, normalise_messages } from "./shared/utils";
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
export let value: messages = [];
export let value: TupleFormat | Message[] = [];
export let scale: number | null = null;
export let min_width: number | undefined = undefined;
export let label: string;
@ -35,6 +37,7 @@
export let sanitize_html = true;
export let bubble_full_width = true;
export let layout: "bubble" | "panel" = "bubble";
export let msg_format: "tuples" | "messages" = "tuples";
export let render_markdown = true;
export let line_breaks = true;
export let latex_delimiters: {
@ -52,9 +55,12 @@
}>;
export let avatar_images: [FileData | null, FileData | null] = [null, null];
let _value: [NormalisedMessage, NormalisedMessage][] | null = [];
let _value: NormalisedMessage[] | null = [];
$: _value = normalise_messages(value, root);
$: _value =
msg_format === "tuples"
? normalise_tuples(value as TupleFormat, root)
: normalise_messages(value as Message[], root);
export let loading_status: LoadingStatus | undefined = undefined;
export let height = 400;

View File

@ -1,5 +1,6 @@
<script lang="ts">
import { format_chat_for_sharing, type NormalisedMessage } from "./utils";
import { format_chat_for_sharing } from "./utils";
import type { NormalisedMessage } from "../types";
import { Gradio, copy } from "@gradio/utils";
import { dequal } from "dequal/lite";
@ -16,10 +17,16 @@
import { Clear } from "@gradio/icons";
import type { SelectData, LikeData } from "@gradio/utils";
import type { MessageRole, ComponentMessage, ComponentData } from "../types";
import { MarkdownCode as Markdown } from "@gradio/markdown";
import { type FileData, type Client } from "@gradio/client";
import type { FileData, Client } from "@gradio/client";
import type { I18nFormatter } from "js/app/src/gradio_helper";
import Pending from "./Pending.svelte";
import MessageBox from "./MessageBox.svelte";
export let value: NormalisedMessage[] | null = [];
let old_value: NormalisedMessage[] | null = null;
import Component from "./Component.svelte";
import LikeButtons from "./ButtonPanel.svelte";
import type { LoadedComponent } from "../../app/src/types";
@ -51,27 +58,19 @@
$: load_components(get_components_from_messages(value));
function get_components_from_messages(messages: typeof value): string[] {
function get_components_from_messages(
messages: NormalisedMessage[] | null
): string[] {
if (!messages) return [];
let components: Set<string> = new Set();
messages.forEach((message_pair) => {
message_pair.forEach((message) => {
if (
typeof message === "object" &&
message !== null &&
"component" in message
) {
components.add(message.component);
}
});
messages.forEach((message) => {
if (message.type === "component") {
components.add(message.content.component);
}
});
return Array.from(components);
}
export let value: [NormalisedMessage, NormalisedMessage][] | null = [];
let old_value: [NormalisedMessage, NormalisedMessage][] | null = null;
export let latex_delimiters: {
left: string;
right: string;
@ -173,45 +172,76 @@
}
}
function handle_select(
i: number,
j: number,
message: NormalisedMessage
): void {
function handle_select(i: number, message: NormalisedMessage): void {
dispatch("select", {
index: [i, j],
value: message
index: message.index,
value: message.content
});
}
function handle_like(
i: number,
j: number,
message: NormalisedMessage,
selected: string | null
): void {
dispatch("like", {
index: [i, j],
value: message,
index: message.index,
value: message.content,
liked: selected === "like"
});
}
function get_message_label_data(message: NormalisedMessage): string {
if (message.type === "text") {
return message.value;
} else if (message.type === "component") {
return `a component of type ${message.component}`;
} else if (message.type === "file") {
if (Array.isArray(message.file)) {
return `file of extension type: ${message.file[0].orig_name?.split(".").pop()}`;
return message.content;
} else if (
message.type === "component" &&
message.content.component === "file"
) {
if (Array.isArray(message.content.value)) {
return `file of extension type: ${message.content.value[0].orig_name?.split(".").pop()}`;
}
return (
`file of extension type: ${message.file?.orig_name?.split(".").pop()}` +
(message.file?.orig_name ?? "")
`file of extension type: ${message.content.value?.orig_name?.split(".").pop()}` +
(message.content.value?.orig_name ?? "")
);
}
return `a message of type ` + message.type ?? "unknown";
return `a component of type ${message.content.component ?? "unknown"}`;
}
function is_component_message(
message: NormalisedMessage
): message is ComponentMessage {
return message.type === "component";
}
function group_messages(
messages: NormalisedMessage[]
): NormalisedMessage[][] {
const groupedMessages: NormalisedMessage[][] = [];
let currentGroup: NormalisedMessage[] = [];
let currentRole: MessageRole | null = null;
for (const message of messages) {
if (!(message.role === "assistant" || message.role === "user")) {
continue;
}
if (message.role === currentRole) {
currentGroup.push(message);
} else {
if (currentGroup.length > 0) {
groupedMessages.push(currentGroup);
}
currentGroup = [message];
currentRole = message.role;
}
}
if (currentGroup.length > 0) {
groupedMessages.push(currentGroup);
}
return groupedMessages;
}
</script>
@ -237,132 +267,137 @@
>
<div class="message-wrap" use:copy>
{#if value !== null && value.length > 0}
{#each value as message_pair, i}
{#each message_pair as message, j}
{#if message.type !== "empty"}
{#if is_image_preview_open}
<div class="image-preview">
<img
src={image_preview_source}
alt={image_preview_source_alt}
/>
<button
class="image-preview-close-button"
on:click={() => {
is_image_preview_open = false;
}}><Clear /></button
>
</div>
{/if}
<div
class="message-row {layout} {j == 0 ? 'user-row' : 'bot-row'}"
class:with_avatar={avatar_images[j] !== null}
class:with_opposite_avatar={avatar_images[j === 0 ? 1 : 0] !==
null}
{@const groupedMessages = group_messages(value)}
{#each groupedMessages as messages, i}
{@const role = messages[0].role === "user" ? "user" : "bot"}
{@const avatar_img = avatar_images[role === "user" ? 0 : 1]}
{@const opposite_avatar_img = avatar_images[role === "user" ? 0 : 1]}
{#if is_image_preview_open}
<div class="image-preview">
<img src={image_preview_source} alt={image_preview_source_alt} />
<button
class="image-preview-close-button"
on:click={() => {
is_image_preview_open = false;
}}><Clear /></button
>
{#if avatar_images[j] !== null}
<div class="avatar-container">
<Image
class="avatar-image"
src={avatar_images[j]?.url}
alt="{j == 0 ? 'user' : 'bot'} avatar"
/>
</div>
{/if}
<div class="flex-wrap">
<div
class="message {j == 0 ? 'user' : 'bot'} {typeof message ===
'object' &&
message !== null &&
'component' in message
? message?.component
: ''}"
class:message-fit={!bubble_full_width}
class:panel-full-width={true}
</div>
{/if}
<div
class="message-row {layout} {role}-row"
class:with_avatar={avatar_img !== null}
class:with_opposite_avatar={opposite_avatar_img !== null}
>
{#if avatar_img !== null}
<div class="avatar-container">
<Image
class="avatar-image"
src={avatar_img?.url}
alt="{role} avatar"
/>
</div>
{/if}
<div class="flex-wrap">
{#each messages as message, thought_index}
{@const msg_type = messages[0].type}
<div
class="message {role} {is_component_message(message)
? message?.content.component
: ''}"
class:message-fit={!bubble_full_width}
class:panel-full-width={true}
class:message-markdown-disabled={!render_markdown}
style:text-align={rtl && role === "user" ? "left" : "right"}
class:component={msg_type === "component"}
class:html={is_component_message(message) &&
message.content.component === "html"}
class:thought={thought_index > 0}
>
<button
data-testid={role}
class:latest={i === value.length - 1}
class:message-markdown-disabled={!render_markdown}
style:text-align={rtl && j == 0 ? "left" : "right"}
class:component={typeof message === "object" &&
message !== null &&
"component" in message}
class:html={typeof message === "object" &&
message !== null &&
"component" in message &&
message.component === "html"}
style:user-select="text"
class:selectable
style:text-align={rtl ? "right" : "left"}
on:click={() => handle_select(i, message)}
on:keydown={(e) => {
if (e.key === "Enter") {
handle_select(i, message);
}
}}
dir={rtl ? "rtl" : "ltr"}
aria-label={role +
"'s message: " +
get_message_label_data(message)}
>
<button
data-testid={j == 0 ? "user" : "bot"}
class:latest={i === value.length - 1}
class:message-markdown-disabled={!render_markdown}
style:user-select="text"
class:selectable
style:text-align={rtl ? "right" : "left"}
on:click={() => handle_select(i, j, message)}
on:keydown={(e) => {
if (e.key === "Enter") {
handle_select(i, j, message);
}
}}
dir={rtl ? "rtl" : "ltr"}
aria-label={(j == 0 ? "user" : "bot") +
"'s message: " +
get_message_label_data(message)}
>
{#if message.type === "text"}
{#if message.type === "text"}
{#if message.metadata.title}
<MessageBox title={message.metadata.title}>
<Markdown
message={message.content}
{latex_delimiters}
{sanitize_html}
{render_markdown}
{line_breaks}
on:load={scroll}
/>
</MessageBox>
{:else}
<Markdown
message={message.value}
message={message.content}
{latex_delimiters}
{sanitize_html}
{render_markdown}
{line_breaks}
on:load={scroll}
/>
{:else if message.type === "component" && message.component in _components}
<Component
{target}
{theme_mode}
props={message.props}
type={message.component}
components={_components}
value={message.value}
{i18n}
{upload}
{_fetch}
on:load={scroll}
/>
{:else if message.type === "component" && message.component === "file"}
<a
data-testid="chatbot-file"
class="file-pil"
href={message.value.url}
target="_blank"
download={window.__is_colab__
? null
: message.value?.orig_name ||
message.value?.path.split("/").pop() ||
"file"}
>
{message.value?.orig_name ||
message.value?.path.split("/").pop() ||
"file"}
</a>
{/if}
</button>
</div>
<LikeButtons
show={j === 1 && (likeable || show_copy_button)}
handle_action={(selected) =>
handle_like(i, j, message, selected)}
{likeable}
{show_copy_button}
{message}
position={j === 0 ? "right" : "left"}
avatar={avatar_images[j]}
{layout}
/>
{:else if message.type === "component" && message.content.component in _components}
<Component
{target}
{theme_mode}
props={message.content.props}
type={message.content.component}
components={_components}
value={message.content.value}
{i18n}
{upload}
{_fetch}
on:load={scroll}
/>
{:else if message.type === "component" && message.content.component === "file"}
<a
data-testid="chatbot-file"
class="file-pil"
href={message.content.value.url}
target="_blank"
download={window.__is_colab__
? null
: message.content.value?.orig_name ||
message.content.value?.path.split("/").pop() ||
"file"}
>
{message.content.value?.orig_name ||
message.content.value?.path.split("/").pop() ||
"file"}
</a>
{/if}
</button>
</div>
</div>
{/if}
{/each}
<LikeButtons
show={role === "bot" && (likeable || show_copy_button)}
handle_action={(selected) => handle_like(i, message, selected)}
{likeable}
{show_copy_button}
{message}
position={role === "user" ? "right" : "left"}
avatar={avatar_img}
{layout}
/>
{/each}
</div>
</div>
{/each}
{#if pending_message}
<Pending {layout} />
@ -430,6 +465,10 @@
overflow-wrap: break-word;
}
.thought {
margin-top: var(--spacing-xxl);
}
.message :global(.prose) {
font-size: var(--chatbot-body-text-size);
}
@ -610,6 +649,37 @@
max-height: 200px;
}
.message-wrap .message :global(a) {
color: var(--color-text-link);
text-decoration: underline;
}
.message-wrap .bot :global(table),
.message-wrap .bot :global(tr),
.message-wrap .bot :global(td),
.message-wrap .bot :global(th) {
border: 1px solid var(--border-color-primary);
}
.message-wrap .user :global(table),
.message-wrap .user :global(tr),
.message-wrap .user :global(td),
.message-wrap .user :global(th) {
border: 1px solid var(--border-color-accent);
}
/* Lists */
.message-wrap :global(ol),
.message-wrap :global(ul) {
padding-inline-start: 2em;
}
/* KaTeX */
.message-wrap :global(span.katex) {
font-size: var(--text-lg);
direction: ltr;
}
/* Copy button */
.message-wrap :global(div[class*="code_wrap"] > button) {
position: absolute;

View File

@ -0,0 +1,55 @@
<script lang="ts">
let expanded = false;
export let title: string;
function toggleExpanded(): void {
expanded = !expanded;
}
</script>
<button class="box" on:click={toggleExpanded}>
<div class="title">
<span class="title-text">{title}</span>
<span
style:transform={expanded ? "rotate(0)" : "rotate(90deg)"}
class="arrow"
>
</span>
</div>
{#if expanded}
<div class="content">
<slot></slot>
</div>
{/if}
</button>
<style>
.box {
border-radius: 4px;
cursor: pointer;
max-width: max-content;
background: var(--color-accent-soft);
}
.title {
display: flex;
align-items: center;
padding: 8px;
color: var(--body-text-color);
opacity: 0.8;
}
.content {
padding: 8px;
}
.title-text {
padding-right: var(--spacing-lg);
}
.arrow {
margin-left: auto;
opacity: 0.8;
}
</style>

View File

@ -1,5 +1,13 @@
import type { FileData } from "@gradio/client";
import { uploadToHuggingFace } from "@gradio/utils";
import type {
TupleFormat,
ComponentMessage,
ComponentData,
TextMessage,
NormalisedMessage,
Message
} from "../types";
export const format_chat_for_sharing = async (
chat: [string | FileData | null, string | FileData | null][]
@ -56,51 +64,6 @@ export const format_chat_for_sharing = async (
.join("\n");
};
export interface ComponentMessage {
type: "component";
component: string;
value: any;
constructor_args: any;
props: any;
id: string;
}
export interface TextMessage {
type: "text";
value: string;
id: string;
}
export interface FileMessage {
type: "file";
file: FileData | FileData[];
alt_text: string | null;
id: string;
}
export interface EmptyMessage {
type: "empty";
value: null;
id: string;
}
export type NormalisedMessage =
| TextMessage
| FileMessage
| ComponentMessage
| EmptyMessage;
export type message_data =
| string
| { file: FileData | FileData[]; alt_text: string | null }
| { component: string; value: any; constructor_args: any; props: any }
| null;
export type messages = [message_data, message_data][] | null;
function make_id(): string {
return Math.random().toString(36).substring(7);
}
const redirect_src_url = (src: string, root: string): string =>
src.replace('src="/file', `src="${root}file`);
@ -114,37 +77,82 @@ function get_component_for_mime_type(
return "file";
}
function convert_file_message_to_component_message(
message: any
): ComponentData {
const _file = Array.isArray(message.file) ? message.file[0] : message.file;
return {
component: get_component_for_mime_type(_file?.mime_type),
value: message.file,
alt_text: message.alt_text,
constructor_args: {},
props: {}
} as ComponentData;
}
export function normalise_messages(
messages: messages,
messages: Message[] | null,
root: string
): [NormalisedMessage, NormalisedMessage][] | null {
if (messages === null) return null;
return messages.map((message_pair) => {
return message_pair.map((message) => {
if (message == null) return { value: null, id: make_id(), type: "empty" };
): NormalisedMessage[] | null {
if (messages === null) return messages;
return messages.map((message, i) => {
if (typeof message.content === "string") {
return {
role: message.role,
metadata: message.metadata,
content: redirect_src_url(message.content, root),
type: "text",
index: i
};
} else if ("file" in message.content) {
return {
content: convert_file_message_to_component_message(message.content),
metadata: message.metadata,
role: message.role,
type: "component",
index: i
};
}
return { type: "component", ...message } as ComponentMessage;
});
}
export function normalise_tuples(
messages: TupleFormat,
root: string
): NormalisedMessage[] | null {
if (messages === null) return messages;
const msg = messages.flatMap((message_pair, i) => {
return message_pair.map((message, index) => {
if (message == null) return null;
const role = index == 0 ? "user" : "assistant";
if (typeof message === "string") {
return {
role: role,
type: "text",
value: redirect_src_url(message, root),
id: make_id()
};
content: redirect_src_url(message, root),
metadata: { title: null },
index: [i, index]
} as TextMessage;
}
if ("file" in message) {
const _file = Array.isArray(message.file)
? message.file[0]
: message.file;
return {
content: convert_file_message_to_component_message(message),
role: role,
type: "component",
component: get_component_for_mime_type(_file?.mime_type),
value: message.file,
alt_text: message.alt_text,
id: make_id()
};
index: [i, index]
} as ComponentMessage;
}
return { ...message, type: "component", id: make_id() };
}) as [NormalisedMessage, NormalisedMessage];
return {
role: role,
content: message,
type: "component",
index: [i, index]
} as ComponentMessage;
});
});
return msg.filter((message) => message != null) as NormalisedMessage[];
}

42
js/chatbot/types.ts Normal file
View File

@ -0,0 +1,42 @@
import type { FileData } from "@gradio/client";
export type MessageRole = "system" | "user" | "assistant";
export interface Metadata {
title: string | null;
}
export interface ComponentData {
component: string;
constructor_args: any;
props: any;
value: any;
alt_text: string | null;
}
export interface Message {
role: MessageRole;
metadata: Metadata;
content: string | FileData | ComponentData;
index: [number, number] | number;
}
export interface TextMessage extends Message {
type: "text";
content: string;
}
export interface ComponentMessage extends Message {
type: "component";
content: ComponentData;
}
export type message_data =
| string
| { file: FileData | FileData[]; alt_text: string | null }
| { component: string; value: any; constructor_args: any; props: any }
| null;
export type TupleFormat = [message_data, message_data][] | null;
export type NormalisedMessage = TextMessage | ComponentMessage;

View File

@ -200,3 +200,11 @@ export const drag_and_drop_file = async (
await selector.dispatchEvent("drop", { dataTransfer });
}
};
export async function go_to_testcase(
page: Page,
test_case: string
): Promise<void> {
const url = page.url();
await page.goto(`${url.substring(0, url.length - 1)}_${test_case}_testcase`);
}

View File

@ -47,6 +47,7 @@ class TestChatbot:
"proxy_url": None,
"_selectable": False,
"key": None,
"msg_format": "tuples",
"latex_delimiters": [{"display": True, "left": "$$", "right": "$$"}],
"likeable": False,
"rtl": False,

View File

@ -247,35 +247,41 @@ class TestAPI:
assert len(api_info["unnamed_endpoints"]) == 0
assert "/chat" in api_info["named_endpoints"]
def test_streaming_api(self, connect):
chatbot = gr.ChatInterface(stream).queue()
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_streaming_api(self, msg_format, connect):
chatbot = gr.ChatInterface(stream, msg_format=msg_format).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
def test_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_stream).queue()
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_streaming_api_async(self, msg_format, connect):
chatbot = gr.ChatInterface(async_stream, msg_format=msg_format).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
def test_non_streaming_api(self, connect):
chatbot = gr.ChatInterface(double)
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_non_streaming_api(self, msg_format, connect):
chatbot = gr.ChatInterface(double, msg_format=msg_format)
with connect(chatbot) as client:
result = client.predict("hello")
assert result == "hello hello"
def test_non_streaming_api_async(self, connect):
chatbot = gr.ChatInterface(async_greet)
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_non_streaming_api_async(self, msg_format, connect):
chatbot = gr.ChatInterface(async_greet, msg_format=msg_format)
with connect(chatbot) as client:
result = client.predict("gradio")
assert result == "hi, gradio"
def test_streaming_api_with_additional_inputs(self, connect):
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_streaming_api_with_additional_inputs(self, msg_format, connect):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
msg_format=msg_format,
additional_inputs=["textbox", "slider"],
).queue()
with connect(chatbot) as client: