mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-13 11:57:29 +08:00
Support Message API for chatbot and chatinterface (#8422)
* first commit * Add code * Tests + code * lint * Add code * notebook * add changeset * type * Add client test * type * Add code * Chatbot type * Add code * test chatbot * fix e2e test * js tests * Consolidate Error and Tool message. Allow Messages in postprocess * Rename to messages * fix tests * notebook clean * More tests and messages * add changeset * notebook * client test * Fix issues * Chatbot docs * add changeset * Add image * Add img tag * Address comments * Add code * Revert chatinterface streaming change. Use title in metadata. Address pngwn comments * Add code * changelog highlight --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
936c7137a9
commit
4221290d84
93
.changeset/young-crabs-begin.md
Normal file
93
.changeset/young-crabs-begin.md
Normal file
@ -0,0 +1,93 @@
|
||||
---
|
||||
"@gradio/chatbot": minor
|
||||
"@gradio/tootils": minor
|
||||
"gradio": minor
|
||||
"website": minor
|
||||
---
|
||||
|
||||
highlight:
|
||||
|
||||
#### Support message format in chatbot 💬
|
||||
|
||||
`gr.Chatbot` and `gr.ChatInterface` now support the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api#messages-api), which is fully compatible with LLM API providers such as Hugging Face Text Generation Inference, OpenAI's chat completions API, and Llama.cpp server.
|
||||
|
||||
Building Gradio applications around these LLM solutions is now even easier!
|
||||
|
||||
`gr.Chatbot` and `gr.ChatInterface` now have a `msg_format` parameter that can accept two values - `'tuples'` and `'messages'`. If set to `'tuples'`, the default chatbot data format is expected. If set to `'messages'`, a list of dictionaries with `content` and `role` keys is expected. See below -
|
||||
|
||||
```python
|
||||
def chat_greeter(msg, history):
|
||||
history.append({"role": "assistant", "content": "Hello!"})
|
||||
return history
|
||||
```
|
||||
|
||||
Additionally, gradio now exposes a `gr.ChatMessage` dataclass you can use for IDE type hints and auto completion.
|
||||
|
||||
<img width="852" alt="image" src="https://github.com/freddyaboulton/freddyboulton/assets/41651716/d283e8f3-b194-466a-8194-c7e697dca9ad">
|
||||
|
||||
|
||||
#### Tool use in Chatbot 🛠️
|
||||
|
||||
The Gradio Chatbot can now natively display tool usage and intermediate thoughts common in Agent and chain-of-thought workflows!
|
||||
|
||||
If you are using the new "messages" format, simply add a `metadata` key with a dictionary containing a `title` key and `value`. This will display the assistant message in an expandable message box to show the result of a tool or intermediate step.
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from gradio import ChatMessage
|
||||
import time
|
||||
|
||||
def generate_response(history):
|
||||
history.append(ChatMessage(role="user", content="What is the weather in San Francisco right now?"))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="In order to find the current weather in San Francisco, I will need to use my weather tool.")
|
||||
)
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="API Error when connecting to weather service.",
|
||||
metadata={"title": "💥 Error using tool 'Weather'"})
|
||||
)
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="I will try again",
|
||||
))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="Weather 72 degrees Fahrenheit with 20% chance of rain.",
|
||||
metadata={"title": "🛠️ Used tool 'Weather'"}
|
||||
))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="Now that the API succeeded I can complete my task.",
|
||||
))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!",
|
||||
))
|
||||
yield history
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot(msg_format="messages")
|
||||
button = gr.Button("Get San Francisco Weather")
|
||||
button.click(generate_response, chatbot, chatbot)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
|
||||
|
||||

|
@ -19,7 +19,10 @@ const test_files = readdirSync(TEST_FILES_PATH)
|
||||
!f.endsWith(".component.spec.ts") &&
|
||||
!f.endsWith(".reload.spec.ts")
|
||||
)
|
||||
.map((f) => basename(f, ".spec.ts"));
|
||||
.map((f) => ({
|
||||
module_name: `${basename(f, ".spec.ts")}.run`,
|
||||
dir_name: basename(f, ".spec.ts")
|
||||
}));
|
||||
|
||||
export default async function global_setup() {
|
||||
const verbose = process.env.GRADIO_TEST_VERBOSE;
|
||||
@ -29,7 +32,24 @@ export default async function global_setup() {
|
||||
|
||||
process.stdout.write(kl.yellow("\nCreating test gradio app.\n\n"));
|
||||
|
||||
const test_app = make_app(test_files, port);
|
||||
const test_cases = [];
|
||||
// check if there is a testcase file in the same directory as the test file
|
||||
// if there is, append that to the file
|
||||
test_files.forEach((value) => {
|
||||
const test_case_dir = join(ROOT, "demo", value.dir_name);
|
||||
|
||||
readdirSync(test_case_dir)
|
||||
.filter((f) => f.endsWith("_testcase.py"))
|
||||
.forEach((f) => {
|
||||
test_cases.push({
|
||||
module_name: `${value.dir_name}.${basename(f, ".py")}`,
|
||||
dir_name: `${value.dir_name}_${basename(f, ".py")}`
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
const all_test_files = test_files.concat(test_cases);
|
||||
const test_app = make_app(all_test_files, port);
|
||||
process.stdout.write(kl.yellow("App created. Starting test server.\n\n"));
|
||||
|
||||
process.stdout.write(kl.bgBlue(" =========================== \n"));
|
||||
@ -111,14 +131,14 @@ import uvicorn
|
||||
from fastapi import FastAPI
|
||||
import gradio as gr
|
||||
|
||||
${demos.map((d) => `from demo.${d}.run import demo as ${d}`).join("\n")}
|
||||
${demos.map((obj) => `from demo.${obj.module_name} import demo as ${obj.dir_name}`).join("\n")}
|
||||
|
||||
app = FastAPI()
|
||||
${demos
|
||||
.map(
|
||||
(d) =>
|
||||
`app = gr.mount_gradio_app(app, ${d}, path="/${d}", max_file_size=${
|
||||
d == "upload_file_limit_test" ? "'15kb'" : "None"
|
||||
(obj) =>
|
||||
`app = gr.mount_gradio_app(app, ${obj.dir_name}, path="/${obj.dir_name}", max_file_size=${
|
||||
obj.dir_name == "upload_file_limit_test" ? "'15kb'" : "None"
|
||||
})`
|
||||
)
|
||||
.join("\n")}
|
||||
|
@ -459,3 +459,26 @@ def max_file_size_demo():
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chatbot_message_format():
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot(msg_format="messages")
|
||||
msg = gr.Textbox()
|
||||
|
||||
def respond(message, chat_history: list):
|
||||
bot_message = random.choice(
|
||||
["How are you?", "I love you", "I'm very hungry"]
|
||||
)
|
||||
chat_history.extend(
|
||||
[
|
||||
{"role": "user", "content": message},
|
||||
{"role": "assistant", "content": bot_message},
|
||||
]
|
||||
)
|
||||
return "", chat_history
|
||||
|
||||
msg.submit(respond, [msg, chatbot], [msg, chatbot], api_name="chat")
|
||||
|
||||
return demo
|
||||
|
@ -682,6 +682,25 @@ class TestClientPredictionsWithKwargs:
|
||||
):
|
||||
client.predict(num1=3, operation="add", api_name="/predict")
|
||||
|
||||
def test_chatbot_message_format(self, chatbot_message_format):
|
||||
with connect(chatbot_message_format) as client:
|
||||
_, history = client.predict("hello", [], api_name="/chat")
|
||||
assert history[1]["role"] == "assistant"
|
||||
assert history[1]["content"] in [
|
||||
"How are you?",
|
||||
"I love you",
|
||||
"I'm very hungry",
|
||||
]
|
||||
_, history = client.predict("hi", history, api_name="/chat")
|
||||
assert history[2]["role"] == "user"
|
||||
assert history[2]["content"] == "hi"
|
||||
assert history[3]["role"] == "assistant"
|
||||
assert history[3]["content"] in [
|
||||
"How are you?",
|
||||
"I love you",
|
||||
"I'm very hungry",
|
||||
]
|
||||
|
||||
|
||||
class TestStatusUpdates:
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
|
101
demo/chatbot_core_components_simple/messages_testcase.py
Normal file
101
demo/chatbot_core_components_simple/messages_testcase.py
Normal file
@ -0,0 +1,101 @@
|
||||
import gradio as gr
|
||||
import random
|
||||
|
||||
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
|
||||
|
||||
|
||||
color_map = {
|
||||
"harmful": "crimson",
|
||||
"neutral": "gray",
|
||||
"beneficial": "green",
|
||||
}
|
||||
|
||||
def html_src(harm_level):
|
||||
return f"""
|
||||
<div style="display: flex; gap: 5px;padding: 2px 4px;margin-top: -40px">
|
||||
<div style="background-color: {color_map[harm_level]}; padding: 2px; border-radius: 5px;">
|
||||
{harm_level}
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
def print_like_dislike(x: gr.LikeData):
|
||||
print(x.index, x.value, x.liked)
|
||||
|
||||
def add_message(history, message):
|
||||
for x in message["files"]:
|
||||
history.append({"role": "user", "content": {"path": x}})
|
||||
if message["text"] is not None:
|
||||
history.append({"role": "user", "content": message['text']})
|
||||
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
||||
|
||||
def bot(history, response_type):
|
||||
if response_type == "gallery":
|
||||
msg = {"role": "assistant", "content": gr.Gallery(
|
||||
["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"]
|
||||
)
|
||||
}
|
||||
elif response_type == "image":
|
||||
msg = {"role": "assistant",
|
||||
"content": gr.Image("https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png")
|
||||
}
|
||||
elif response_type == "video":
|
||||
msg = {"role": "assistant",
|
||||
"content": gr.Video("https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4")
|
||||
}
|
||||
elif response_type == "audio":
|
||||
msg = {"role": "assistant",
|
||||
"content": gr.Audio("https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav")
|
||||
}
|
||||
elif response_type == "html":
|
||||
msg = {"role": "assistant",
|
||||
"content": gr.HTML(
|
||||
html_src(random.choice(["harmful", "neutral", "beneficial"]))
|
||||
)
|
||||
}
|
||||
else:
|
||||
msg = {"role": "assistant", "content": "Cool!"}
|
||||
history.append(msg)
|
||||
return history
|
||||
|
||||
|
||||
with gr.Blocks(fill_height=True) as demo:
|
||||
chatbot = gr.Chatbot(
|
||||
elem_id="chatbot",
|
||||
bubble_full_width=False,
|
||||
scale=1,
|
||||
msg_format="messages"
|
||||
)
|
||||
response_type = gr.Radio(
|
||||
[
|
||||
"image",
|
||||
"text",
|
||||
"gallery",
|
||||
"video",
|
||||
"audio",
|
||||
"html",
|
||||
],
|
||||
value="text",
|
||||
label="Response Type",
|
||||
)
|
||||
|
||||
chat_input = gr.MultimodalTextbox(
|
||||
interactive=True,
|
||||
placeholder="Enter message or upload file...",
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
chat_msg = chat_input.submit(
|
||||
add_message, [chatbot, chat_input], [chatbot, chat_input]
|
||||
)
|
||||
bot_msg = chat_msg.then(
|
||||
bot, [chatbot, response_type], chatbot, api_name="bot_response"
|
||||
)
|
||||
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
|
||||
|
||||
chatbot.like(print_like_dislike, None, None)
|
||||
|
||||
demo.queue()
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_core_components_simple"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/audio.wav https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/audio.wav\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/avatar.png\n", "!wget -q -O files/sample.txt https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/sample.txt\n", "!wget -q -O files/world.mp4 https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/world.mp4"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import random\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "color_map = {\n", " \"harmful\": \"crimson\",\n", " \"neutral\": \"gray\",\n", " \"beneficial\": \"green\",\n", "}\n", "\n", "\n", "def html_src(harm_level):\n", " return f\"\"\"\n", "<div style=\"display: flex; gap: 5px;\">\n", " <div style=\"background-color: {color_map[harm_level]}; padding: 2px; border-radius: 5px;\">\n", " {harm_level}\n", " </div>\n", "</div>\n", "\"\"\"\n", "\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "\n", "def bot(history, response_type):\n", " if response_type == \"gallery\":\n", " history[-1][1] = gr.Gallery(\n", " [\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " ]\n", " )\n", " elif response_type == \"image\":\n", " history[-1][1] = gr.Image(\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\"\n", " )\n", " elif response_type == \"video\":\n", " history[-1][1] = gr.Video(\n", " \"https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4\"\n", " )\n", " elif response_type == \"audio\":\n", " history[-1][1] = gr.Audio(\n", " \"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav\"\n", " )\n", " elif response_type == \"html\":\n", " history[-1][1] = gr.HTML(\n", " html_src(random.choice([\"harmful\", \"neutral\", \"beneficial\"]))\n", " )\n", " else:\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", " response_type = gr.Radio(\n", " [\n", " \"image\",\n", " \"text\",\n", " \"gallery\",\n", " \"video\",\n", " \"audio\",\n", " \"html\",\n", " ],\n", " value=\"text\",\n", " label=\"Response Type\",\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(\n", " interactive=True,\n", " placeholder=\"Enter message or upload file...\",\n", " show_label=False,\n", " )\n", "\n", " chat_msg = chat_input.submit(\n", " add_message, [chatbot, chat_input], [chatbot, chat_input]\n", " )\n", " bot_msg = chat_msg.then(\n", " bot, [chatbot, response_type], chatbot, api_name=\"bot_response\"\n", " )\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_core_components_simple"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/audio.wav https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/audio.wav\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/avatar.png\n", "!wget -q -O files/sample.txt https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/sample.txt\n", "!wget -q -O files/world.mp4 https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/files/world.mp4\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_core_components_simple/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import random\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "color_map = {\n", " \"harmful\": \"crimson\",\n", " \"neutral\": \"gray\",\n", " \"beneficial\": \"green\",\n", "}\n", "\n", "\n", "def html_src(harm_level):\n", " return f\"\"\"\n", "<div style=\"display: flex; gap: 5px;\">\n", " <div style=\"background-color: {color_map[harm_level]}; padding: 2px; border-radius: 5px;\">\n", " {harm_level}\n", " </div>\n", "</div>\n", "\"\"\"\n", "\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "\n", "def bot(history, response_type):\n", " if response_type == \"gallery\":\n", " history[-1][1] = gr.Gallery(\n", " [\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\",\n", " ]\n", " )\n", " elif response_type == \"image\":\n", " history[-1][1] = gr.Image(\n", " \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\"\n", " )\n", " elif response_type == \"video\":\n", " history[-1][1] = gr.Video(\n", " \"https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4\"\n", " )\n", " elif response_type == \"audio\":\n", " history[-1][1] = gr.Audio(\n", " \"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav\"\n", " )\n", " elif response_type == \"html\":\n", " history[-1][1] = gr.HTML(\n", " html_src(random.choice([\"harmful\", \"neutral\", \"beneficial\"]))\n", " )\n", " else:\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", " response_type = gr.Radio(\n", " [\n", " \"image\",\n", " \"text\",\n", " \"gallery\",\n", " \"video\",\n", " \"audio\",\n", " \"html\",\n", " ],\n", " value=\"text\",\n", " label=\"Response Type\",\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(\n", " interactive=True,\n", " placeholder=\"Enter message or upload file...\",\n", " show_label=False,\n", " )\n", "\n", " chat_msg = chat_input.submit(\n", " add_message, [chatbot, chat_input], [chatbot, chat_input]\n", " )\n", " bot_msg = chat_msg.then(\n", " bot, [chatbot, response_type], chatbot, api_name=\"bot_response\"\n", " )\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
45
demo/chatbot_multimodal/messages_testcase.py
Normal file
45
demo/chatbot_multimodal/messages_testcase.py
Normal file
@ -0,0 +1,45 @@
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
|
||||
|
||||
|
||||
def print_like_dislike(x: gr.LikeData):
|
||||
print(x.index, x.value, x.liked)
|
||||
|
||||
def add_message(history, message):
|
||||
for x in message["files"]:
|
||||
history.append({"role": "user", "content": {"path": x}})
|
||||
if message["text"] is not None:
|
||||
history.append({"role": "user", "content": message["text"]})
|
||||
return history, gr.MultimodalTextbox(value=None, interactive=False)
|
||||
|
||||
def bot(history: list):
|
||||
response = "**That's cool!**"
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
for character in response:
|
||||
history[-1]['content'] += character
|
||||
time.sleep(0.05)
|
||||
yield history
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot(
|
||||
[],
|
||||
elem_id="chatbot",
|
||||
bubble_full_width=False,
|
||||
msg_format="messages"
|
||||
)
|
||||
|
||||
chat_input = gr.MultimodalTextbox(interactive=True,
|
||||
file_count="multiple",
|
||||
placeholder="Enter message or upload file...", show_label=False)
|
||||
|
||||
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
||||
bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
|
||||
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
|
||||
|
||||
chatbot.like(print_like_dislike, None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio plotly"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/files/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import plotly.express as px\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "def random_plot():\n", " df = px.data.iris()\n", " fig = px.scatter(df, x=\"sepal_width\", y=\"sepal_length\", color=\"species\",\n", " size='petal_length', hover_data=['petal_width'])\n", " return fig\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "def bot(history):\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "fig = random_plot()\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(interactive=True,\n", " file_count=\"multiple\",\n", " placeholder=\"Enter message or upload file...\", show_label=False)\n", "\n", " chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])\n", " bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name=\"bot_response\")\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio plotly"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('files')\n", "!wget -q -O files/avatar.png https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/files/avatar.png\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import plotly.express as px\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "def random_plot():\n", " df = px.data.iris()\n", " fig = px.scatter(df, x=\"sepal_width\", y=\"sepal_length\", color=\"species\",\n", " size='petal_length', hover_data=['petal_width'])\n", " return fig\n", "\n", "def print_like_dislike(x: gr.LikeData):\n", " print(x.index, x.value, x.liked)\n", "\n", "def add_message(history, message):\n", " for x in message[\"files\"]:\n", " history.append(((x,), None))\n", " if message[\"text\"] is not None:\n", " history.append((message[\"text\"], None))\n", " return history, gr.MultimodalTextbox(value=None, interactive=False)\n", "\n", "def bot(history):\n", " history[-1][1] = \"Cool!\"\n", " return history\n", "\n", "fig = random_plot()\n", "\n", "with gr.Blocks(fill_height=True) as demo:\n", " chatbot = gr.Chatbot(\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " scale=1,\n", " )\n", "\n", " chat_input = gr.MultimodalTextbox(interactive=True,\n", " file_count=\"multiple\",\n", " placeholder=\"Enter message or upload file...\", show_label=False)\n", "\n", " chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])\n", " bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name=\"bot_response\")\n", " bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])\n", "\n", " chatbot.like(print_like_dislike, None, None)\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_streaming"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " msg = gr.Textbox()\n", " clear = gr.Button(\"Clear\")\n", "\n", " def user(user_message, history):\n", " return \"\", history + [[user_message, None]]\n", "\n", " def bot(history):\n", " bot_message = random.choice([\"How are you?\", \"I love you\", \"I'm very hungry\"])\n", " history[-1][1] = \"\"\n", " for character in bot_message:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", " \n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_streaming"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_streaming/testcase_messages.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " msg = gr.Textbox()\n", " clear = gr.Button(\"Clear\")\n", "\n", " def user(user_message, history):\n", " return \"\", history + [[user_message, None]]\n", "\n", " def bot(history):\n", " bot_message = random.choice([\"How are you?\", \"I love you\", \"I'm very hungry\"])\n", " history[-1][1] = \"\"\n", " for character in bot_message:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -22,7 +22,6 @@ with gr.Blocks() as demo:
|
||||
bot, chatbot, chatbot
|
||||
)
|
||||
clear.click(lambda: None, None, chatbot, queue=False)
|
||||
|
||||
demo.queue()
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
28
demo/chatbot_streaming/testcase_messages.py
Normal file
28
demo/chatbot_streaming/testcase_messages.py
Normal file
@ -0,0 +1,28 @@
|
||||
import gradio as gr
|
||||
import random
|
||||
import time
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot(msg_format="messages")
|
||||
msg = gr.Textbox()
|
||||
clear = gr.Button("Clear")
|
||||
|
||||
def user(user_message, history: list):
|
||||
return "", history + [{"role": "user", "content": user_message}]
|
||||
|
||||
def bot(history: list):
|
||||
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
|
||||
history.append({"role": "assistant", "content": ""})
|
||||
for character in bot_message:
|
||||
history[-1]['content'] += character
|
||||
time.sleep(0.05)
|
||||
yield history
|
||||
|
||||
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||
bot, chatbot, chatbot
|
||||
)
|
||||
clear.click(lambda: None, None, chatbot, queue=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
1
demo/chatbot_with_tools/run.ipynb
Normal file
1
demo/chatbot_with_tools/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_with_tools"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from gradio import ChatMessage\n", "import time\n", "\n", "def generate_response(history):\n", " history.append(ChatMessage(role=\"user\", content=\"What is the weather in San Francisco right now?\"))\n", " yield history\n", " time.sleep(0.25)\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"In order to find the current weather in San Francisco, I will need to use my weather tool.\")\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"API Error when connecting to weather service.\",\n", " metadata={\"title\": \"\ud83d\udca5 Error using tool 'Weather'\"})\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"I will try again\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Weather 72 degrees Fahrenheit with 20% chance of rain.\",\n", " metadata={\"title\": \"\ud83d\udee0\ufe0f Used tool 'Weather'\"}\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Now that the API succeeded I can complete my task.\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!\",\n", " ))\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(msg_format=\"messages\")\n", " button = gr.Button(\"Get San Francisco Weather\")\n", " button.click(generate_response, chatbot, chatbot)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
53
demo/chatbot_with_tools/run.py
Normal file
53
demo/chatbot_with_tools/run.py
Normal file
@ -0,0 +1,53 @@
|
||||
import gradio as gr
|
||||
from gradio import ChatMessage
|
||||
import time
|
||||
|
||||
def generate_response(history):
|
||||
history.append(ChatMessage(role="user", content="What is the weather in San Francisco right now?"))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="In order to find the current weather in San Francisco, I will need to use my weather tool.")
|
||||
)
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="API Error when connecting to weather service.",
|
||||
metadata={"title": "💥 Error using tool 'Weather'"})
|
||||
)
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="I will try again",
|
||||
))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="Weather 72 degrees Fahrenheit with 20% chance of rain.",
|
||||
metadata={"title": "🛠️ Used tool 'Weather'"}
|
||||
))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="Now that the API succeeded I can complete my task.",
|
||||
))
|
||||
yield history
|
||||
time.sleep(0.25)
|
||||
|
||||
history.append(ChatMessage(role="assistant",
|
||||
content="It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!",
|
||||
))
|
||||
yield history
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
chatbot = gr.Chatbot(msg_format="messages")
|
||||
button = gr.Button("Get San Francisco Weather")
|
||||
button.click(generate_response, chatbot, chatbot)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
15
demo/chatinterface_streaming_echo/messages_testcase.py
Normal file
15
demo/chatinterface_streaming_echo/messages_testcase.py
Normal file
@ -0,0 +1,15 @@
|
||||
import time
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def slow_echo(message, history):
|
||||
for i in range(len(message)):
|
||||
time.sleep(0.05)
|
||||
yield "You typed: " + message[: i + 1]
|
||||
|
||||
|
||||
|
||||
demo = gr.ChatInterface(slow_echo, msg_format="messages")
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "\n", "def slow_echo(message, history):\n", " for i in range(len(message)):\n", " time.sleep(0.05)\n", " yield \"You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo).queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatinterface_streaming_echo/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import time\n", "import gradio as gr\n", "\n", "\n", "def slow_echo(message, history):\n", " for i in range(len(message)):\n", " time.sleep(0.05)\n", " yield \"You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -8,7 +8,7 @@ def slow_echo(message, history):
|
||||
yield "You typed: " + message[: i + 1]
|
||||
|
||||
|
||||
demo = gr.ChatInterface(slow_echo).queue()
|
||||
demo = gr.ChatInterface(slow_echo)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
17
demo/test_chatinterface_streaming_echo/messages_testcase.py
Normal file
17
demo/test_chatinterface_streaming_echo/messages_testcase.py
Normal file
@ -0,0 +1,17 @@
|
||||
import time
|
||||
import gradio as gr
|
||||
|
||||
runs = 0
|
||||
|
||||
def slow_echo(message, history):
|
||||
global runs # i didn't want to add state or anything to this demo
|
||||
runs = runs + 1
|
||||
for i in range(len(message)):
|
||||
time.sleep(0.05)
|
||||
yield f"Run {runs} - You typed: " + message[: i + 1]
|
||||
|
||||
|
||||
demo = gr.ChatInterface(slow_echo, msg_format="messages").queue()
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo).queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_streaming_echo"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_streaming_echo/messages_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "runs = 0\n", "\n", "\n", "def slow_echo(message, history):\n", " global runs # i didn't want to add state or anything to this demo\n", " runs = runs + 1\n", " for i in range(len(message)):\n", " yield f\"Run {runs} - You typed: \" + message[: i + 1]\n", "\n", "\n", "demo = gr.ChatInterface(slow_echo).queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -16,6 +16,7 @@ from gradio.components import (
|
||||
BarPlot,
|
||||
Button,
|
||||
Chatbot,
|
||||
ChatMessage,
|
||||
Checkbox,
|
||||
CheckboxGroup,
|
||||
Checkboxgroup,
|
||||
@ -42,6 +43,7 @@ from gradio.components import (
|
||||
LoginButton,
|
||||
LogoutButton,
|
||||
Markdown,
|
||||
MessageDict,
|
||||
Model3D,
|
||||
MultimodalTextbox,
|
||||
Number,
|
||||
|
@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
||||
|
||||
import anyio
|
||||
@ -22,6 +23,8 @@ from gradio.components import (
|
||||
Textbox,
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.components.chatbot import FileDataDict, Message, MessageDict, TupleFormat
|
||||
from gradio.components.multimodal_textbox import MultimodalData
|
||||
from gradio.events import Dependency, on
|
||||
from gradio.helpers import create_examples as Examples # noqa: N812
|
||||
from gradio.helpers import special_args
|
||||
@ -56,6 +59,7 @@ class ChatInterface(Blocks):
|
||||
fn: Callable,
|
||||
*,
|
||||
multimodal: bool = False,
|
||||
msg_format: Literal["messages", "tuples"] = "tuples",
|
||||
chatbot: Chatbot | None = None,
|
||||
textbox: Textbox | MultimodalTextbox | None = None,
|
||||
additional_inputs: str | Component | list[str | Component] | None = None,
|
||||
@ -123,6 +127,7 @@ class ChatInterface(Blocks):
|
||||
fill_height=fill_height,
|
||||
delete_cache=delete_cache,
|
||||
)
|
||||
self.msg_format: Literal["messages", "tuples"] = msg_format
|
||||
self.multimodal = multimodal
|
||||
self.concurrency_limit = concurrency_limit
|
||||
self.fn = fn
|
||||
@ -182,10 +187,19 @@ class ChatInterface(Blocks):
|
||||
Markdown(description)
|
||||
|
||||
if chatbot:
|
||||
if self.msg_format != chatbot.msg_format:
|
||||
warnings.warn(
|
||||
"The msg_format of the chatbot does not match the msg_format of the chat interface. The msg_format of the chat interface will be used."
|
||||
"Recieved msg_format of chatbot: {chatbot.msg_format}, msg_format of chat interface: {self.msg_format}"
|
||||
)
|
||||
chatbot.msg_format = self.msg_format
|
||||
self.chatbot = get_component_instance(chatbot, render=True)
|
||||
else:
|
||||
self.chatbot = Chatbot(
|
||||
label="Chatbot", scale=1, height=200 if fill_height else None
|
||||
label="Chatbot",
|
||||
scale=1,
|
||||
height=200 if fill_height else None,
|
||||
msg_format=self.msg_format,
|
||||
)
|
||||
|
||||
with Row():
|
||||
@ -329,6 +343,7 @@ class ChatInterface(Blocks):
|
||||
[self.textbox, self.saved_input],
|
||||
show_api=False,
|
||||
queue=False,
|
||||
preprocess=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
@ -383,6 +398,12 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
self._setup_stop_events([self.retry_btn.click], retry_event)
|
||||
|
||||
async def format_textbox(data: str | MultimodalData) -> str | dict:
|
||||
if isinstance(data, MultimodalData):
|
||||
return {"text": data.text, "files": [x.path for x in data.files]}
|
||||
else:
|
||||
return data
|
||||
|
||||
if self.undo_btn:
|
||||
self.undo_btn.click(
|
||||
self._delete_prev_fn,
|
||||
@ -391,7 +412,7 @@ class ChatInterface(Blocks):
|
||||
show_api=False,
|
||||
queue=False,
|
||||
).then(
|
||||
async_lambda(lambda x: x),
|
||||
format_textbox,
|
||||
[self.saved_input],
|
||||
[self.textbox],
|
||||
show_api=False,
|
||||
@ -498,54 +519,82 @@ class ChatInterface(Blocks):
|
||||
),
|
||||
)
|
||||
|
||||
def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]:
|
||||
def _clear_and_save_textbox(
|
||||
self, message: str | dict
|
||||
) -> tuple[str | dict, str | MultimodalData]:
|
||||
if self.multimodal:
|
||||
return {"text": "", "files": []}, message
|
||||
return {"text": "", "files": []}, MultimodalData(**cast(dict, message))
|
||||
else:
|
||||
return "", message
|
||||
return "", cast(str, message)
|
||||
|
||||
def _append_multimodal_history(
|
||||
self,
|
||||
message: dict[str, list],
|
||||
response: str | None,
|
||||
history: list[list[str | tuple | None]],
|
||||
message: MultimodalData,
|
||||
response: MessageDict | str | None,
|
||||
history: list[MessageDict] | TupleFormat,
|
||||
):
|
||||
for x in message["files"]:
|
||||
history.append([(x,), None])
|
||||
if message["text"] is None or not isinstance(message["text"], str):
|
||||
return
|
||||
elif message["text"] == "" and message["files"] != []:
|
||||
history.append([None, response])
|
||||
if self.msg_format == "tuples":
|
||||
for x in message.files:
|
||||
history.append([(x.path,), None]) # type: ignore
|
||||
if message.text is None or not isinstance(message.text, str):
|
||||
return
|
||||
elif message.text == "" and message.files != []:
|
||||
history.append([None, response]) # type: ignore
|
||||
else:
|
||||
history.append([message.text, cast(str, response)]) # type: ignore
|
||||
else:
|
||||
history.append([message["text"], response])
|
||||
for x in message.files:
|
||||
history.append(
|
||||
{"role": "user", "content": cast(FileDataDict, x.model_dump())} # type: ignore
|
||||
)
|
||||
if message.text is None or not isinstance(message.text, str):
|
||||
return
|
||||
else:
|
||||
history.append({"role": "user", "content": message.text}) # type: ignore
|
||||
if response:
|
||||
history.append(cast(MessageDict, response)) # type: ignore
|
||||
|
||||
async def _display_input(
|
||||
self, message: str | dict[str, list], history: list[list[str | tuple | None]]
|
||||
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
self, message: str | MultimodalData, history: TupleFormat | list[MessageDict]
|
||||
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
self._append_multimodal_history(message, None, history)
|
||||
elif isinstance(message, str):
|
||||
history.append([message, None])
|
||||
return history, history
|
||||
elif isinstance(message, str) and self.msg_format == "tuples":
|
||||
history.append([message, None]) # type: ignore
|
||||
elif isinstance(message, str) and self.msg_format == "messages":
|
||||
history.append({"role": "user", "content": message}) # type: ignore
|
||||
return history, history # type: ignore
|
||||
|
||||
def response_as_dict(self, response: MessageDict | Message | str) -> MessageDict:
|
||||
if isinstance(response, Message):
|
||||
new_response = response.model_dump()
|
||||
elif isinstance(response, str):
|
||||
return {"role": "assistant", "content": response}
|
||||
else:
|
||||
new_response = response
|
||||
return cast(MessageDict, new_response)
|
||||
|
||||
async def _submit_fn(
|
||||
self,
|
||||
message: str | dict[str, list],
|
||||
history_with_input: list[list[str | tuple | None]],
|
||||
message: str | MultimodalData,
|
||||
history_with_input: TupleFormat | list[MessageDict],
|
||||
request: Request,
|
||||
*args,
|
||||
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
remove_input = (
|
||||
len(message["files"]) + 1
|
||||
if message["text"] is not None
|
||||
else len(message["files"])
|
||||
len(message.files) + 1
|
||||
if message.text is not None
|
||||
else len(message.files)
|
||||
)
|
||||
history = history_with_input[:-remove_input]
|
||||
message_serialized = message.model_dump()
|
||||
else:
|
||||
history = history_with_input[:-1]
|
||||
message_serialized = message
|
||||
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
self.fn, inputs=[message_serialized, history, *args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
@ -555,24 +604,31 @@ class ChatInterface(Blocks):
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
self._append_multimodal_history(message, response, history)
|
||||
elif isinstance(message, str):
|
||||
history.append([message, response])
|
||||
return history, history
|
||||
if self.msg_format == "messages":
|
||||
new_response = self.response_as_dict(response)
|
||||
else:
|
||||
new_response = response
|
||||
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
self._append_multimodal_history(message, new_response, history) # type: ignore
|
||||
elif isinstance(message, str) and self.msg_format == "tuples":
|
||||
history.append([message, new_response]) # type: ignore
|
||||
elif isinstance(message, str) and self.msg_format == "messages":
|
||||
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
|
||||
return history, history # type: ignore
|
||||
|
||||
async def _stream_fn(
|
||||
self,
|
||||
message: str | dict[str, list],
|
||||
history_with_input: list[list[str | tuple | None]],
|
||||
message: str | MultimodalData,
|
||||
history_with_input: TupleFormat | list[MessageDict],
|
||||
request: Request,
|
||||
*args,
|
||||
) -> AsyncGenerator:
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
remove_input = (
|
||||
len(message["files"]) + 1
|
||||
if message["text"] is not None
|
||||
else len(message["files"])
|
||||
len(message.files) + 1
|
||||
if message.text is not None
|
||||
else len(message.files)
|
||||
)
|
||||
history = history_with_input[:-remove_input]
|
||||
else:
|
||||
@ -590,30 +646,136 @@ class ChatInterface(Blocks):
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = await async_iteration(generator)
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
for x in message["files"]:
|
||||
history.append([(x,), None])
|
||||
update = history + [[message["text"], first_response]]
|
||||
if self.msg_format == "messages":
|
||||
first_response = self.response_as_dict(first_response)
|
||||
if (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.msg_format == "tuples"
|
||||
):
|
||||
for x in message.files:
|
||||
history.append([(x,), None]) # type: ignore
|
||||
update = history + [[message.text, first_response]]
|
||||
yield update, update
|
||||
else:
|
||||
elif (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.msg_format == "messages"
|
||||
):
|
||||
for x in message.files:
|
||||
history.append(
|
||||
{"role": "user", "content": cast(FileDataDict, x.model_dump())} # type: ignore
|
||||
)
|
||||
update = history + [
|
||||
{"role": "user", "content": message.text},
|
||||
first_response,
|
||||
]
|
||||
yield update, update
|
||||
elif self.msg_format == "tuples":
|
||||
update = history + [[message, first_response]]
|
||||
yield update, update
|
||||
else:
|
||||
update = history + [
|
||||
{"role": "user", "content": message},
|
||||
first_response,
|
||||
]
|
||||
yield update, update
|
||||
except StopIteration:
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
self._append_multimodal_history(message, None, history)
|
||||
yield history, history
|
||||
else:
|
||||
update = history + [[message, None]]
|
||||
yield update, update
|
||||
async for response in generator:
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
update = history + [[message["text"], response]]
|
||||
if self.msg_format == "messages":
|
||||
response = self.response_as_dict(response)
|
||||
if (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.msg_format == "tuples"
|
||||
):
|
||||
update = history + [[message.text, response]]
|
||||
yield update, update
|
||||
else:
|
||||
elif (
|
||||
self.multimodal
|
||||
and isinstance(message, MultimodalData)
|
||||
and self.msg_format == "messages"
|
||||
):
|
||||
update = history + [
|
||||
{"role": "user", "content": message.text},
|
||||
response,
|
||||
]
|
||||
yield update, update
|
||||
elif self.msg_format == "tuples":
|
||||
update = history + [[message, response]]
|
||||
yield update, update
|
||||
else:
|
||||
update = history + [{"role": "user", "content": message}, response]
|
||||
yield update, update
|
||||
|
||||
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
||||
async def _api_submit_fn(
|
||||
self,
|
||||
message: str,
|
||||
history: TupleFormat | list[MessageDict],
|
||||
request: Request,
|
||||
*args,
|
||||
) -> tuple[str, TupleFormat | list[MessageDict]]:
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
response = await self.fn(*inputs)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
if self.msg_format == "tuples":
|
||||
history.append([message, response]) # type: ignore
|
||||
else:
|
||||
new_response = self.response_as_dict(response)
|
||||
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
|
||||
return response, history
|
||||
|
||||
async def _api_stream_fn(
|
||||
self, message: str, history: list[list[str | None]], request: Request, *args
|
||||
) -> AsyncGenerator:
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *args], request=request
|
||||
)
|
||||
if self.is_async:
|
||||
generator = self.fn(*inputs)
|
||||
else:
|
||||
generator = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
try:
|
||||
first_response = await async_iteration(generator)
|
||||
if self.msg_format == "tuples":
|
||||
yield first_response, history + [[message, first_response]]
|
||||
else:
|
||||
first_response = self.response_as_dict(first_response)
|
||||
yield (
|
||||
first_response,
|
||||
history + [{"role": "user", "content": message}, first_response],
|
||||
)
|
||||
except StopIteration:
|
||||
yield None, history + [[message, None]]
|
||||
async for response in generator:
|
||||
if self.msg_format == "tuples":
|
||||
yield response, history + [[message, response]]
|
||||
else:
|
||||
new_response = self.response_as_dict(response)
|
||||
yield (
|
||||
new_response,
|
||||
history + [{"role": "user", "content": message}, new_response],
|
||||
)
|
||||
|
||||
async def _examples_fn(
|
||||
self, message: str, *args
|
||||
) -> TupleFormat | list[MessageDict]:
|
||||
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
||||
|
||||
if self.is_async:
|
||||
@ -622,7 +784,10 @@ class ChatInterface(Blocks):
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
return [[message, response]]
|
||||
if self.msg_format == "tuples":
|
||||
return [[message, response]]
|
||||
else:
|
||||
return [{"role": "user", "content": message}, response]
|
||||
|
||||
async def _examples_stream_fn(
|
||||
self,
|
||||
@ -639,24 +804,29 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
generator = SyncToAsyncIterator(generator, self.limiter)
|
||||
async for response in generator:
|
||||
yield [[message, response]]
|
||||
if self.msg_format == "tuples":
|
||||
yield [[message, response]]
|
||||
else:
|
||||
new_response = self.response_as_dict(response)
|
||||
yield [{"role": "user", "content": message}, new_response]
|
||||
|
||||
async def _delete_prev_fn(
|
||||
self,
|
||||
message: str | dict[str, list],
|
||||
history: list[list[str | tuple | None]],
|
||||
message: str | MultimodalData | None,
|
||||
history: list[MessageDict] | TupleFormat,
|
||||
) -> tuple[
|
||||
list[list[str | tuple | None]],
|
||||
str | dict[str, list],
|
||||
list[list[str | tuple | None]],
|
||||
list[MessageDict] | TupleFormat,
|
||||
str | MultimodalData,
|
||||
list[MessageDict] | TupleFormat,
|
||||
]:
|
||||
if self.multimodal and isinstance(message, dict):
|
||||
extra = 1 if self.msg_format == "messages" else 0
|
||||
if self.multimodal and isinstance(message, MultimodalData):
|
||||
remove_input = (
|
||||
len(message["files"]) + 1
|
||||
if message["text"] is not None
|
||||
else len(message["files"])
|
||||
)
|
||||
len(message.files) + 1
|
||||
if message.text is not None
|
||||
else len(message.files)
|
||||
) + extra
|
||||
history = history[:-remove_input]
|
||||
else:
|
||||
history = history[:-1]
|
||||
history = history[: -(1 + extra)]
|
||||
return history, message or "", history
|
||||
|
@ -11,7 +11,7 @@ from gradio.components.base import (
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.components.button import Button
|
||||
from gradio.components.chatbot import Chatbot
|
||||
from gradio.components.chatbot import Chatbot, ChatMessage, MessageDict
|
||||
from gradio.components.checkbox import Checkbox
|
||||
from gradio.components.checkboxgroup import CheckboxGroup
|
||||
from gradio.components.clear_button import ClearButton
|
||||
@ -64,6 +64,7 @@ __all__ = [
|
||||
"BarPlot",
|
||||
"Button",
|
||||
"Chatbot",
|
||||
"ChatMessage",
|
||||
"ClearButton",
|
||||
"Component",
|
||||
"component",
|
||||
@ -92,6 +93,7 @@ __all__ = [
|
||||
"LoginButton",
|
||||
"LogoutButton",
|
||||
"Markdown",
|
||||
"MessageDict",
|
||||
"Textbox",
|
||||
"Dropdown",
|
||||
"Model3D",
|
||||
|
@ -12,7 +12,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Type
|
||||
|
||||
import gradio_client.utils as client_utils
|
||||
|
||||
@ -166,7 +166,7 @@ class Component(ComponentBase, Block):
|
||||
# This gets overridden when `select` is called
|
||||
self._selectable = False
|
||||
if not hasattr(self, "data_model"):
|
||||
self.data_model: type[GradioDataModel] | None = None
|
||||
self.data_model: Type[GradioDataModel] | None = None
|
||||
|
||||
Block.__init__(
|
||||
self,
|
||||
|
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -13,11 +14,16 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document
|
||||
from pydantic import Field
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from gradio import utils
|
||||
from gradio.component_meta import ComponentMeta
|
||||
@ -27,6 +33,31 @@ from gradio.components import (
|
||||
from gradio.components.base import Component
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel
|
||||
from gradio.events import Events
|
||||
from gradio.exceptions import Error
|
||||
from gradio.processing_utils import move_resource_to_block_cache
|
||||
|
||||
|
||||
class MetadataDict(TypedDict):
|
||||
title: Union[str, None]
|
||||
|
||||
|
||||
class FileDataDict(TypedDict):
|
||||
path: str # server filepath
|
||||
url: NotRequired[Optional[str]] # normalised server url
|
||||
size: NotRequired[Optional[int]] # size in bytes
|
||||
orig_name: NotRequired[Optional[str]] # original filename
|
||||
mime_type: NotRequired[Optional[str]]
|
||||
is_stream: NotRequired[bool]
|
||||
meta: dict[Literal["_type"], Literal["gradio.FileData"]]
|
||||
|
||||
|
||||
class MessageDict(TypedDict):
|
||||
content: str | FileDataDict | tuple | Component
|
||||
role: Literal["user", "assistant", "system"]
|
||||
metadata: NotRequired[MetadataDict]
|
||||
|
||||
|
||||
TupleFormat = List[List[Union[str, Tuple[str], Tuple[str, str], None]]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Timer
|
||||
@ -59,7 +90,7 @@ class ComponentMessage(GradioModel):
|
||||
props: Dict[str, Any]
|
||||
|
||||
|
||||
class ChatbotData(GradioRootModel):
|
||||
class ChatbotDataTuples(GradioRootModel):
|
||||
root: List[
|
||||
Tuple[
|
||||
Union[str, FileMessage, ComponentMessage, None],
|
||||
@ -68,6 +99,27 @@ class ChatbotData(GradioRootModel):
|
||||
]
|
||||
|
||||
|
||||
class Metadata(GradioModel):
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class Message(GradioModel):
|
||||
role: str
|
||||
metadata: Metadata = Field(default_factory=Metadata)
|
||||
content: Union[str, FileMessage, ComponentMessage]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str | FileData | Component | FileDataDict | tuple | list
|
||||
metadata: MetadataDict | Metadata = field(default_factory=Metadata)
|
||||
|
||||
|
||||
class ChatbotDataMessages(GradioRootModel):
|
||||
root: List[Message]
|
||||
|
||||
|
||||
@document()
|
||||
class Chatbot(Component):
|
||||
"""
|
||||
@ -80,7 +132,6 @@ class Chatbot(Component):
|
||||
"""
|
||||
|
||||
EVENTS = [Events.change, Events.select, Events.like]
|
||||
data_model = ChatbotData
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -92,6 +143,7 @@ class Chatbot(Component):
|
||||
| None
|
||||
) = None,
|
||||
*,
|
||||
msg_format: Literal["messages", "tuples"] = "tuples",
|
||||
label: str | None = None,
|
||||
every: Timer | float | None = None,
|
||||
inputs: Component | list[Component] | set[Component] | None = None,
|
||||
@ -121,6 +173,7 @@ class Chatbot(Component):
|
||||
"""
|
||||
Parameters:
|
||||
value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
|
||||
msg_format: The format of the messages. If 'tuples', expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content' key's value supports everything the 'tuples' format supports. The 'role' key should be one of 'user' or 'assistant'. Any other roles will not be displayed in the output.
|
||||
label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
|
||||
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
|
||||
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
|
||||
@ -148,6 +201,15 @@ class Chatbot(Component):
|
||||
placeholder: a placeholder message to display in the chatbot when it is empty. Centered vertically and horizontally in the Chatbot. Supports Markdown and HTML. If None, no placeholder is displayed.
|
||||
"""
|
||||
self.likeable = likeable
|
||||
if msg_format not in ["messages", "tuples"]:
|
||||
raise ValueError(
|
||||
"msg_format must be 'messages' or 'tuples', received: {msg_format}"
|
||||
)
|
||||
self.msg_format: Literal["tuples", "messages"] = msg_format
|
||||
if msg_format == "messages":
|
||||
self.data_model = ChatbotDataMessages
|
||||
else:
|
||||
self.data_model = ChatbotDataTuples
|
||||
self.height = height
|
||||
self.rtl = rtl
|
||||
if latex_delimiters is None:
|
||||
@ -189,7 +251,27 @@ class Chatbot(Component):
|
||||
]
|
||||
self.placeholder = placeholder
|
||||
|
||||
def _preprocess_chat_messages(
|
||||
@staticmethod
|
||||
def _check_format(messages: list[Any], msg_format: Literal["messages", "tuples"]):
|
||||
if msg_format == "messages":
|
||||
all_dicts = all(
|
||||
isinstance(message, dict) and "role" in message and "content" in message
|
||||
for message in messages
|
||||
)
|
||||
all_msgs = all(isinstance(msg, ChatMessage) for msg in messages)
|
||||
if not (all_dicts or all_msgs):
|
||||
raise Error(
|
||||
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
|
||||
)
|
||||
elif not all(
|
||||
isinstance(message, (tuple, list)) and len(message) == 2
|
||||
for message in messages
|
||||
):
|
||||
raise Error(
|
||||
"Data incompatible with tuples format. Each message should be a list of length 2."
|
||||
)
|
||||
|
||||
def _preprocess_content(
|
||||
self,
|
||||
chat_message: str | FileMessage | ComponentMessage | None,
|
||||
) -> str | GradioComponent | tuple[str | None] | tuple[str | None, str] | None:
|
||||
@ -228,18 +310,9 @@ class Chatbot(Component):
|
||||
else:
|
||||
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
payload: ChatbotData | None,
|
||||
) -> list[list[str | GradioComponent | tuple[str] | tuple[str, str] | None]] | None:
|
||||
"""
|
||||
Parameters:
|
||||
payload: data as a ChatbotData object
|
||||
Returns:
|
||||
Passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed.
|
||||
"""
|
||||
if payload is None:
|
||||
return payload
|
||||
def _preprocess_messages_tuples(
|
||||
self, payload: ChatbotDataTuples
|
||||
) -> list[list[str | tuple[str] | tuple[str, str] | None]]:
|
||||
processed_messages = []
|
||||
for message_pair in payload.root:
|
||||
if not isinstance(message_pair, (tuple, list)):
|
||||
@ -252,29 +325,71 @@ class Chatbot(Component):
|
||||
)
|
||||
processed_messages.append(
|
||||
[
|
||||
self._preprocess_chat_messages(message_pair[0]),
|
||||
self._preprocess_chat_messages(message_pair[1]),
|
||||
self._preprocess_content(message_pair[0]),
|
||||
self._preprocess_content(message_pair[1]),
|
||||
]
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
def _postprocess_chat_messages(
|
||||
self, chat_message: str | tuple | list | GradioComponent | None
|
||||
) -> str | FileMessage | ComponentMessage | None:
|
||||
def create_file_message(chat_message, filepath):
|
||||
mime_type = client_utils.get_mimetype(filepath)
|
||||
return FileMessage(
|
||||
file=FileData(path=filepath, mime_type=mime_type),
|
||||
alt_text=(
|
||||
chat_message[1]
|
||||
if not isinstance(chat_message, GradioComponent)
|
||||
and len(chat_message) > 1
|
||||
else None
|
||||
),
|
||||
)
|
||||
def preprocess(
|
||||
self,
|
||||
payload: ChatbotDataTuples | ChatbotDataMessages | None,
|
||||
) -> (
|
||||
list[list[str | tuple[str] | tuple[str, str] | None]] | list[MessageDict] | None
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
payload: data as a ChatbotData object
|
||||
Returns:
|
||||
If msg_format is 'tuples', passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed. If msg_format is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
|
||||
"""
|
||||
if payload is None:
|
||||
return payload
|
||||
if self.msg_format == "tuples":
|
||||
if not isinstance(payload, ChatbotDataTuples):
|
||||
raise Error("Data incompatible with the tuples format")
|
||||
return self._preprocess_messages_tuples(cast(ChatbotDataTuples, payload))
|
||||
if not isinstance(payload, ChatbotDataMessages):
|
||||
raise Error("Data incompatible with the messages format")
|
||||
message_dicts = []
|
||||
for message in payload.root:
|
||||
message_dict = cast(MessageDict, message.model_dump())
|
||||
message_dict["content"] = self._preprocess_content(message.content)
|
||||
message_dicts.append(message_dict)
|
||||
return message_dicts
|
||||
|
||||
@staticmethod
|
||||
def _get_alt_text(chat_message: dict | list | tuple | GradioComponent):
|
||||
if isinstance(chat_message, dict):
|
||||
return chat_message.get("alt_text")
|
||||
elif not isinstance(chat_message, GradioComponent) and len(chat_message) > 1:
|
||||
return chat_message[1]
|
||||
|
||||
@staticmethod
|
||||
def _create_file_message(chat_message, filepath):
|
||||
mime_type = client_utils.get_mimetype(filepath)
|
||||
|
||||
return FileMessage(
|
||||
file=FileData(path=filepath, mime_type=mime_type),
|
||||
alt_text=Chatbot._get_alt_text(chat_message),
|
||||
)
|
||||
|
||||
def _postprocess_content(
|
||||
self,
|
||||
chat_message: str
|
||||
| tuple
|
||||
| list
|
||||
| FileDataDict
|
||||
| FileData
|
||||
| GradioComponent
|
||||
| None,
|
||||
) -> str | FileMessage | ComponentMessage | None:
|
||||
if chat_message is None:
|
||||
return None
|
||||
elif isinstance(chat_message, FileMessage):
|
||||
return chat_message
|
||||
elif isinstance(chat_message, FileData):
|
||||
return FileMessage(file=chat_message)
|
||||
elif isinstance(chat_message, GradioComponent):
|
||||
component = import_component_and_data(type(chat_message).__name__)
|
||||
if component:
|
||||
@ -287,54 +402,110 @@ class Chatbot(Component):
|
||||
constructor_args=chat_message.constructor_args,
|
||||
props=config,
|
||||
)
|
||||
elif isinstance(chat_message, dict) and "path" in chat_message:
|
||||
filepath = chat_message["path"]
|
||||
return self._create_file_message(chat_message, filepath)
|
||||
elif isinstance(chat_message, (tuple, list)):
|
||||
filepath = str(chat_message[0])
|
||||
return create_file_message(chat_message, filepath)
|
||||
return self._create_file_message(chat_message, filepath)
|
||||
elif isinstance(chat_message, str):
|
||||
chat_message = inspect.cleandoc(chat_message)
|
||||
return chat_message
|
||||
else:
|
||||
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
||||
|
||||
def _postprocess_messages_tuples(self, value: TupleFormat) -> ChatbotDataTuples:
|
||||
processed_messages = []
|
||||
for message_pair in value:
|
||||
processed_messages.append(
|
||||
[
|
||||
self._postprocess_content(message_pair[0]),
|
||||
self._postprocess_content(message_pair[1]),
|
||||
]
|
||||
)
|
||||
return ChatbotDataTuples(root=processed_messages)
|
||||
|
||||
def _postprocess_message_messages(
|
||||
self, message: MessageDict | ChatMessage
|
||||
) -> list[Message]:
|
||||
if isinstance(message, dict):
|
||||
message["content"] = self._postprocess_content(message["content"])
|
||||
msg = Message(**message) # type: ignore
|
||||
elif isinstance(message, ChatMessage):
|
||||
message.content = self._postprocess_content(message.content) # type: ignore
|
||||
msg = Message(
|
||||
role=message.role,
|
||||
content=message.content, # type: ignore
|
||||
metadata=message.metadata, # type: ignore
|
||||
)
|
||||
else:
|
||||
raise Error(
|
||||
f"Invalid message for Chatbot component: {message}", visible=False
|
||||
)
|
||||
|
||||
# extract file path from message
|
||||
new_messages = []
|
||||
if isinstance(msg.content, str):
|
||||
for word in msg.content.split(" "):
|
||||
filepath = Path(word)
|
||||
try:
|
||||
is_file = filepath.is_file() and filepath.exists()
|
||||
except OSError:
|
||||
is_file = False
|
||||
if is_file:
|
||||
filepath = cast(
|
||||
str, move_resource_to_block_cache(filepath, block=self)
|
||||
)
|
||||
mime_type = client_utils.get_mimetype(filepath)
|
||||
new_messages.append(
|
||||
Message(
|
||||
role=msg.role,
|
||||
metadata=msg.metadata,
|
||||
content=FileMessage(
|
||||
file=FileData(path=filepath, mime_type=mime_type)
|
||||
),
|
||||
),
|
||||
)
|
||||
return [msg, *new_messages]
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
value: (
|
||||
list[
|
||||
list[str | GradioComponent | tuple[str] | tuple[str, str] | None]
|
||||
| tuple
|
||||
]
|
||||
| None
|
||||
),
|
||||
) -> ChatbotData:
|
||||
value: TupleFormat | list[MessageDict | Message] | None,
|
||||
) -> ChatbotDataTuples | ChatbotDataMessages:
|
||||
"""
|
||||
Parameters:
|
||||
value: expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed.
|
||||
value: If msg_format is `tuples`, expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If msg_format is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
|
||||
Returns:
|
||||
an object of type ChatbotData
|
||||
"""
|
||||
data_model = cast(
|
||||
Union[Type[ChatbotDataTuples], Type[ChatbotDataMessages]], self.data_model
|
||||
)
|
||||
if value is None:
|
||||
return ChatbotData(root=[])
|
||||
|
||||
processed_messages = []
|
||||
for message_pair in value:
|
||||
if not isinstance(message_pair, (tuple, list)):
|
||||
raise TypeError(
|
||||
f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
)
|
||||
if len(message_pair) != 2:
|
||||
raise TypeError(
|
||||
f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
)
|
||||
processed_messages.append(
|
||||
[
|
||||
self._postprocess_chat_messages(message_pair[0]),
|
||||
self._postprocess_chat_messages(message_pair[1]),
|
||||
]
|
||||
)
|
||||
return ChatbotData(root=processed_messages)
|
||||
return data_model(root=[])
|
||||
if self.msg_format == "tuples":
|
||||
self._check_format(value, "tuples")
|
||||
return self._postprocess_messages_tuples(cast(TupleFormat, value))
|
||||
self._check_format(value, "messages")
|
||||
processed_messages = [
|
||||
msg
|
||||
for message in value
|
||||
for msg in self._postprocess_message_messages(cast(MessageDict, message))
|
||||
]
|
||||
return ChatbotDataMessages(root=processed_messages)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
if self.msg_format == "messages":
|
||||
return [
|
||||
Message(role="user", content="Hello!").model_dump(),
|
||||
Message(role="assistant", content="How can I help you?").model_dump(),
|
||||
]
|
||||
return [["Hello!", None]]
|
||||
|
||||
def example_value(self) -> Any:
|
||||
if self.msg_format == "messages":
|
||||
return [
|
||||
Message(role="user", content="Hello!").model_dump(),
|
||||
Message(role="assistant", content="How can I help you?").model_dump(),
|
||||
]
|
||||
return [["Hello!", None]]
|
||||
|
@ -229,7 +229,7 @@ class Interface(Blocks):
|
||||
state_output_index = state_output_indexes[0]
|
||||
if inputs[state_input_index] == "state":
|
||||
default = utils.get_default_args(fn)[state_input_index]
|
||||
state_variable = State(value=default) # type: ignore
|
||||
state_variable = State(value=default)
|
||||
else:
|
||||
state_variable = inputs[state_input_index]
|
||||
|
||||
@ -244,12 +244,10 @@ class Interface(Blocks):
|
||||
self.cache_examples = False
|
||||
|
||||
self.main_input_components = [
|
||||
get_component_instance(i, unrender=True)
|
||||
for i in inputs # type: ignore
|
||||
get_component_instance(i, unrender=True) for i in inputs
|
||||
]
|
||||
self.additional_input_components = [
|
||||
get_component_instance(i, unrender=True)
|
||||
for i in additional_inputs # type: ignore
|
||||
get_component_instance(i, unrender=True) for i in additional_inputs
|
||||
]
|
||||
if additional_inputs_accordion is None:
|
||||
self.additional_inputs_accordion_params = {
|
||||
|
@ -36,27 +36,133 @@ demo.launch()
|
||||
|
||||
<!-- Behavior -->
|
||||
### Behavior
|
||||
|
||||
The data format accepted by the Chatbot is dictated by the `msg_format` parameter.
|
||||
This parameter can take two values, `'tuples'` and `'messages'`.
|
||||
|
||||
|
||||
If `msg_format` is `'tuples'`, then the data sent to/from the chatbot will be a list of tuples.
|
||||
The first element of each tuple is the user message and the second element is the bot's response.
|
||||
Each element can be a string (markdown/html is supported),
|
||||
a tuple (in which case the first element is a filepath that will be displayed in the chatbot),
|
||||
or a gradio component (see the Examples section for more details).
|
||||
|
||||
|
||||
If the `msg_format` is `'messages'`, then the data sent to/from the chatbot will be a list of dictionaries
|
||||
with `role` and `content` keys. This format is compliant with the format expected by most LLM APIs (HuggingChat, OpenAI, Claude).
|
||||
The `role` key is either `'user'` or `'`assistant'` and the `content` key can be a string (markdown/html supported),
|
||||
a `FileDataDict` (to represent a file that is displayed in the chatbot - documented below), or a gradio component.
|
||||
|
||||
|
||||
For convenience, you can use the `ChatMessage` dataclass so that your text editor can give you autocomplete hints and typechecks.
|
||||
|
||||
```python
|
||||
from gradio import ChatMessage
|
||||
|
||||
def generate_response(history):
|
||||
history.append(
|
||||
ChatMessage(role="assistant",
|
||||
content="How can I help you?")
|
||||
)
|
||||
return history
|
||||
```
|
||||
|
||||
Additionally, when `msg_format` is `messages`, you can provide additional metadata regarding any tools used to generate the response.
|
||||
This is useful for displaying the thought process of LLM agents. For example,
|
||||
|
||||
```python
|
||||
def generate_response(history):
|
||||
history.append(
|
||||
ChatMessage(role="assistant",
|
||||
content="The weather API says it is 20 degrees Celcius in New York.",
|
||||
metadata={"title": "🛠️ Used tool Weather API"})
|
||||
)
|
||||
return history
|
||||
```
|
||||
|
||||
Would be displayed as following:
|
||||
|
||||
<img src="https://github.com/freddyaboulton/freddyboulton/assets/41651716/a4bb2b0a-5f8a-4287-814b-4eab278e021e" alt="Gradio chatbot tool display">
|
||||
|
||||
|
||||
All of the types expected by the messages format are documented below:
|
||||
|
||||
```python
|
||||
class MetadataDict(TypedDict):
|
||||
title: Union[str, None]
|
||||
|
||||
class FileDataDict(TypedDict):
|
||||
path: str # server filepath
|
||||
url: NotRequired[Optional[str]] # normalised server url
|
||||
size: NotRequired[Optional[int]] # size in bytes
|
||||
orig_name: NotRequired[Optional[str]] # original filename
|
||||
mime_type: NotRequired[Optional[str]]
|
||||
is_stream: NotRequired[bool]
|
||||
meta: dict[Literal["_type"], Literal["gradio.FileData"]]
|
||||
|
||||
|
||||
class MessageDict(TypedDict):
|
||||
content: str | FileDataDict | Component
|
||||
role: Literal["user", "assistant", "system"]
|
||||
metadata: NotRequired[MetadataDict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Metadata:
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str | FileData | Component | FileDataDict | tuple | list
|
||||
metadata: MetadataDict | Metadata = field(default_factory=Metadata)
|
||||
```
|
||||
|
||||
|
||||
## **As input component**: {@html style_formatted_text(obj.preprocess.return_doc.doc)}
|
||||
##### Your function should accept one of these types:
|
||||
|
||||
If `msg_format` is `tuples` -
|
||||
|
||||
```python
|
||||
from gradio import Component
|
||||
|
||||
def predict(
|
||||
value: list[list[str | tuple[str] | tuple[str, str] | None]] | None
|
||||
)
|
||||
value: list[list[str | tuple[str, str] | Component | None]] | None
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
If `msg_format` is `messages` -
|
||||
|
||||
```python
|
||||
from gradio import MessageDict
|
||||
|
||||
def predict(value: list[MessageDict] | None):
|
||||
...
|
||||
```
|
||||
<br>
|
||||
|
||||
## **As output component**: {@html style_formatted_text(obj.postprocess.parameter_doc[0].doc)}
|
||||
##### Your function should return one of these types:
|
||||
|
||||
If `msg_format` is `tuples` -
|
||||
|
||||
```python
|
||||
def predict(···) -> list[list[str | tuple[str] | tuple[str, str] | None] | tuple] | None
|
||||
...
|
||||
return value
|
||||
```
|
||||
|
||||
If `msg_format` is `messages` -
|
||||
|
||||
from gradio import ChatMessage, MessageDict
|
||||
|
||||
```python
|
||||
def predict(···) - > list[MessageDict] | list[ChatMessage]:
|
||||
...
|
||||
```
|
||||
|
||||
<!--- Initialization -->
|
||||
### Initialization
|
||||
|
65
js/app/test/chatbot_core_components_simple.spec.ts
Normal file
65
js/app/test/chatbot_core_components_simple.spec.ts
Normal file
@ -0,0 +1,65 @@
|
||||
import { test, expect, go_to_testcase } from "@gradio/tootils";
|
||||
|
||||
for (const msg_format of ["tuples", "messages"]) {
|
||||
test(`message format ${msg_format} - Gallery component properly displayed`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
await page.getByTestId("gallery-radio-label").click();
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.getByTestId("textbox").fill("gallery");
|
||||
await page.keyboard.press("Enter");
|
||||
await expect(
|
||||
page.getByLabel("Thumbnail 1 of 2").locator("img")
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByLabel("Thumbnail 2 of 2").locator("img")
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test(`message format ${msg_format} - Audio component properly displayed`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
await page.getByTestId("audio-radio-label").click();
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.getByTestId("textbox").fill("audio");
|
||||
await page.keyboard.press("Enter");
|
||||
await expect(
|
||||
page.getByTestId("unlabelled-audio").locator("audio")
|
||||
).toBeAttached();
|
||||
});
|
||||
|
||||
test(`message format ${msg_format} - Video component properly displayed`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
await page.getByTestId("video-radio-label").click();
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.getByTestId("textbox").fill("video");
|
||||
await page.keyboard.press("Enter");
|
||||
await expect(page.getByTestId("test-player")).toBeAttached();
|
||||
await expect(
|
||||
page.getByTestId("test-player").getAttribute("src")
|
||||
).toBeTruthy();
|
||||
});
|
||||
|
||||
test(`message format ${msg_format} - Image component properly displayed`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
await page.getByTestId("image-radio-label").click();
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.getByTestId("textbox").fill("image");
|
||||
await page.keyboard.press("Enter");
|
||||
await expect(page.getByTestId("bot").locator("img")).toBeAttached();
|
||||
});
|
||||
}
|
@ -1,221 +1,268 @@
|
||||
import { test, expect } from "@gradio/tootils";
|
||||
import { test, expect, go_to_testcase } from "@gradio/tootils";
|
||||
|
||||
test("text input by a user should be shown in the chatbot as a paragraph", async ({
|
||||
page
|
||||
}) => {
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("Lorem ipsum");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toEqual("Lorem ipsum");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
for (const msg_format of ["tuples", "messages"]) {
|
||||
test(`message format ${msg_format} - text input by a user should be shown in the chatbot as a paragraph`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("Lorem ipsum");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toEqual("Lorem ipsum");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("images uploaded by a user should be shown in the chat", async ({
|
||||
page
|
||||
}) => {
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles("./test/files/cheetah1.jpg");
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.keyboard.press("Enter");
|
||||
test(`message format ${msg_format} - images uploaded by a user should be shown in the chat`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles("./test/files/cheetah1.jpg");
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.keyboard.press("Enter");
|
||||
|
||||
const user_message_locator = await page.getByTestId("user").first();
|
||||
const user_message = await user_message_locator.elementHandle();
|
||||
if (user_message) {
|
||||
const imageContainer = await user_message.$("div.image-container");
|
||||
const user_message_locator = await page.getByTestId("user").first();
|
||||
const user_message = await user_message_locator.elementHandle();
|
||||
if (user_message) {
|
||||
const imageContainer = await user_message.$("div.image-container");
|
||||
|
||||
if (imageContainer) {
|
||||
const imgElement = await imageContainer.$("img");
|
||||
if (imgElement) {
|
||||
const image_src = await imgElement.getAttribute("src");
|
||||
expect(image_src).toBeTruthy();
|
||||
if (imageContainer) {
|
||||
const imgElement = await imageContainer.$("img");
|
||||
if (imgElement) {
|
||||
const image_src = await imgElement.getAttribute("src");
|
||||
expect(image_src).toBeTruthy();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
expect(bot_message).toBeTruthy();
|
||||
});
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
|
||||
test("Users can upload multiple images and they will be shown as thumbnails", async ({
|
||||
page
|
||||
}) => {
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles([
|
||||
"./test/files/cheetah1.jpg",
|
||||
"./test/files/cheetah1.jpg"
|
||||
]);
|
||||
expect
|
||||
.poll(async () => await page.locator("thumbnail-image").count(), {
|
||||
timeout: 5000
|
||||
})
|
||||
.toEqual(2);
|
||||
});
|
||||
expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("audio uploaded by a user should be shown in the chatbot", async ({
|
||||
page
|
||||
}) => {
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles("../../test/test_files/audio_sample.wav");
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.keyboard.press("Enter");
|
||||
test(`message format ${msg_format} - audio uploaded by a user should be shown in the chatbot`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles("../../test/test_files/audio_sample.wav");
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.keyboard.press("Enter");
|
||||
|
||||
const user_message = await page.getByTestId("user").first().locator("audio");
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
const audio_data = await user_message.getAttribute("src");
|
||||
await expect(audio_data).toBeTruthy();
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.locator("audio");
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
const audio_data = await user_message.getAttribute("src");
|
||||
await expect(audio_data).toBeTruthy();
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("videos uploaded by a user should be shown in the chatbot", async ({
|
||||
page
|
||||
}) => {
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles("../../test/test_files/video_sample.mp4");
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.keyboard.press("Enter");
|
||||
test(`message format ${msg_format} - videos uploaded by a user should be shown in the chatbot`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles("../../test/test_files/video_sample.mp4");
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.keyboard.press("Enter");
|
||||
|
||||
const user_message = await page.getByTestId("user").first().locator("video");
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
const video_data = await user_message.getAttribute("src");
|
||||
await expect(video_data).toBeTruthy();
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.locator("video");
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
const video_data = await user_message.getAttribute("src");
|
||||
await expect(video_data).toBeTruthy();
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("markdown input by a user should be correctly formatted: bold, italics, links", async ({
|
||||
page
|
||||
}) => {
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill(
|
||||
"This is **bold text**. This is *italic text*. This is a [link](https://gradio.app)."
|
||||
);
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toEqual(
|
||||
'This is <strong>bold text</strong>. This is <em>italic text</em>. This is a <a href="https://gradio.app" target="_blank" rel="noopener noreferrer">link</a>.'
|
||||
);
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
test(`message format ${msg_format} - markdown input by a user should be correctly formatted: bold, italics, links`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill(
|
||||
"This is **bold text**. This is *italic text*. This is a [link](https://gradio.app)."
|
||||
);
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toEqual(
|
||||
'This is <strong>bold text</strong>. This is <em>italic text</em>. This is a <a href="https://gradio.app" target="_blank" rel="noopener noreferrer">link</a>.'
|
||||
);
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("inline code markdown input by the user should be correctly formatted", async ({
|
||||
page
|
||||
}) => {
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("This is `code`.");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toEqual("This is <code>code</code>.");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
test(`message format ${msg_format} - inline code markdown input by the user should be correctly formatted`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("This is `code`.");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toEqual("This is <code>code</code>.");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("markdown code blocks input by a user should be rendered correctly with the correct language tag", async ({
|
||||
page
|
||||
}) => {
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("```python\nprint('Hello')\nprint('World!')\n```");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.locator("pre")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toContain("language-python");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
test(`message format ${msg_format} - markdown code blocks input by a user should be rendered correctly with the correct language tag`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("```python\nprint('Hello')\nprint('World!')\n```");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.locator("pre")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toContain("language-python");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("LaTeX input by a user should be rendered correctly", async ({ page }) => {
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("This is LaTeX $$x^2$$");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toContain("katex-display");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
test(`message format ${msg_format} - LaTeX input by a user should be rendered correctly`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
await textbox.fill("This is LaTeX $$x^2$$");
|
||||
await page.keyboard.press("Enter");
|
||||
const user_message = await page
|
||||
.getByTestId("user")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.innerHTML();
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph")
|
||||
.textContent();
|
||||
await expect(user_message).toContain("katex-display");
|
||||
await expect(bot_message).toBeTruthy();
|
||||
});
|
||||
|
||||
test("when a new message is sent the chatbot should scroll to the latest message", async ({
|
||||
page
|
||||
}) => {
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
const line_break = "<br>";
|
||||
await textbox.fill(line_break.repeat(30));
|
||||
await page.keyboard.press("Enter");
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph");
|
||||
await expect(bot_message).toBeVisible();
|
||||
const bot_message_text = bot_message.textContent();
|
||||
await expect(bot_message_text).toBeTruthy();
|
||||
});
|
||||
test(`message format ${msg_format} - when a new message is sent the chatbot should scroll to the latest message`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const textbox = await page.getByTestId("textbox");
|
||||
const line_break = "<br>";
|
||||
await textbox.fill(line_break.repeat(30));
|
||||
await page.keyboard.press("Enter");
|
||||
const bot_message = await page
|
||||
.getByTestId("bot")
|
||||
.first()
|
||||
.getByRole("paragraph");
|
||||
await expect(bot_message).toBeVisible();
|
||||
const bot_message_text = bot_message.textContent();
|
||||
await expect(bot_message_text).toBeTruthy();
|
||||
});
|
||||
|
||||
test("chatbot like and dislike functionality", async ({ page }) => {
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.getByTestId("textbox").fill("hello");
|
||||
await page.keyboard.press("Enter");
|
||||
await page.getByLabel("like", { exact: true }).click();
|
||||
await page.getByLabel("dislike").click();
|
||||
test(`message format ${msg_format} - chatbot like and dislike functionality`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
await page.getByTestId("textbox").click();
|
||||
await page.getByTestId("textbox").fill("hello");
|
||||
await page.keyboard.press("Enter");
|
||||
await page.getByLabel("like", { exact: true }).click();
|
||||
await page.getByLabel("dislike").click();
|
||||
|
||||
expect(await page.getByLabel("clicked dislike").count()).toEqual(1);
|
||||
expect(await page.getByLabel("clicked like").count()).toEqual(0);
|
||||
});
|
||||
expect(await page.getByLabel("clicked dislike").count()).toEqual(1);
|
||||
expect(await page.getByLabel("clicked like").count()).toEqual(0);
|
||||
});
|
||||
|
||||
test(`message format ${msg_format} - Users can upload multiple images and they will be shown as thumbnails`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
|
||||
const fileChooserPromise = page.waitForEvent("filechooser");
|
||||
await page.getByTestId("upload-button").click();
|
||||
const fileChooser = await fileChooserPromise;
|
||||
await fileChooser.setFiles([
|
||||
"./test/files/cheetah1.jpg",
|
||||
"./test/files/cheetah1.jpg"
|
||||
]);
|
||||
expect
|
||||
.poll(async () => await page.locator("thumbnail-image").count(), {
|
||||
timeout: 5000
|
||||
})
|
||||
.toEqual(2);
|
||||
});
|
||||
}
|
||||
|
16
js/app/test/chatbot_with_tools.spec.ts
Normal file
16
js/app/test/chatbot_with_tools.spec.ts
Normal file
@ -0,0 +1,16 @@
|
||||
import { test, expect } from "@gradio/tootils";
|
||||
|
||||
test("Chatbot can support agentic demos by displaying messages with metadata", async ({
|
||||
page
|
||||
}) => {
|
||||
await page.getByRole("button", { name: "Get San Francisco Weather" }).click();
|
||||
await expect(
|
||||
await page.locator("button").filter({ hasText: "💥 Error" }).nth(1)
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.locator("span").filter({ hasText: "🛠️ Used tool" })
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.locator("button").filter({ hasText: "It's a sunny day in San" })
|
||||
).toBeVisible();
|
||||
});
|
@ -1,82 +1,95 @@
|
||||
import { test, expect } from "@gradio/tootils";
|
||||
import { test, expect, go_to_testcase } from "@gradio/tootils";
|
||||
|
||||
test("chatinterface works with streaming functions and all buttons behave as expected", async ({
|
||||
page
|
||||
}) => {
|
||||
const submit_button = page.getByRole("button", { name: "Submit" });
|
||||
const retry_button = page.getByRole("button", { name: "🔄 Retry" });
|
||||
const undo_button = page.getByRole("button", { name: "↩️ Undo" });
|
||||
const clear_button = page.getByRole("button", { name: "🗑️ Clear" });
|
||||
const textbox = page.getByPlaceholder("Type a message...");
|
||||
for (const msg_format of ["tuples", "messages"]) {
|
||||
test(`msg format ${msg_format} chatinterface works with streaming functions and all buttons behave as expected`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const submit_button = page.getByRole("button", { name: "Submit" });
|
||||
const retry_button = page.getByRole("button", { name: "🔄 Retry" });
|
||||
const undo_button = page.getByRole("button", { name: "↩️ Undo" });
|
||||
const clear_button = page.getByRole("button", { name: "🗑️ Clear" });
|
||||
const textbox = page.getByPlaceholder("Type a message...");
|
||||
|
||||
await textbox.fill("hello");
|
||||
await submit_button.click();
|
||||
await textbox.fill("hello");
|
||||
await submit_button.click();
|
||||
|
||||
await expect(textbox).toHaveValue("");
|
||||
const expected_text_el_0 = page.locator(".bot p", {
|
||||
hasText: "Run 1 - You typed: hello"
|
||||
await expect(textbox).toHaveValue("");
|
||||
const expected_text_el_0 = page.locator(".bot p", {
|
||||
hasText: "Run 1 - You typed: hello"
|
||||
});
|
||||
await expect(expected_text_el_0).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(1);
|
||||
|
||||
await textbox.fill("hi");
|
||||
await submit_button.click();
|
||||
await expect(textbox).toHaveValue("");
|
||||
const expected_text_el_1 = page.locator(".bot p", {
|
||||
hasText: "Run 2 - You typed: hi"
|
||||
});
|
||||
await expect(expected_text_el_1).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(2);
|
||||
|
||||
await undo_button.click();
|
||||
await expect
|
||||
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
|
||||
.toBe(1);
|
||||
await expect(textbox).toHaveValue("hi");
|
||||
|
||||
await retry_button.click();
|
||||
const expected_text_el_2 = page.locator(".bot p", {
|
||||
hasText: ""
|
||||
});
|
||||
await expect(expected_text_el_2).toBeVisible();
|
||||
|
||||
await expect
|
||||
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
|
||||
.toBe(1);
|
||||
|
||||
await textbox.fill("hi");
|
||||
await submit_button.click();
|
||||
await expect(textbox).toHaveValue("");
|
||||
const expected_text_el_3 = page.locator(".bot p", {
|
||||
hasText: "Run 4 - You typed: hi"
|
||||
});
|
||||
await expect(expected_text_el_3).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(2);
|
||||
|
||||
await clear_button.click();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 5000 })
|
||||
.toBe(0);
|
||||
});
|
||||
await expect(expected_text_el_0).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(1);
|
||||
|
||||
await textbox.fill("hi");
|
||||
await submit_button.click();
|
||||
await expect(textbox).toHaveValue("");
|
||||
const expected_text_el_1 = page.locator(".bot p", {
|
||||
hasText: "Run 2 - You typed: hi"
|
||||
test(`msg format ${msg_format} the api recorder correctly records the api calls`, async ({
|
||||
page
|
||||
}) => {
|
||||
if (msg_format === "messages") {
|
||||
await go_to_testcase(page, "messages");
|
||||
}
|
||||
const textbox = page.getByPlaceholder("Type a message...");
|
||||
const submit_button = page.getByRole("button", { name: "Submit" });
|
||||
await textbox.fill("hi");
|
||||
|
||||
await page.getByRole("button", { name: "Use via API logo" }).click();
|
||||
await page.locator("#start-api-recorder").click();
|
||||
await submit_button.click();
|
||||
await expect(textbox).toHaveValue("");
|
||||
await expect(page.locator(".bot p").first()).toContainText(
|
||||
/\- You typed: hi/
|
||||
);
|
||||
const api_recorder = await page.locator("#api-recorder");
|
||||
await api_recorder.click();
|
||||
await expect(page.locator("#num-recorded-api-calls")).toContainText(
|
||||
"🪄 Recorded API Calls [5]"
|
||||
);
|
||||
});
|
||||
await expect(expected_text_el_1).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(2);
|
||||
|
||||
await undo_button.click();
|
||||
await expect
|
||||
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
|
||||
.toBe(1);
|
||||
await expect(textbox).toHaveValue("hi");
|
||||
|
||||
await retry_button.click();
|
||||
const expected_text_el_2 = page.locator(".bot p", {
|
||||
hasText: ""
|
||||
});
|
||||
await expect(expected_text_el_2).toBeVisible();
|
||||
|
||||
await expect
|
||||
.poll(async () => page.locator(".message.bot").count(), { timeout: 5000 })
|
||||
.toBe(1);
|
||||
|
||||
await textbox.fill("hi");
|
||||
await submit_button.click();
|
||||
await expect(textbox).toHaveValue("");
|
||||
const expected_text_el_3 = page.locator(".bot p", {
|
||||
hasText: "Run 4 - You typed: hi"
|
||||
});
|
||||
await expect(expected_text_el_3).toBeVisible();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 2000 })
|
||||
.toBe(2);
|
||||
|
||||
await clear_button.click();
|
||||
await expect
|
||||
.poll(async () => page.locator(".bot.message").count(), { timeout: 5000 })
|
||||
.toBe(0);
|
||||
});
|
||||
|
||||
test("the api recorder correctly records the api calls", async ({ page }) => {
|
||||
const textbox = page.getByPlaceholder("Type a message...");
|
||||
const submit_button = page.getByRole("button", { name: "Submit" });
|
||||
await textbox.fill("hi");
|
||||
|
||||
await page.getByRole("button", { name: "Use via API logo" }).click();
|
||||
await page.locator("#start-api-recorder").click();
|
||||
await submit_button.click();
|
||||
await expect(textbox).toHaveValue("");
|
||||
const api_recorder = await page.locator("#api-recorder");
|
||||
await api_recorder.click();
|
||||
|
||||
const num_calls = await page.locator("#num-recorded-api-calls").innerText();
|
||||
await expect(num_calls).toBe("🪄 Recorded API Calls [5]");
|
||||
});
|
||||
}
|
||||
|
@ -11,17 +11,19 @@
|
||||
import { Chat } from "@gradio/icons";
|
||||
import type { FileData } from "@gradio/client";
|
||||
import { StatusTracker } from "@gradio/statustracker";
|
||||
import type {
|
||||
Message,
|
||||
TupleFormat,
|
||||
MessageRole,
|
||||
NormalisedMessage
|
||||
} from "./types";
|
||||
|
||||
import {
|
||||
type messages,
|
||||
type NormalisedMessage,
|
||||
normalise_messages
|
||||
} from "./shared/utils";
|
||||
import { normalise_tuples, normalise_messages } from "./shared/utils";
|
||||
|
||||
export let elem_id = "";
|
||||
export let elem_classes: string[] = [];
|
||||
export let visible = true;
|
||||
export let value: messages = [];
|
||||
export let value: TupleFormat | Message[] = [];
|
||||
export let scale: number | null = null;
|
||||
export let min_width: number | undefined = undefined;
|
||||
export let label: string;
|
||||
@ -35,6 +37,7 @@
|
||||
export let sanitize_html = true;
|
||||
export let bubble_full_width = true;
|
||||
export let layout: "bubble" | "panel" = "bubble";
|
||||
export let msg_format: "tuples" | "messages" = "tuples";
|
||||
export let render_markdown = true;
|
||||
export let line_breaks = true;
|
||||
export let latex_delimiters: {
|
||||
@ -52,9 +55,12 @@
|
||||
}>;
|
||||
export let avatar_images: [FileData | null, FileData | null] = [null, null];
|
||||
|
||||
let _value: [NormalisedMessage, NormalisedMessage][] | null = [];
|
||||
let _value: NormalisedMessage[] | null = [];
|
||||
|
||||
$: _value = normalise_messages(value, root);
|
||||
$: _value =
|
||||
msg_format === "tuples"
|
||||
? normalise_tuples(value as TupleFormat, root)
|
||||
: normalise_messages(value as Message[], root);
|
||||
|
||||
export let loading_status: LoadingStatus | undefined = undefined;
|
||||
export let height = 400;
|
||||
|
@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { format_chat_for_sharing, type NormalisedMessage } from "./utils";
|
||||
import { format_chat_for_sharing } from "./utils";
|
||||
import type { NormalisedMessage } from "../types";
|
||||
import { Gradio, copy } from "@gradio/utils";
|
||||
|
||||
import { dequal } from "dequal/lite";
|
||||
@ -16,10 +17,16 @@
|
||||
|
||||
import { Clear } from "@gradio/icons";
|
||||
import type { SelectData, LikeData } from "@gradio/utils";
|
||||
import type { MessageRole, ComponentMessage, ComponentData } from "../types";
|
||||
import { MarkdownCode as Markdown } from "@gradio/markdown";
|
||||
import { type FileData, type Client } from "@gradio/client";
|
||||
import type { FileData, Client } from "@gradio/client";
|
||||
import type { I18nFormatter } from "js/app/src/gradio_helper";
|
||||
import Pending from "./Pending.svelte";
|
||||
import MessageBox from "./MessageBox.svelte";
|
||||
|
||||
export let value: NormalisedMessage[] | null = [];
|
||||
let old_value: NormalisedMessage[] | null = null;
|
||||
|
||||
import Component from "./Component.svelte";
|
||||
import LikeButtons from "./ButtonPanel.svelte";
|
||||
import type { LoadedComponent } from "../../app/src/types";
|
||||
@ -51,27 +58,19 @@
|
||||
|
||||
$: load_components(get_components_from_messages(value));
|
||||
|
||||
function get_components_from_messages(messages: typeof value): string[] {
|
||||
function get_components_from_messages(
|
||||
messages: NormalisedMessage[] | null
|
||||
): string[] {
|
||||
if (!messages) return [];
|
||||
let components: Set<string> = new Set();
|
||||
messages.forEach((message_pair) => {
|
||||
message_pair.forEach((message) => {
|
||||
if (
|
||||
typeof message === "object" &&
|
||||
message !== null &&
|
||||
"component" in message
|
||||
) {
|
||||
components.add(message.component);
|
||||
}
|
||||
});
|
||||
messages.forEach((message) => {
|
||||
if (message.type === "component") {
|
||||
components.add(message.content.component);
|
||||
}
|
||||
});
|
||||
|
||||
return Array.from(components);
|
||||
}
|
||||
|
||||
export let value: [NormalisedMessage, NormalisedMessage][] | null = [];
|
||||
let old_value: [NormalisedMessage, NormalisedMessage][] | null = null;
|
||||
|
||||
export let latex_delimiters: {
|
||||
left: string;
|
||||
right: string;
|
||||
@ -173,45 +172,76 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handle_select(
|
||||
i: number,
|
||||
j: number,
|
||||
message: NormalisedMessage
|
||||
): void {
|
||||
function handle_select(i: number, message: NormalisedMessage): void {
|
||||
dispatch("select", {
|
||||
index: [i, j],
|
||||
value: message
|
||||
index: message.index,
|
||||
value: message.content
|
||||
});
|
||||
}
|
||||
|
||||
function handle_like(
|
||||
i: number,
|
||||
j: number,
|
||||
message: NormalisedMessage,
|
||||
selected: string | null
|
||||
): void {
|
||||
dispatch("like", {
|
||||
index: [i, j],
|
||||
value: message,
|
||||
index: message.index,
|
||||
value: message.content,
|
||||
liked: selected === "like"
|
||||
});
|
||||
}
|
||||
|
||||
function get_message_label_data(message: NormalisedMessage): string {
|
||||
if (message.type === "text") {
|
||||
return message.value;
|
||||
} else if (message.type === "component") {
|
||||
return `a component of type ${message.component}`;
|
||||
} else if (message.type === "file") {
|
||||
if (Array.isArray(message.file)) {
|
||||
return `file of extension type: ${message.file[0].orig_name?.split(".").pop()}`;
|
||||
return message.content;
|
||||
} else if (
|
||||
message.type === "component" &&
|
||||
message.content.component === "file"
|
||||
) {
|
||||
if (Array.isArray(message.content.value)) {
|
||||
return `file of extension type: ${message.content.value[0].orig_name?.split(".").pop()}`;
|
||||
}
|
||||
return (
|
||||
`file of extension type: ${message.file?.orig_name?.split(".").pop()}` +
|
||||
(message.file?.orig_name ?? "")
|
||||
`file of extension type: ${message.content.value?.orig_name?.split(".").pop()}` +
|
||||
(message.content.value?.orig_name ?? "")
|
||||
);
|
||||
}
|
||||
return `a message of type ` + message.type ?? "unknown";
|
||||
return `a component of type ${message.content.component ?? "unknown"}`;
|
||||
}
|
||||
|
||||
function is_component_message(
|
||||
message: NormalisedMessage
|
||||
): message is ComponentMessage {
|
||||
return message.type === "component";
|
||||
}
|
||||
|
||||
function group_messages(
|
||||
messages: NormalisedMessage[]
|
||||
): NormalisedMessage[][] {
|
||||
const groupedMessages: NormalisedMessage[][] = [];
|
||||
let currentGroup: NormalisedMessage[] = [];
|
||||
let currentRole: MessageRole | null = null;
|
||||
|
||||
for (const message of messages) {
|
||||
if (!(message.role === "assistant" || message.role === "user")) {
|
||||
continue;
|
||||
}
|
||||
if (message.role === currentRole) {
|
||||
currentGroup.push(message);
|
||||
} else {
|
||||
if (currentGroup.length > 0) {
|
||||
groupedMessages.push(currentGroup);
|
||||
}
|
||||
currentGroup = [message];
|
||||
currentRole = message.role;
|
||||
}
|
||||
}
|
||||
|
||||
if (currentGroup.length > 0) {
|
||||
groupedMessages.push(currentGroup);
|
||||
}
|
||||
|
||||
return groupedMessages;
|
||||
}
|
||||
</script>
|
||||
|
||||
@ -237,132 +267,137 @@
|
||||
>
|
||||
<div class="message-wrap" use:copy>
|
||||
{#if value !== null && value.length > 0}
|
||||
{#each value as message_pair, i}
|
||||
{#each message_pair as message, j}
|
||||
{#if message.type !== "empty"}
|
||||
{#if is_image_preview_open}
|
||||
<div class="image-preview">
|
||||
<img
|
||||
src={image_preview_source}
|
||||
alt={image_preview_source_alt}
|
||||
/>
|
||||
<button
|
||||
class="image-preview-close-button"
|
||||
on:click={() => {
|
||||
is_image_preview_open = false;
|
||||
}}><Clear /></button
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
<div
|
||||
class="message-row {layout} {j == 0 ? 'user-row' : 'bot-row'}"
|
||||
class:with_avatar={avatar_images[j] !== null}
|
||||
class:with_opposite_avatar={avatar_images[j === 0 ? 1 : 0] !==
|
||||
null}
|
||||
{@const groupedMessages = group_messages(value)}
|
||||
{#each groupedMessages as messages, i}
|
||||
{@const role = messages[0].role === "user" ? "user" : "bot"}
|
||||
{@const avatar_img = avatar_images[role === "user" ? 0 : 1]}
|
||||
{@const opposite_avatar_img = avatar_images[role === "user" ? 0 : 1]}
|
||||
{#if is_image_preview_open}
|
||||
<div class="image-preview">
|
||||
<img src={image_preview_source} alt={image_preview_source_alt} />
|
||||
<button
|
||||
class="image-preview-close-button"
|
||||
on:click={() => {
|
||||
is_image_preview_open = false;
|
||||
}}><Clear /></button
|
||||
>
|
||||
{#if avatar_images[j] !== null}
|
||||
<div class="avatar-container">
|
||||
<Image
|
||||
class="avatar-image"
|
||||
src={avatar_images[j]?.url}
|
||||
alt="{j == 0 ? 'user' : 'bot'} avatar"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
<div class="flex-wrap">
|
||||
<div
|
||||
class="message {j == 0 ? 'user' : 'bot'} {typeof message ===
|
||||
'object' &&
|
||||
message !== null &&
|
||||
'component' in message
|
||||
? message?.component
|
||||
: ''}"
|
||||
class:message-fit={!bubble_full_width}
|
||||
class:panel-full-width={true}
|
||||
</div>
|
||||
{/if}
|
||||
<div
|
||||
class="message-row {layout} {role}-row"
|
||||
class:with_avatar={avatar_img !== null}
|
||||
class:with_opposite_avatar={opposite_avatar_img !== null}
|
||||
>
|
||||
{#if avatar_img !== null}
|
||||
<div class="avatar-container">
|
||||
<Image
|
||||
class="avatar-image"
|
||||
src={avatar_img?.url}
|
||||
alt="{role} avatar"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
<div class="flex-wrap">
|
||||
{#each messages as message, thought_index}
|
||||
{@const msg_type = messages[0].type}
|
||||
<div
|
||||
class="message {role} {is_component_message(message)
|
||||
? message?.content.component
|
||||
: ''}"
|
||||
class:message-fit={!bubble_full_width}
|
||||
class:panel-full-width={true}
|
||||
class:message-markdown-disabled={!render_markdown}
|
||||
style:text-align={rtl && role === "user" ? "left" : "right"}
|
||||
class:component={msg_type === "component"}
|
||||
class:html={is_component_message(message) &&
|
||||
message.content.component === "html"}
|
||||
class:thought={thought_index > 0}
|
||||
>
|
||||
<button
|
||||
data-testid={role}
|
||||
class:latest={i === value.length - 1}
|
||||
class:message-markdown-disabled={!render_markdown}
|
||||
style:text-align={rtl && j == 0 ? "left" : "right"}
|
||||
class:component={typeof message === "object" &&
|
||||
message !== null &&
|
||||
"component" in message}
|
||||
class:html={typeof message === "object" &&
|
||||
message !== null &&
|
||||
"component" in message &&
|
||||
message.component === "html"}
|
||||
style:user-select="text"
|
||||
class:selectable
|
||||
style:text-align={rtl ? "right" : "left"}
|
||||
on:click={() => handle_select(i, message)}
|
||||
on:keydown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
handle_select(i, message);
|
||||
}
|
||||
}}
|
||||
dir={rtl ? "rtl" : "ltr"}
|
||||
aria-label={role +
|
||||
"'s message: " +
|
||||
get_message_label_data(message)}
|
||||
>
|
||||
<button
|
||||
data-testid={j == 0 ? "user" : "bot"}
|
||||
class:latest={i === value.length - 1}
|
||||
class:message-markdown-disabled={!render_markdown}
|
||||
style:user-select="text"
|
||||
class:selectable
|
||||
style:text-align={rtl ? "right" : "left"}
|
||||
on:click={() => handle_select(i, j, message)}
|
||||
on:keydown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
handle_select(i, j, message);
|
||||
}
|
||||
}}
|
||||
dir={rtl ? "rtl" : "ltr"}
|
||||
aria-label={(j == 0 ? "user" : "bot") +
|
||||
"'s message: " +
|
||||
get_message_label_data(message)}
|
||||
>
|
||||
{#if message.type === "text"}
|
||||
{#if message.type === "text"}
|
||||
{#if message.metadata.title}
|
||||
<MessageBox title={message.metadata.title}>
|
||||
<Markdown
|
||||
message={message.content}
|
||||
{latex_delimiters}
|
||||
{sanitize_html}
|
||||
{render_markdown}
|
||||
{line_breaks}
|
||||
on:load={scroll}
|
||||
/>
|
||||
</MessageBox>
|
||||
{:else}
|
||||
<Markdown
|
||||
message={message.value}
|
||||
message={message.content}
|
||||
{latex_delimiters}
|
||||
{sanitize_html}
|
||||
{render_markdown}
|
||||
{line_breaks}
|
||||
on:load={scroll}
|
||||
/>
|
||||
{:else if message.type === "component" && message.component in _components}
|
||||
<Component
|
||||
{target}
|
||||
{theme_mode}
|
||||
props={message.props}
|
||||
type={message.component}
|
||||
components={_components}
|
||||
value={message.value}
|
||||
{i18n}
|
||||
{upload}
|
||||
{_fetch}
|
||||
on:load={scroll}
|
||||
/>
|
||||
{:else if message.type === "component" && message.component === "file"}
|
||||
<a
|
||||
data-testid="chatbot-file"
|
||||
class="file-pil"
|
||||
href={message.value.url}
|
||||
target="_blank"
|
||||
download={window.__is_colab__
|
||||
? null
|
||||
: message.value?.orig_name ||
|
||||
message.value?.path.split("/").pop() ||
|
||||
"file"}
|
||||
>
|
||||
{message.value?.orig_name ||
|
||||
message.value?.path.split("/").pop() ||
|
||||
"file"}
|
||||
</a>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
<LikeButtons
|
||||
show={j === 1 && (likeable || show_copy_button)}
|
||||
handle_action={(selected) =>
|
||||
handle_like(i, j, message, selected)}
|
||||
{likeable}
|
||||
{show_copy_button}
|
||||
{message}
|
||||
position={j === 0 ? "right" : "left"}
|
||||
avatar={avatar_images[j]}
|
||||
{layout}
|
||||
/>
|
||||
{:else if message.type === "component" && message.content.component in _components}
|
||||
<Component
|
||||
{target}
|
||||
{theme_mode}
|
||||
props={message.content.props}
|
||||
type={message.content.component}
|
||||
components={_components}
|
||||
value={message.content.value}
|
||||
{i18n}
|
||||
{upload}
|
||||
{_fetch}
|
||||
on:load={scroll}
|
||||
/>
|
||||
{:else if message.type === "component" && message.content.component === "file"}
|
||||
<a
|
||||
data-testid="chatbot-file"
|
||||
class="file-pil"
|
||||
href={message.content.value.url}
|
||||
target="_blank"
|
||||
download={window.__is_colab__
|
||||
? null
|
||||
: message.content.value?.orig_name ||
|
||||
message.content.value?.path.split("/").pop() ||
|
||||
"file"}
|
||||
>
|
||||
{message.content.value?.orig_name ||
|
||||
message.content.value?.path.split("/").pop() ||
|
||||
"file"}
|
||||
</a>
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
<LikeButtons
|
||||
show={role === "bot" && (likeable || show_copy_button)}
|
||||
handle_action={(selected) => handle_like(i, message, selected)}
|
||||
{likeable}
|
||||
{show_copy_button}
|
||||
{message}
|
||||
position={role === "user" ? "right" : "left"}
|
||||
avatar={avatar_img}
|
||||
{layout}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
{#if pending_message}
|
||||
<Pending {layout} />
|
||||
@ -430,6 +465,10 @@
|
||||
overflow-wrap: break-word;
|
||||
}
|
||||
|
||||
.thought {
|
||||
margin-top: var(--spacing-xxl);
|
||||
}
|
||||
|
||||
.message :global(.prose) {
|
||||
font-size: var(--chatbot-body-text-size);
|
||||
}
|
||||
@ -610,6 +649,37 @@
|
||||
max-height: 200px;
|
||||
}
|
||||
|
||||
.message-wrap .message :global(a) {
|
||||
color: var(--color-text-link);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.message-wrap .bot :global(table),
|
||||
.message-wrap .bot :global(tr),
|
||||
.message-wrap .bot :global(td),
|
||||
.message-wrap .bot :global(th) {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
}
|
||||
|
||||
.message-wrap .user :global(table),
|
||||
.message-wrap .user :global(tr),
|
||||
.message-wrap .user :global(td),
|
||||
.message-wrap .user :global(th) {
|
||||
border: 1px solid var(--border-color-accent);
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
.message-wrap :global(ol),
|
||||
.message-wrap :global(ul) {
|
||||
padding-inline-start: 2em;
|
||||
}
|
||||
|
||||
/* KaTeX */
|
||||
.message-wrap :global(span.katex) {
|
||||
font-size: var(--text-lg);
|
||||
direction: ltr;
|
||||
}
|
||||
|
||||
/* Copy button */
|
||||
.message-wrap :global(div[class*="code_wrap"] > button) {
|
||||
position: absolute;
|
||||
|
55
js/chatbot/shared/MessageBox.svelte
Normal file
55
js/chatbot/shared/MessageBox.svelte
Normal file
@ -0,0 +1,55 @@
|
||||
<script lang="ts">
|
||||
let expanded = false;
|
||||
export let title: string;
|
||||
|
||||
function toggleExpanded(): void {
|
||||
expanded = !expanded;
|
||||
}
|
||||
</script>
|
||||
|
||||
<button class="box" on:click={toggleExpanded}>
|
||||
<div class="title">
|
||||
<span class="title-text">{title}</span>
|
||||
<span
|
||||
style:transform={expanded ? "rotate(0)" : "rotate(90deg)"}
|
||||
class="arrow"
|
||||
>
|
||||
▼
|
||||
</span>
|
||||
</div>
|
||||
{#if expanded}
|
||||
<div class="content">
|
||||
<slot></slot>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<style>
|
||||
.box {
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
max-width: max-content;
|
||||
background: var(--color-accent-soft);
|
||||
}
|
||||
|
||||
.title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 8px;
|
||||
color: var(--body-text-color);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.content {
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
.title-text {
|
||||
padding-right: var(--spacing-lg);
|
||||
}
|
||||
|
||||
.arrow {
|
||||
margin-left: auto;
|
||||
opacity: 0.8;
|
||||
}
|
||||
</style>
|
@ -1,5 +1,13 @@
|
||||
import type { FileData } from "@gradio/client";
|
||||
import { uploadToHuggingFace } from "@gradio/utils";
|
||||
import type {
|
||||
TupleFormat,
|
||||
ComponentMessage,
|
||||
ComponentData,
|
||||
TextMessage,
|
||||
NormalisedMessage,
|
||||
Message
|
||||
} from "../types";
|
||||
|
||||
export const format_chat_for_sharing = async (
|
||||
chat: [string | FileData | null, string | FileData | null][]
|
||||
@ -56,51 +64,6 @@ export const format_chat_for_sharing = async (
|
||||
.join("\n");
|
||||
};
|
||||
|
||||
export interface ComponentMessage {
|
||||
type: "component";
|
||||
component: string;
|
||||
value: any;
|
||||
constructor_args: any;
|
||||
props: any;
|
||||
id: string;
|
||||
}
|
||||
export interface TextMessage {
|
||||
type: "text";
|
||||
value: string;
|
||||
id: string;
|
||||
}
|
||||
|
||||
export interface FileMessage {
|
||||
type: "file";
|
||||
file: FileData | FileData[];
|
||||
alt_text: string | null;
|
||||
id: string;
|
||||
}
|
||||
|
||||
export interface EmptyMessage {
|
||||
type: "empty";
|
||||
value: null;
|
||||
id: string;
|
||||
}
|
||||
|
||||
export type NormalisedMessage =
|
||||
| TextMessage
|
||||
| FileMessage
|
||||
| ComponentMessage
|
||||
| EmptyMessage;
|
||||
|
||||
export type message_data =
|
||||
| string
|
||||
| { file: FileData | FileData[]; alt_text: string | null }
|
||||
| { component: string; value: any; constructor_args: any; props: any }
|
||||
| null;
|
||||
|
||||
export type messages = [message_data, message_data][] | null;
|
||||
|
||||
function make_id(): string {
|
||||
return Math.random().toString(36).substring(7);
|
||||
}
|
||||
|
||||
const redirect_src_url = (src: string, root: string): string =>
|
||||
src.replace('src="/file', `src="${root}file`);
|
||||
|
||||
@ -114,37 +77,82 @@ function get_component_for_mime_type(
|
||||
return "file";
|
||||
}
|
||||
|
||||
function convert_file_message_to_component_message(
|
||||
message: any
|
||||
): ComponentData {
|
||||
const _file = Array.isArray(message.file) ? message.file[0] : message.file;
|
||||
return {
|
||||
component: get_component_for_mime_type(_file?.mime_type),
|
||||
value: message.file,
|
||||
alt_text: message.alt_text,
|
||||
constructor_args: {},
|
||||
props: {}
|
||||
} as ComponentData;
|
||||
}
|
||||
|
||||
export function normalise_messages(
|
||||
messages: messages,
|
||||
messages: Message[] | null,
|
||||
root: string
|
||||
): [NormalisedMessage, NormalisedMessage][] | null {
|
||||
if (messages === null) return null;
|
||||
return messages.map((message_pair) => {
|
||||
return message_pair.map((message) => {
|
||||
if (message == null) return { value: null, id: make_id(), type: "empty" };
|
||||
): NormalisedMessage[] | null {
|
||||
if (messages === null) return messages;
|
||||
return messages.map((message, i) => {
|
||||
if (typeof message.content === "string") {
|
||||
return {
|
||||
role: message.role,
|
||||
metadata: message.metadata,
|
||||
content: redirect_src_url(message.content, root),
|
||||
type: "text",
|
||||
index: i
|
||||
};
|
||||
} else if ("file" in message.content) {
|
||||
return {
|
||||
content: convert_file_message_to_component_message(message.content),
|
||||
metadata: message.metadata,
|
||||
role: message.role,
|
||||
type: "component",
|
||||
index: i
|
||||
};
|
||||
}
|
||||
return { type: "component", ...message } as ComponentMessage;
|
||||
});
|
||||
}
|
||||
|
||||
export function normalise_tuples(
|
||||
messages: TupleFormat,
|
||||
root: string
|
||||
): NormalisedMessage[] | null {
|
||||
if (messages === null) return messages;
|
||||
const msg = messages.flatMap((message_pair, i) => {
|
||||
return message_pair.map((message, index) => {
|
||||
if (message == null) return null;
|
||||
const role = index == 0 ? "user" : "assistant";
|
||||
|
||||
if (typeof message === "string") {
|
||||
return {
|
||||
role: role,
|
||||
type: "text",
|
||||
value: redirect_src_url(message, root),
|
||||
id: make_id()
|
||||
};
|
||||
content: redirect_src_url(message, root),
|
||||
metadata: { title: null },
|
||||
index: [i, index]
|
||||
} as TextMessage;
|
||||
}
|
||||
|
||||
if ("file" in message) {
|
||||
const _file = Array.isArray(message.file)
|
||||
? message.file[0]
|
||||
: message.file;
|
||||
return {
|
||||
content: convert_file_message_to_component_message(message),
|
||||
role: role,
|
||||
type: "component",
|
||||
component: get_component_for_mime_type(_file?.mime_type),
|
||||
value: message.file,
|
||||
alt_text: message.alt_text,
|
||||
id: make_id()
|
||||
};
|
||||
index: [i, index]
|
||||
} as ComponentMessage;
|
||||
}
|
||||
|
||||
return { ...message, type: "component", id: make_id() };
|
||||
}) as [NormalisedMessage, NormalisedMessage];
|
||||
return {
|
||||
role: role,
|
||||
content: message,
|
||||
type: "component",
|
||||
index: [i, index]
|
||||
} as ComponentMessage;
|
||||
});
|
||||
});
|
||||
return msg.filter((message) => message != null) as NormalisedMessage[];
|
||||
}
|
||||
|
42
js/chatbot/types.ts
Normal file
42
js/chatbot/types.ts
Normal file
@ -0,0 +1,42 @@
|
||||
import type { FileData } from "@gradio/client";
|
||||
|
||||
export type MessageRole = "system" | "user" | "assistant";
|
||||
|
||||
export interface Metadata {
|
||||
title: string | null;
|
||||
}
|
||||
|
||||
export interface ComponentData {
|
||||
component: string;
|
||||
constructor_args: any;
|
||||
props: any;
|
||||
value: any;
|
||||
alt_text: string | null;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
role: MessageRole;
|
||||
metadata: Metadata;
|
||||
content: string | FileData | ComponentData;
|
||||
index: [number, number] | number;
|
||||
}
|
||||
|
||||
export interface TextMessage extends Message {
|
||||
type: "text";
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface ComponentMessage extends Message {
|
||||
type: "component";
|
||||
content: ComponentData;
|
||||
}
|
||||
|
||||
export type message_data =
|
||||
| string
|
||||
| { file: FileData | FileData[]; alt_text: string | null }
|
||||
| { component: string; value: any; constructor_args: any; props: any }
|
||||
| null;
|
||||
|
||||
export type TupleFormat = [message_data, message_data][] | null;
|
||||
|
||||
export type NormalisedMessage = TextMessage | ComponentMessage;
|
@ -200,3 +200,11 @@ export const drag_and_drop_file = async (
|
||||
await selector.dispatchEvent("drop", { dataTransfer });
|
||||
}
|
||||
};
|
||||
|
||||
export async function go_to_testcase(
|
||||
page: Page,
|
||||
test_case: string
|
||||
): Promise<void> {
|
||||
const url = page.url();
|
||||
await page.goto(`${url.substring(0, url.length - 1)}_${test_case}_testcase`);
|
||||
}
|
||||
|
@ -47,6 +47,7 @@ class TestChatbot:
|
||||
"proxy_url": None,
|
||||
"_selectable": False,
|
||||
"key": None,
|
||||
"msg_format": "tuples",
|
||||
"latex_delimiters": [{"display": True, "left": "$$", "right": "$$"}],
|
||||
"likeable": False,
|
||||
"rtl": False,
|
||||
|
@ -247,35 +247,41 @@ class TestAPI:
|
||||
assert len(api_info["unnamed_endpoints"]) == 0
|
||||
assert "/chat" in api_info["named_endpoints"]
|
||||
|
||||
def test_streaming_api(self, connect):
|
||||
chatbot = gr.ChatInterface(stream).queue()
|
||||
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
|
||||
def test_streaming_api(self, msg_format, connect):
|
||||
chatbot = gr.ChatInterface(stream, msg_format=msg_format).queue()
|
||||
with connect(chatbot) as client:
|
||||
job = client.submit("hello")
|
||||
wait([job])
|
||||
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
|
||||
|
||||
def test_streaming_api_async(self, connect):
|
||||
chatbot = gr.ChatInterface(async_stream).queue()
|
||||
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
|
||||
def test_streaming_api_async(self, msg_format, connect):
|
||||
chatbot = gr.ChatInterface(async_stream, msg_format=msg_format).queue()
|
||||
with connect(chatbot) as client:
|
||||
job = client.submit("hello")
|
||||
wait([job])
|
||||
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
|
||||
|
||||
def test_non_streaming_api(self, connect):
|
||||
chatbot = gr.ChatInterface(double)
|
||||
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
|
||||
def test_non_streaming_api(self, msg_format, connect):
|
||||
chatbot = gr.ChatInterface(double, msg_format=msg_format)
|
||||
with connect(chatbot) as client:
|
||||
result = client.predict("hello")
|
||||
assert result == "hello hello"
|
||||
|
||||
def test_non_streaming_api_async(self, connect):
|
||||
chatbot = gr.ChatInterface(async_greet)
|
||||
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
|
||||
def test_non_streaming_api_async(self, msg_format, connect):
|
||||
chatbot = gr.ChatInterface(async_greet, msg_format=msg_format)
|
||||
with connect(chatbot) as client:
|
||||
result = client.predict("gradio")
|
||||
assert result == "hi, gradio"
|
||||
|
||||
def test_streaming_api_with_additional_inputs(self, connect):
|
||||
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
|
||||
def test_streaming_api_with_additional_inputs(self, msg_format, connect):
|
||||
chatbot = gr.ChatInterface(
|
||||
echo_system_prompt_plus_message,
|
||||
msg_format=msg_format,
|
||||
additional_inputs=["textbox", "slider"],
|
||||
).queue()
|
||||
with connect(chatbot) as client:
|
||||
|
Loading…
x
Reference in New Issue
Block a user