2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-03-25 12:10:31 +08:00

Add guides for msg format and llm agents ()

* Add guides

* add changeset

* Add code

* Add code

* Add notebook

* rename msg_format to type

* Fix docs

* notebooks

* missing link

* Update guides

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2024-07-12 13:53:28 +02:00 committed by GitHub
parent d21d8ee0cf
commit 5e36144232
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 713 additions and 375 deletions

@ -1,6 +0,0 @@
---
"@gradio/app": patch
"gradio": patch
---
fix:Refactoring <gradio-lite /> component making the code simpler and fixing a Playground mode bug

@ -1,6 +0,0 @@
---
"@gradio/plot": patch
"gradio": patch
---
fix:Fixes Plotly is rendered smaller in a hidden `gr.Tab`

@ -1,6 +0,0 @@
---
"@gradio/app": minor
"gradio": minor
---
feat:Allow app to fill width

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Add guides for msg format and llm agents

@ -13,7 +13,7 @@ highlight:
Building Gradio applications around these LLM solutions is now even easier!
`gr.Chatbot` and `gr.ChatInterface` now have a `msg_format` parameter that can accept two values - `'tuples'` and `'messages'`. If set to `'tuples'`, the default chatbot data format is expected. If set to `'messages'`, a list of dictionaries with `content` and `role` keys is expected. See below -
`gr.Chatbot` and `gr.ChatInterface` now have a `type` parameter that can accept two values - `'tuples'` and `'messages'`. If set to `'tuples'`, the default chatbot data format is expected. If set to `'messages'`, a list of dictionaries with `content` and `role` keys is expected. See below -
```python
def chat_greeter(msg, history):
@ -80,7 +80,7 @@ def generate_response(history):
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
chatbot = gr.Chatbot(type="messages")
button = gr.Button("Get San Francisco Weather")
button.click(generate_response, chatbot, chatbot)

@ -464,7 +464,7 @@ def max_file_size_demo():
@pytest.fixture
def chatbot_message_format():
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox()
def respond(message, chat_history: list):

@ -0,0 +1 @@
git+https://github.com/huggingface/transformers.git#egg=transformers[agents]

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: agent_chatbot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio git+https://github.com/huggingface/transformers.git#egg=transformers[agents]"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/agent_chatbot/utils.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from gradio import ChatMessage\n", "from transformers import load_tool, ReactCodeAgent, HfEngine\n", "from utils import stream_from_transformers_agent\n", "\n", "# Import tool from Hub\n", "image_generation_tool = load_tool(\"m-ric/text-to-image\")\n", "\n", "\n", "llm_engine = HfEngine(\"meta-llama/Meta-Llama-3-70B-Instruct\")\n", "# Initialize the agent with both tools\n", "agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)\n", "\n", "\n", "def interact_with_agent(prompt, messages):\n", " messages.append(ChatMessage(role=\"user\", content=prompt))\n", " yield messages\n", " for msg in stream_from_transformers_agent(agent, prompt):\n", " messages.append(msg)\n", " yield messages\n", " yield messages\n", "\n", "\n", "with gr.Blocks() as demo:\n", " stored_message = gr.State([])\n", " chatbot = gr.Chatbot(label=\"Agent\",\n", " type=\"messages\",\n", " avatar_images=(None, \"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png\"))\n", " text_input = gr.Textbox(lines=1, label=\"Chat Message\")\n", " text_input.submit(lambda s: (s, \"\"), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

34
demo/agent_chatbot/run.py Normal file

@ -0,0 +1,34 @@
import gradio as gr
from gradio import ChatMessage
from transformers import load_tool, ReactCodeAgent, HfEngine
from utils import stream_from_transformers_agent
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image")
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
# Initialize the agent with both tools
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
def interact_with_agent(prompt, messages):
messages.append(ChatMessage(role="user", content=prompt))
yield messages
for msg in stream_from_transformers_agent(agent, prompt):
messages.append(msg)
yield messages
yield messages
with gr.Blocks() as demo:
stored_message = gr.State([])
chatbot = gr.Chatbot(label="Agent",
type="messages",
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot])
if __name__ == "__main__":
demo.launch()

@ -0,0 +1,63 @@
from gradio import ChatMessage
from transformers.agents import ReactCodeAgent, agent_types
from typing import Generator
def pull_message(step_log: dict):
if step_log.get("rationale"):
yield ChatMessage(
role="assistant", content=step_log["rationale"]
)
if step_log.get("tool_call"):
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
content = step_log["tool_call"]["tool_arguments"]
if used_code:
content = f"```py\n{content}\n```"
yield ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
content=content,
)
if step_log.get("observation"):
yield ChatMessage(
role="assistant", content=f"```\n{step_log['observation']}\n```"
)
if step_log.get("error"):
yield ChatMessage(
role="assistant",
content=str(step_log["error"]),
metadata={"title": "💥 Error"},
)
def stream_from_transformers_agent(
agent: ReactCodeAgent, prompt: str
) -> Generator[ChatMessage, None, ChatMessage | None]:
"""Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
class Output:
output: agent_types.AgentType | str = None
for step_log in agent.run(prompt, stream=True):
if isinstance(step_log, dict):
for message in pull_message(step_log):
print("message", message)
yield message
Output.output = step_log
if isinstance(Output.output, agent_types.AgentText):
yield ChatMessage(
role="assistant", content=f"**Final answer:**\n```\n{Output.output.to_string()}\n```")
elif isinstance(Output.output, agent_types.AgentImage):
yield ChatMessage(
role="assistant",
content={"path": Output.output.to_string(), "mime_type": "image/png"},
)
elif isinstance(Output.output, agent_types.AgentAudio):
yield ChatMessage(
role="assistant",
content={"path": Output.output.to_string(), "mime_type": "audio/wav"},
)
else:
return ChatMessage(role="assistant", content=Output.output)

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks(fill_width=True) as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

@ -10,7 +10,7 @@ def xray_model(diseases, img):
def ct_model(diseases, img):
return [{disease: 0.1 for disease in diseases}]
with gr.Blocks(fill_width=True) as demo:
with gr.Blocks() as demo:
gr.Markdown(
"""
# Detect Disease From Scan

@ -65,7 +65,7 @@ with gr.Blocks(fill_height=True) as demo:
elem_id="chatbot",
bubble_full_width=False,
scale=1,
msg_format="messages"
type="messages"
)
response_type = gr.Radio(
[

@ -27,7 +27,7 @@ with gr.Blocks() as demo:
[],
elem_id="chatbot",
bubble_full_width=False,
msg_format="messages"
type="messages"
)
chat_input = gr.MultimodalTextbox(interactive=True,

@ -3,7 +3,7 @@ import random
import time
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox()
clear = gr.Button("Clear")

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_with_tools"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from gradio import ChatMessage\n", "import time\n", "\n", "def generate_response(history):\n", " history.append(ChatMessage(role=\"user\", content=\"What is the weather in San Francisco right now?\"))\n", " yield history\n", " time.sleep(0.25)\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"In order to find the current weather in San Francisco, I will need to use my weather tool.\")\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"API Error when connecting to weather service.\",\n", " metadata={\"title\": \"\ud83d\udca5 Error using tool 'Weather'\"})\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"I will try again\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Weather 72 degrees Fahrenheit with 20% chance of rain.\",\n", " metadata={\"title\": \"\ud83d\udee0\ufe0f Used tool 'Weather'\"}\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Now that the API succeeded I can complete my task.\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!\",\n", " ))\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(msg_format=\"messages\")\n", " button = gr.Button(\"Get San Francisco Weather\")\n", " button.click(generate_response, chatbot, chatbot)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_with_tools"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from gradio import ChatMessage\n", "import time\n", "\n", "def generate_response(history):\n", " history.append(ChatMessage(role=\"user\", content=\"What is the weather in San Francisco right now?\"))\n", " yield history\n", " time.sleep(0.25)\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"In order to find the current weather in San Francisco, I will need to use my weather tool.\")\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"API Error when connecting to weather service.\",\n", " metadata={\"title\": \"\ud83d\udca5 Error using tool 'Weather'\"})\n", " )\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"I will try again\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Weather 72 degrees Fahrenheit with 20% chance of rain.\",\n", " metadata={\"title\": \"\ud83d\udee0\ufe0f Used tool 'Weather'\"}\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"Now that the API succeeded I can complete my task.\",\n", " ))\n", " yield history\n", " time.sleep(0.25)\n", "\n", " history.append(ChatMessage(role=\"assistant\",\n", " content=\"It's a sunny day in San Francisco with a current temperature of 72 degrees Fahrenheit and a 20% chance of rain. Enjoy the weather!\",\n", " ))\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(type=\"messages\")\n", " button = gr.Button(\"Get San Francisco Weather\")\n", " button.click(generate_response, chatbot, chatbot)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

@ -45,7 +45,7 @@ def generate_response(history):
with gr.Blocks() as demo:
chatbot = gr.Chatbot(msg_format="messages")
chatbot = gr.Chatbot(type="messages")
button = gr.Button("Get San Francisco Weather")
button.click(generate_response, chatbot, chatbot)

@ -9,7 +9,7 @@ def slow_echo(message, history):
demo = gr.ChatInterface(slow_echo, msg_format="messages")
demo = gr.ChatInterface(slow_echo, type="messages")
if __name__ == "__main__":
demo.launch()

@ -11,7 +11,7 @@ def slow_echo(message, history):
yield f"Run {runs} - You typed: " + message[: i + 1]
demo = gr.ChatInterface(slow_echo, msg_format="messages").queue()
demo = gr.ChatInterface(slow_echo, type="messages").queue()
if __name__ == "__main__":
demo.launch()

@ -925,7 +925,6 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
js: str | None = None,
head: str | None = None,
fill_height: bool = False,
fill_width: bool = False,
delete_cache: tuple[int, int] | None = None,
**kwargs,
):
@ -939,7 +938,6 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
js: Custom js as a string or path to a js file. The custom js should be in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside <script> tags.
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, multiple scripts, stylesheets, etc. to the page.
fill_height: Whether to vertically expand top-level child components to the height of the window. If True, expansion occurs when the scale value of the child components >= 1.
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
self.limiter = None
@ -973,7 +971,6 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
self.show_error = True
self.head = head
self.fill_height = fill_height
self.fill_width = fill_width
self.delete_cache = delete_cache
if css is not None and os.path.exists(css):
with open(css, encoding="utf-8") as css_file:
@ -2016,7 +2013,6 @@ Received outputs:
),
},
"fill_height": self.fill_height,
"fill_width": self.fill_width,
"theme_hash": self.theme_hash,
}
config.update(self.default_config.get_config()) # type: ignore

@ -4,6 +4,7 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
from __future__ import annotations
import builtins
import functools
import inspect
import warnings
@ -59,7 +60,7 @@ class ChatInterface(Blocks):
fn: Callable,
*,
multimodal: bool = False,
msg_format: Literal["messages", "tuples"] = "tuples",
type: Literal["messages", "tuples"] = "tuples",
chatbot: Chatbot | None = None,
textbox: Textbox | MultimodalTextbox | None = None,
additional_inputs: str | Component | list[str | Component] | None = None,
@ -85,7 +86,6 @@ class ChatInterface(Blocks):
fill_height: bool = True,
delete_cache: tuple[int, int] | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "minimal",
fill_width: bool = False,
):
"""
Parameters:
@ -116,7 +116,6 @@ class ChatInterface(Blocks):
fill_height: If True, the chat interface will expand to the height of window.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
show_progress: whether to show progress animation while running.
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -127,10 +126,9 @@ class ChatInterface(Blocks):
js=js,
head=head,
fill_height=fill_height,
fill_width=fill_width,
delete_cache=delete_cache,
)
self.msg_format: Literal["messages", "tuples"] = msg_format
self.type: Literal["messages", "tuples"] = type
self.multimodal = multimodal
self.concurrency_limit = concurrency_limit
self.fn = fn
@ -178,7 +176,7 @@ class ChatInterface(Blocks):
)
else:
raise ValueError(
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {builtins.type(additional_inputs_accordion)}"
)
with self:
@ -190,19 +188,19 @@ class ChatInterface(Blocks):
Markdown(description)
if chatbot:
if self.msg_format != chatbot.msg_format:
if self.type != chatbot.type:
warnings.warn(
"The msg_format of the chatbot does not match the msg_format of the chat interface. The msg_format of the chat interface will be used."
"Recieved msg_format of chatbot: {chatbot.msg_format}, msg_format of chat interface: {self.msg_format}"
"The type of the chatbot does not match the type of the chat interface. The type of the chat interface will be used."
"Recieved type of chatbot: {chatbot.type}, type of chat interface: {self.type}"
)
chatbot.msg_format = self.msg_format
chatbot.type = self.type
self.chatbot = get_component_instance(chatbot, render=True)
else:
self.chatbot = Chatbot(
label="Chatbot",
scale=1,
height=200 if fill_height else None,
msg_format=self.msg_format,
type=self.type,
)
with Row():
@ -216,7 +214,7 @@ class ChatInterface(Blocks):
)
else:
raise ValueError(
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
f"All the _btn parameters must be a gr.Button, string, or None, not {builtins.type(btn)}"
)
self.buttons.append(btn) # type: ignore
@ -231,7 +229,7 @@ class ChatInterface(Blocks):
textbox_ = get_component_instance(textbox, render=True)
if not isinstance(textbox_, (Textbox, MultimodalTextbox)):
raise TypeError(
f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}"
f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {builtins.type(textbox_)}"
)
self.textbox = textbox_
elif self.multimodal:
@ -264,7 +262,7 @@ class ChatInterface(Blocks):
)
else:
raise ValueError(
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
f"The submit_btn parameter must be a gr.Button, string, or None, not {builtins.type(submit_btn)}"
)
if stop_btn is not None:
if isinstance(stop_btn, Button):
@ -280,7 +278,7 @@ class ChatInterface(Blocks):
)
else:
raise ValueError(
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
f"The stop_btn parameter must be a gr.Button, string, or None, not {builtins.type(stop_btn)}"
)
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
@ -536,7 +534,7 @@ class ChatInterface(Blocks):
response: MessageDict | str | None,
history: list[MessageDict] | TupleFormat,
):
if self.msg_format == "tuples":
if self.type == "tuples":
for x in message.files:
history.append([(x.path,), None]) # type: ignore
if message.text is None or not isinstance(message.text, str):
@ -562,9 +560,9 @@ class ChatInterface(Blocks):
) -> tuple[TupleFormat, TupleFormat] | tuple[list[MessageDict], list[MessageDict]]:
if self.multimodal and isinstance(message, MultimodalData):
self._append_multimodal_history(message, None, history)
elif isinstance(message, str) and self.msg_format == "tuples":
elif isinstance(message, str) and self.type == "tuples":
history.append([message, None]) # type: ignore
elif isinstance(message, str) and self.msg_format == "messages":
elif isinstance(message, str) and self.type == "messages":
history.append({"role": "user", "content": message}) # type: ignore
return history, history # type: ignore
@ -607,16 +605,16 @@ class ChatInterface(Blocks):
self.fn, *inputs, limiter=self.limiter
)
if self.msg_format == "messages":
if self.type == "messages":
new_response = self.response_as_dict(response)
else:
new_response = response
if self.multimodal and isinstance(message, MultimodalData):
self._append_multimodal_history(message, new_response, history) # type: ignore
elif isinstance(message, str) and self.msg_format == "tuples":
elif isinstance(message, str) and self.type == "tuples":
history.append([message, new_response]) # type: ignore
elif isinstance(message, str) and self.msg_format == "messages":
elif isinstance(message, str) and self.type == "messages":
history.extend([{"role": "user", "content": message}, new_response]) # type: ignore
return history, history # type: ignore
@ -649,12 +647,12 @@ class ChatInterface(Blocks):
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
if self.msg_format == "messages":
if self.type == "messages":
first_response = self.response_as_dict(first_response)
if (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "tuples"
and self.type == "tuples"
):
for x in message.files:
history.append([(x,), None]) # type: ignore
@ -663,7 +661,7 @@ class ChatInterface(Blocks):
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "messages"
and self.type == "messages"
):
for x in message.files:
history.append(
@ -674,7 +672,7 @@ class ChatInterface(Blocks):
first_response,
]
yield update, update
elif self.msg_format == "tuples":
elif self.type == "tuples":
update = history + [[message, first_response]]
yield update, update
else:
@ -691,26 +689,26 @@ class ChatInterface(Blocks):
update = history + [[message, None]]
yield update, update
async for response in generator:
if self.msg_format == "messages":
if self.type == "messages":
response = self.response_as_dict(response)
if (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "tuples"
and self.type == "tuples"
):
update = history + [[message.text, response]]
yield update, update
elif (
self.multimodal
and isinstance(message, MultimodalData)
and self.msg_format == "messages"
and self.type == "messages"
):
update = history + [
{"role": "user", "content": message.text},
response,
]
yield update, update
elif self.msg_format == "tuples":
elif self.type == "tuples":
update = history + [[message, response]]
yield update, update
else:
@ -734,7 +732,7 @@ class ChatInterface(Blocks):
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
if self.msg_format == "tuples":
if self.type == "tuples":
history.append([message, response]) # type: ignore
else:
new_response = self.response_as_dict(response)
@ -756,7 +754,7 @@ class ChatInterface(Blocks):
generator = SyncToAsyncIterator(generator, self.limiter)
try:
first_response = await async_iteration(generator)
if self.msg_format == "tuples":
if self.type == "tuples":
yield first_response, history + [[message, first_response]]
else:
first_response = self.response_as_dict(first_response)
@ -767,7 +765,7 @@ class ChatInterface(Blocks):
except StopIteration:
yield None, history + [[message, None]]
async for response in generator:
if self.msg_format == "tuples":
if self.type == "tuples":
yield response, history + [[message, response]]
else:
new_response = self.response_as_dict(response)
@ -787,7 +785,7 @@ class ChatInterface(Blocks):
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
if self.msg_format == "tuples":
if self.type == "tuples":
return [[message, response]]
else:
return [{"role": "user", "content": message}, response]
@ -807,7 +805,7 @@ class ChatInterface(Blocks):
)
generator = SyncToAsyncIterator(generator, self.limiter)
async for response in generator:
if self.msg_format == "tuples":
if self.type == "tuples":
yield [[message, response]]
else:
new_response = self.response_as_dict(response)
@ -822,7 +820,7 @@ class ChatInterface(Blocks):
str | MultimodalData,
list[MessageDict] | TupleFormat,
]:
extra = 1 if self.msg_format == "messages" else 0
extra = 1 if self.type == "messages" else 0
if self.multimodal and isinstance(message, MultimodalData):
remove_input = (
len(message.files) + 1

@ -143,7 +143,7 @@ class Chatbot(Component):
| None
) = None,
*,
msg_format: Literal["messages", "tuples"] = "tuples",
type: Literal["messages", "tuples"] = "tuples",
label: str | None = None,
every: Timer | float | None = None,
inputs: Component | list[Component] | set[Component] | None = None,
@ -173,7 +173,7 @@ class Chatbot(Component):
"""
Parameters:
value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
msg_format: The format of the messages. If 'tuples', expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content' key's value supports everything the 'tuples' format supports. The 'role' key should be one of 'user' or 'assistant'. Any other roles will not be displayed in the output.
type: The format of the messages. If 'tuples', expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content' key's value supports everything the 'tuples' format supports. The 'role' key should be one of 'user' or 'assistant'. Any other roles will not be displayed in the output.
label: The label for this component. Appears above the component and is also used as the header if there are a table of examples for this component. If None and used in a `gr.Interface`, the label will be the name of the parameter this component is assigned to.
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
@ -201,12 +201,10 @@ class Chatbot(Component):
placeholder: a placeholder message to display in the chatbot when it is empty. Centered vertically and horizontally in the Chatbot. Supports Markdown and HTML. If None, no placeholder is displayed.
"""
self.likeable = likeable
if msg_format not in ["messages", "tuples"]:
raise ValueError(
"msg_format must be 'messages' or 'tuples', received: {msg_format}"
)
self.msg_format: Literal["tuples", "messages"] = msg_format
if msg_format == "messages":
if type not in ["messages", "tuples"]:
raise ValueError("type must be 'messages' or 'tuples', received: {type}")
self.type: Literal["tuples", "messages"] = type
if type == "messages":
self.data_model = ChatbotDataMessages
else:
self.data_model = ChatbotDataTuples
@ -252,14 +250,16 @@ class Chatbot(Component):
self.placeholder = placeholder
@staticmethod
def _check_format(messages: list[Any], msg_format: Literal["messages", "tuples"]):
if msg_format == "messages":
all_dicts = all(
isinstance(message, dict) and "role" in message and "content" in message
def _check_format(messages: list[Any], type: Literal["messages", "tuples"]):
if type == "messages":
all_valid = all(
isinstance(message, dict)
and "role" in message
and "content" in message
or isinstance(message, ChatMessage)
for message in messages
)
all_msgs = all(isinstance(msg, ChatMessage) for msg in messages)
if not (all_dicts or all_msgs):
if not all_valid:
raise Error(
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
)
@ -341,11 +341,11 @@ class Chatbot(Component):
Parameters:
payload: data as a ChatbotData object
Returns:
If msg_format is 'tuples', passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed. If msg_format is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
If type is 'tuples', passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed. If type is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
"""
if payload is None:
return payload
if self.msg_format == "tuples":
if self.type == "tuples":
if not isinstance(payload, ChatbotDataTuples):
raise Error("Data incompatible with the tuples format")
return self._preprocess_messages_tuples(cast(ChatbotDataTuples, payload))
@ -474,7 +474,7 @@ class Chatbot(Component):
) -> ChatbotDataTuples | ChatbotDataMessages:
"""
Parameters:
value: If msg_format is `tuples`, expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If msg_format is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
value: If type is `tuples`, expects a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list should have 2 elements: the user message and the response message. The individual messages can be (1) strings in valid Markdown, (2) tuples if sending files: (a filepath or URL to a file, [optional string alt text]) -- if the file is image/video/audio, it is displayed in the Chatbot, or (3) None, in which case the message is not displayed. If type is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
Returns:
an object of type ChatbotData
"""
@ -483,7 +483,7 @@ class Chatbot(Component):
)
if value is None:
return data_model(root=[])
if self.msg_format == "tuples":
if self.type == "tuples":
self._check_format(value, "tuples")
return self._postprocess_messages_tuples(cast(TupleFormat, value))
self._check_format(value, "messages")
@ -495,7 +495,7 @@ class Chatbot(Component):
return ChatbotDataMessages(root=processed_messages)
def example_payload(self) -> Any:
if self.msg_format == "messages":
if self.type == "messages":
return [
Message(role="user", content="Hello!").model_dump(),
Message(role="assistant", content="How can I help you?").model_dump(),
@ -503,7 +503,7 @@ class Chatbot(Component):
return [["Hello!", None]]
def example_value(self) -> Any:
if self.msg_format == "messages":
if self.type == "messages":
return [
Message(role="user", content="Hello!").model_dump(),
Message(role="assistant", content="How can I help you?").model_dump(),

@ -326,7 +326,6 @@ class BlocksConfigDict(TypedDict):
protocol: Literal["ws", "sse", "sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
body_css: BodyCSS
fill_height: bool
fill_width: bool
theme_hash: str
layout: NotRequired[Layout]
dependencies: NotRequired[list[dict[str, Any]]]

@ -133,7 +133,6 @@ class Interface(Blocks):
delete_cache: tuple[int, int] | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "full",
example_labels: list[str] | None = None,
fill_width: bool = False,
**kwargs,
):
"""
@ -171,7 +170,6 @@ class Interface(Blocks):
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
show_progress: whether to show progress animation while running. Has no effect if the interface is `live`.
example_labels: A list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -182,7 +180,6 @@ class Interface(Blocks):
js=js,
head=head,
delete_cache=delete_cache,
fill_width=fill_width,
**kwargs,
)
self.api_name: str | Literal[False] | None = api_name

@ -0,0 +1,97 @@
# Using the Messages data format
In the previous guides, we built chatbots where the conversation history was stored in a list of tuple pairs.
It is also possible to use the more flexible [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api#messages-api), which is fully compatible with LLM API providers such as Hugging Face Text Generation Inference, Llama.cpp server, and OpenAI's chat completions API.
To use this format, set the `type` parameter of `gr.Chatbot` or `gr.ChatInterface` to `'messages'`. This expects a list of dictionaries with content and role keys.
The `role` key should be `'assistant'` for the bot/llm and `user` for the human.
The `content` key can be one of three things:
1. A string (markdown supported) to display a simple text message
2. A dictionary (or `gr.FileData`) to display a file. At minimum this dictionary should contain a `path` key corresponding to the path to the file. Full documenation of this dictionary is in the appendix of this guide.
3. A gradio component - at present `gr.Plot`, `gr.Image`, `gr.Gallery`, `gr.Video`, `gr.Audio` are supported.
For better type hinting and auto-completion in your IDE, you can use the `gr.ChatMessage` dataclass:
```python
from gradio import ChatMessage
def chat_function(message, history):
history.append(ChatMessage(role="user", content=message))
history.append(ChatMessage(role="assistant", content="Hello, how can I help you?"))
return history
```
## Examples
The following chatbot will always greet the user with "Hello"
```python
import gradio as gr
def chat_greeter(msg, history):
history.append({"role": "assistant", "content": "Hello!"})
return history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
msg.submit(chat_greeter, [msg, chatbot], [chatbot])
demo.launch()
```
The messages format lets us seemlessly stream from the Hugging Face Inference API -
```python
import gradio as gr
from huggingface_hub import InferenceClient
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
def respond(message, history: list[dict]):
messages = history + [{"role": "user", "content": message}]
print(messages)
response = {"role": "assistant", "content": ""}
for message in client.chat_completion(
messages,
max_tokens=512,
stream=True,
temperature=0.7,
top_p=0.95,
):
token = message.choices[0].delta.content
response['content'] += token
yield response
demo = gr.ChatInterface(respond, type="messages")
if __name__ == "__main__":
demo.launch()
```
### Appendix
The full contents of the dictionary format for files is documented here
```python
class FileDataDict(TypedDict):
path: str # server filepath
url: NotRequired[Optional[str]] # normalised server url
size: NotRequired[Optional[int]] # size in bytes
orig_name: NotRequired[Optional[str]] # original filename
mime_type: NotRequired[Optional[str]]
is_stream: NotRequired[bool]
meta: dict[Literal["_type"], Literal["gradio.FileData"]]
```

@ -0,0 +1,156 @@
# Building a UI for an LLM Agent
Tags: LLM, AGENTS, CHAT
Related spaces: https://huggingface.co/spaces/gradio/agent_chatbot, https://huggingface.co/spaces/gradio/langchain-agent
The Gradio Chatbot can natively display intermediate thoughts and tool usage. This makes it perfect for creating UIs for LLM agents. This guide will show you how. Before we begin, familiarize yourself with the `messages` chatbot data format documented in this [guide](./messages-format).
## The metadata key
In addition to the `content` and `role` keys, the messages dictionary accepts a `metadata` key. At present, the `metadata` key accepts a dictionary with a single key called `title`.
If you specify a `title` for the message, it will be displayed in a collapsible box.
Here is an example, were we display the agent's thought to use a weather API tool to answer the user query.
```python
with gr.Blocks() as demo:
chatbot = gr.Chatbot(type="messages",
value=[{"role": "user", "content": "What is the weather in San Francisco?"},
{"role": "assistant", "content": "I need to use the weather API tool",
"metadata": {"title": "🧠 Thinking"}}]
)
```
![simple-metadat-chatbot](https://github.com/freddyaboulton/freddyboulton/assets/41651716/3941783f-6835-4e5e-89a6-03f850d9abde)
## A real example using transformers.agents
We'll create a Gradio application simple agent that has access to a text-to-image tool.
Tip: Make sure you read the transformers agent [documentation](https://huggingface.co/docs/transformers/en/agents) first
We'll start by importing the necessary classes from transformers and gradio.
```python
import gradio as gr
from gradio import ChatMessage
from transformers import load_tool, ReactCodeAgent, HfEngine
from utils import stream_from_transformers_agent
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image")
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
# Initialize the agent with both tools
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
```
Then we'll build the UI. The bulk of the logic is handled by `stream_from_transformers_agent`. We won't cover it in this guide because it will soon be merged to transformers but you can see its source code [here](https://huggingface.co/spaces/gradio/agent_chatbot/blob/main/utils.py).
```python
def interact_with_agent(prompt, messages):
messages.append(ChatMessage(role="user", content=prompt))
yield messages
for msg in stream_from_transformers_agent(agent, prompt):
messages.append(msg)
yield messages
yield messages
with gr.Blocks() as demo:
stored_message = gr.State([])
chatbot = gr.Chatbot(label="Agent",
type="messages",
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot])
```
You can see the full demo code [here](https://huggingface.co/spaces/gradio/agent_chatbot/blob/main/app.py).
![transformers_agent_code](https://github.com/freddyaboulton/freddyboulton/assets/41651716/c8d21336-e0e6-4878-88ea-e6fcfef3552d)
## A real example using langchain agents
We'll create a UI for langchain agent that has access to a search engine.
We'll begin with imports and setting up the langchain agent. Note that you'll need an .env file with
the following environment variables set -
```
SERPAPI_API_KEY=
HF_TOKEN=
OPENAI_API_KEY=
```
```python
from langchain import hub
from langchain.agents import AgentExecutor, create_openai_tools_agent, load_tools
from langchain_openai import ChatOpenAI
from gradio import ChatMessage
import gradio as gr
from dotenv import load_dotenv
load_dotenv()
model = ChatOpenAI(temperature=0, streaming=True)
tools = load_tools(["serpapi"])
# Get the prompt to use - you can modify this!
prompt = hub.pull("hwchase17/openai-tools-agent")
# print(prompt.messages) -- to see the prompt
agent = create_openai_tools_agent(
model.with_config({"tags": ["agent_llm"]}), tools, prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools).with_config(
{"run_name": "Agent"}
)
```
Then we'll create the Gradio UI
```python
async def interact_with_langchain_agent(prompt, messages):
messages.append(ChatMessage(role="user", content=prompt))
yield messages
async for chunk in agent_executor.astream(
{"input": prompt}
):
if "steps" in chunk:
for step in chunk["steps"]:
messages.append(ChatMessage(role="assistant", content=step.action.log,
metadata={"title": f"🛠️ Used tool {step.action.tool}"}))
yield messages
if "output" in chunk:
messages.append(ChatMessage(role="assistant", content=chunk["output"]))
yield messages
with gr.Blocks() as demo:
gr.Markdown("# Chat with a LangChain Agent 🦜⛓️ and see its thoughts 💭")
chatbot = gr.Chatbot(
type="messages",
label="Agent",
avatar_images=(
None,
"https://em-content.zobj.net/source/twitter/141/parrot_1f99c.png",
),
)
input = gr.Textbox(lines=1, label="Chat Message")
input.submit(interact_with_langchain_agent, [input_2, chatbot_2], [chatbot_2])
demo.launch()
```
![langchain_agent_code](https://github.com/freddyaboulton/freddyboulton/assets/41651716/762283e5-3937-47e5-89e0-79657279ea67)
That's it! See our finished langchain demo [here](https://huggingface.co/spaces/gradio/langchain-agent).

@ -37,18 +37,18 @@ demo.launch()
<!-- Behavior -->
### Behavior
The data format accepted by the Chatbot is dictated by the `msg_format` parameter.
The data format accepted by the Chatbot is dictated by the `type` parameter.
This parameter can take two values, `'tuples'` and `'messages'`.
If `msg_format` is `'tuples'`, then the data sent to/from the chatbot will be a list of tuples.
If `type` is `'tuples'`, then the data sent to/from the chatbot will be a list of tuples.
The first element of each tuple is the user message and the second element is the bot's response.
Each element can be a string (markdown/html is supported),
a tuple (in which case the first element is a filepath that will be displayed in the chatbot),
or a gradio component (see the Examples section for more details).
If the `msg_format` is `'messages'`, then the data sent to/from the chatbot will be a list of dictionaries
If the `type` is `'messages'`, then the data sent to/from the chatbot will be a list of dictionaries
with `role` and `content` keys. This format is compliant with the format expected by most LLM APIs (HuggingChat, OpenAI, Claude).
The `role` key is either `'user'` or `'`assistant'` and the `content` key can be a string (markdown/html supported),
a `FileDataDict` (to represent a file that is displayed in the chatbot - documented below), or a gradio component.
@ -67,7 +67,7 @@ def generate_response(history):
return history
```
Additionally, when `msg_format` is `messages`, you can provide additional metadata regarding any tools used to generate the response.
Additionally, when `type` is `messages`, you can provide additional metadata regarding any tools used to generate the response.
This is useful for displaying the thought process of LLM agents. For example,
```python
@ -123,7 +123,7 @@ class ChatMessage:
## **As input component**: {@html style_formatted_text(obj.preprocess.return_doc.doc)}
##### Your function should accept one of these types:
If `msg_format` is `tuples` -
If `type` is `tuples` -
```python
from gradio import Component
@ -134,7 +134,7 @@ def predict(
...
```
If `msg_format` is `messages` -
If `type` is `messages` -
```python
from gradio import MessageDict
@ -147,7 +147,7 @@ def predict(value: list[MessageDict] | None):
## **As output component**: {@html style_formatted_text(obj.postprocess.parameter_doc[0].doc)}
##### Your function should return one of these types:
If `msg_format` is `tuples` -
If `type` is `tuples` -
```python
def predict(···) -> list[list[str | tuple[str] | tuple[str, str] | None] | tuple] | None
@ -155,7 +155,7 @@ def predict(···) -> list[list[str | tuple[str] | tuple[str, str] | None] | tu
return value
```
If `msg_format` is `messages` -
If `type` is `messages` -
from gradio import ChatMessage, MessageDict

@ -4,7 +4,6 @@
export let wrapper: HTMLDivElement;
export let version: string;
export let initial_height: string;
export let fill_width: boolean;
export let is_embed: boolean;
export let space: string | null;
@ -16,7 +15,6 @@
<div
bind:this={wrapper}
class:app={!display && !is_embed}
class:fill_width
class:embed-container={display}
class:with-info={info}
class="gradio-container gradio-container-{version}"
@ -89,27 +87,27 @@
}
@media (--screen-sm) {
.app:not(.fill_width) {
.app {
max-width: 640px;
}
}
@media (--screen-md) {
.app:not(.fill_width) {
.app {
max-width: 768px;
}
}
@media (--screen-lg) {
.app:not(.fill_width) {
.app {
max-width: 1024px;
}
}
@media (--screen-xl) {
.app:not(.fill_width) {
.app {
max-width: 1280px;
}
}
@media (--screen-xxl) {
.app:not(.fill_width) {
.app {
max-width: 1536px;
}
}

@ -29,7 +29,6 @@
path: string;
app_id?: string;
fill_height?: boolean;
fill_width?: boolean;
theme_hash?: number;
username: string | null;
}
@ -413,7 +412,6 @@
{initial_height}
{space}
loaded={loader_status === "complete"}
fill_width={config?.fill_width || false}
bind:wrapper
>
{#if (loader_status === "pending" || loader_status === "error") && !(config && config?.auth_required)}

@ -1,201 +0,0 @@
<script lang="ts">
import "@gradio/theme/src/reset.css";
import "@gradio/theme/src/global.css";
import "@gradio/theme/src/pollen.css";
import "@gradio/theme/src/typography.css";
import { onDestroy, SvelteComponent } from "svelte";
import Index from "../Index.svelte";
import Playground from "./Playground.svelte";
import ErrorDisplay from "./ErrorDisplay.svelte";
import type { ThemeMode } from "../types";
import { WorkerProxy, type WorkerProxyOptions } from "@gradio/wasm";
import { Client } from "@gradio/client";
import { wasm_proxied_fetch } from "./fetch";
import { wasm_proxied_stream_factory } from "./sse";
import { wasm_proxied_mount_css, mount_prebuilt_css } from "./css";
import type { mount_css } from "../css";
// These imports are aliased at built time with Vite. See the `resolve.alias` config in `vite.config.ts`.
import gradioWheel from "gradio.whl";
import gradioClientWheel from "gradio_client.whl";
export let info: boolean;
export let container: boolean;
export let is_embed: boolean;
export let initial_height: string;
export let eager: boolean;
export let version: string;
export let theme_mode: ThemeMode | null;
export let autoscroll: boolean;
export let control_page_title: boolean;
export let app_mode: boolean;
// For Wasm mode
export let files: WorkerProxyOptions["files"] | undefined;
export let requirements: WorkerProxyOptions["requirements"] | undefined;
export let code: string | undefined;
export let entrypoint: string | undefined;
export let sharedWorkerMode: boolean | undefined;
// For playground
export let playground = false;
export let layout: string | null;
const worker_proxy = new WorkerProxy({
gradioWheelUrl: new URL(gradioWheel, import.meta.url).href,
gradioClientWheelUrl: new URL(gradioClientWheel, import.meta.url).href,
files: files ?? {},
requirements: requirements ?? [],
sharedWorkerMode: sharedWorkerMode ?? false
});
onDestroy(() => {
worker_proxy.terminate();
});
let error: Error | null = null;
const wrapFunctionWithAppLogic = <TArgs extends any[], TRet extends any>(
func: (...args: TArgs) => Promise<TRet>
): ((...args: TArgs) => Promise<TRet>) => {
return async (...args: TArgs) => {
try {
error = null;
const retval = await func(...args);
refresh_index_component();
return retval;
} catch (err) {
error = err as Error;
throw err;
}
};
};
worker_proxy.runPythonCode = wrapFunctionWithAppLogic(
worker_proxy.runPythonCode.bind(worker_proxy)
);
worker_proxy.runPythonFile = wrapFunctionWithAppLogic(
worker_proxy.runPythonFile.bind(worker_proxy)
);
worker_proxy.writeFile = wrapFunctionWithAppLogic(
worker_proxy.writeFile.bind(worker_proxy)
);
worker_proxy.renameFile = wrapFunctionWithAppLogic(
worker_proxy.renameFile.bind(worker_proxy)
);
worker_proxy.unlink = wrapFunctionWithAppLogic(
worker_proxy.unlink.bind(worker_proxy)
);
worker_proxy.install = wrapFunctionWithAppLogic(
worker_proxy.install.bind(worker_proxy)
);
worker_proxy.addEventListener("initialization-error", (event) => {
error = (event as CustomEvent).detail;
});
// Internally, the execution of `runPythonCode()` or `runPythonFile()` is queued
// and its promise will be resolved after the Pyodide is loaded and the worker initialization is done
// (see the await in the `onmessage` callback in the webworker code)
// So we don't await this promise because we want to mount the `Index` immediately and start the app initialization asynchronously.
if (code != null) {
worker_proxy.runPythonCode(code);
} else if (entrypoint != null) {
worker_proxy.runPythonFile(entrypoint);
} else {
throw new Error("Either code or entrypoint must be provided.");
}
mount_prebuilt_css(document.head);
class LiteClient extends Client {
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
return wasm_proxied_fetch(worker_proxy, input, init);
}
stream(url: URL): EventSource {
return wasm_proxied_stream_factory(worker_proxy, url);
}
}
const overridden_mount_css: typeof mount_css = async (url, target) => {
return wasm_proxied_mount_css(worker_proxy, url, target);
};
let index_component_key = 0;
function refresh_index_component(): void {
index_component_key += 1;
}
let playground_component: SvelteComponent | null = null;
$: playground_component?.$on("code", (event) => {
const { code } = event.detail;
worker_proxy.runPythonCode(code);
});
export const run_code = worker_proxy.runPythonCode;
export const run_file = worker_proxy.runPythonFile;
export const write = worker_proxy.writeFile;
export const rename = worker_proxy.renameFile;
export const unlink = worker_proxy.unlink;
export const install = worker_proxy.install;
</script>
{#if playground}
<Playground
bind:this={playground_component}
{worker_proxy}
{layout}
{code}
{is_embed}
>
{#key index_component_key}
{#if error}
<ErrorDisplay {error} is_embed />
{:else}
<Index
space={null}
src={null}
host={null}
{info}
{container}
{is_embed}
{initial_height}
{eager}
{version}
{theme_mode}
{autoscroll}
{control_page_title}
{app_mode}
{worker_proxy}
Client={LiteClient}
mount_css={overridden_mount_css}
/>
{/if}
{/key}
</Playground>
{:else}
{#key index_component_key}
{#if error}
<ErrorDisplay {error} {is_embed} />
{:else}
<Index
space={null}
src={null}
host={null}
{info}
{container}
{is_embed}
{initial_height}
{eager}
{version}
{theme_mode}
{autoscroll}
{control_page_title}
{app_mode}
{worker_proxy}
Client={LiteClient}
mount_css={overridden_mount_css}
/>
{/if}
{/key}
{/if}

@ -1,16 +1,35 @@
<script lang="ts">
import Index from "../Index.svelte";
import type { ThemeMode } from "../types";
import { mount_css as default_mount_css } from "../css";
import type { Client as ClientType } from "@gradio/client";
import type { WorkerProxy } from "@gradio/wasm";
import { createEventDispatcher, onMount } from "svelte";
import { SvelteComponent, createEventDispatcher, onMount } from "svelte";
import Code from "@gradio/code";
import ErrorDisplay from "./ErrorDisplay.svelte";
import lightning from "../images/lightning.svg";
import type { LoadingStatus } from "js/statustracker";
export let autoscroll: boolean;
export let version: string;
export let initial_height: string;
export let app_mode: boolean;
export let is_embed: boolean;
export let theme_mode: ThemeMode | null = "system";
export let control_page_title: boolean;
export let container: boolean;
export let info: boolean;
export let eager: boolean;
export let mount_css: typeof default_mount_css = default_mount_css;
export let Client: typeof ClientType;
export let worker_proxy: WorkerProxy | undefined = undefined;
export let space: string | null;
export let host: string | null;
export let src: string | null;
export let code: string | undefined;
export let error_display: SvelteComponent | null;
export let layout: string | null = null;
const dispatch = createEventDispatcher();
@ -180,7 +199,33 @@
</div>
{#if loaded}
<div class="preview">
<slot></slot>
<div class="flex-grow: 1;">
{#if !error_display}
<Index
{autoscroll}
{version}
{initial_height}
{app_mode}
{is_embed}
{theme_mode}
{control_page_title}
{container}
{info}
{eager}
{mount_css}
{Client}
bind:worker_proxy
{space}
{host}
{src}
/>
{:else}
<ErrorDisplay
is_embed={error_display.is_embed}
error={error_display.error}
/>
{/if}
</div>
</div>
{/if}
</div>

@ -1,10 +1,25 @@
import { type WorkerProxyOptions } from "@gradio/wasm";
import "@gradio/theme/src/reset.css";
import "@gradio/theme/src/global.css";
import "@gradio/theme/src/pollen.css";
import "@gradio/theme/src/typography.css";
import type { SvelteComponent } from "svelte";
import { WorkerProxy, type WorkerProxyOptions } from "@gradio/wasm";
import { Client } from "@gradio/client";
import { wasm_proxied_fetch } from "./fetch";
import { wasm_proxied_stream_factory } from "./sse";
import { wasm_proxied_mount_css, mount_prebuilt_css } from "./css";
import type { mount_css } from "../css";
import Index from "../Index.svelte";
import Playground from "./Playground.svelte";
import ErrorDisplay from "./ErrorDisplay.svelte";
import type { ThemeMode } from "../types";
import { bootstrap_custom_element } from "./custom-element";
declare let GRADIO_VERSION: string;
// These imports are aliased at built time with Vite. See the `resolve.alias` config in `vite.config.ts`.
import gradioWheel from "gradio.whl";
import gradioClientWheel from "gradio_client.whl";
import LiteIndex from "./LiteIndex.svelte";
declare let GRADIO_VERSION: string;
// NOTE: The following line has been copied from `main.ts`.
// In `main.ts`, which is the normal Gradio app entry point,
@ -63,40 +78,217 @@ export function create(options: Options): GradioAppController {
observer.observe(options.target, { childList: true });
const app = new LiteIndex({
target: options.target,
props: {
const worker_proxy = new WorkerProxy({
gradioWheelUrl: new URL(gradioWheel, import.meta.url).href,
gradioClientWheelUrl: new URL(gradioClientWheel, import.meta.url).href,
files: options.files ?? {},
requirements: options.requirements ?? [],
sharedWorkerMode: options.sharedWorkerMode ?? false
});
worker_proxy.addEventListener("initialization-error", (event) => {
showError((event as CustomEvent).detail);
});
// Internally, the execution of `runPythonCode()` or `runPythonFile()` is queued
// and its promise will be resolved after the Pyodide is loaded and the worker initialization is done
// (see the await in the `onmessage` callback in the webworker code)
// So we don't await this promise because we want to mount the `Index` immediately and start the app initialization asynchronously.
if (options.code != null) {
worker_proxy.runPythonCode(options.code).catch(showError);
} else if (options.entrypoint != null) {
worker_proxy.runPythonFile(options.entrypoint).catch(showError);
} else {
throw new Error("Either code or entrypoint must be provided.");
}
mount_prebuilt_css(document.head);
class LiteClient extends Client {
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
return wasm_proxied_fetch(worker_proxy, input, init);
}
stream(url: URL): EventSource {
return wasm_proxied_stream_factory(worker_proxy, url);
}
}
const overridden_mount_css: typeof mount_css = async (url, target) => {
return wasm_proxied_mount_css(worker_proxy, url, target);
};
let app: SvelteComponent;
let app_props: any;
let loaded = false;
function showError(error: Error): void {
if (app != null) {
app.$destroy();
}
if (options.playground) {
app = new Playground({
target: options.target,
props: {
...app_props,
code: options.code,
error_display: {
is_embed: !options.isEmbed,
error
},
loaded: true
}
});
app.$on("code", (event) => {
options.code = event.detail.code;
loaded = true;
worker_proxy
.runPythonCode(event.detail.code)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
});
} else {
app = new ErrorDisplay({
target: options.target,
props: {
is_embed: !options.isEmbed,
error
}
});
}
}
function launchNewApp(): Promise<void> {
if (app != null) {
app.$destroy();
}
app_props = {
// embed source
space: null,
src: null,
host: null,
// embed info
info: options.info,
container: options.container,
is_embed: options.isEmbed,
initial_height: options.initialHeight ?? "300px",
initial_height: options.initialHeight ?? "300px", // default: 300px
eager: options.eager,
// gradio meta info
version: GRADIO_VERSION,
theme_mode: options.themeMode,
// misc global behaviour
autoscroll: options.autoScroll,
control_page_title: options.controlPageTitle,
// for gradio docs
// TODO: Remove -- i think this is just for autoscroll behavhiour, app vs embeds
app_mode: options.appMode,
// For Wasm mode
files: options.files,
requirements: options.requirements,
code: options.code,
entrypoint: options.entrypoint,
sharedWorkerMode: options.sharedWorkerMode,
worker_proxy,
Client: LiteClient,
mount_css: overridden_mount_css,
// For playground
playground: options.playground,
layout: options.layout
};
if (options.playground) {
app = new Playground({
target: options.target,
props: {
...app_props,
code: options.code,
error_display: null,
loaded: loaded
}
});
app.$on("code", (event) => {
options.code = event.detail.code;
loaded = true;
worker_proxy
.runPythonCode(event.detail.code)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
});
} else {
app = new Index({
target: options.target,
props: app_props
});
}
});
return new Promise((resolve) => {
app.$on("loaded", () => {
resolve();
});
});
}
launchNewApp();
return {
run_code: app.run_code,
run_file: app.run_file,
write: app.write,
rename: app.rename,
unlink: app.unlink,
install: app.install,
run_code: (code: string) => {
return worker_proxy
.runPythonCode(code)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
},
run_file: (path: string) => {
return worker_proxy
.runPythonFile(path)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
},
write: (path, data, opts) => {
return worker_proxy
.writeFile(path, data, opts)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
},
rename: (old_path, new_path) => {
return worker_proxy
.renameFile(old_path, new_path)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
},
unlink: (path) => {
return worker_proxy
.unlink(path)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
},
install: (requirements) => {
return worker_proxy
.install(requirements)
.then(launchNewApp)
.catch((e) => {
showError(e);
throw e;
});
},
unmount() {
app.$destroy();
worker_proxy.terminate();
}
};
}

@ -37,7 +37,7 @@
export let sanitize_html = true;
export let bubble_full_width = true;
export let layout: "bubble" | "panel" = "bubble";
export let msg_format: "tuples" | "messages" = "tuples";
export let type: "tuples" | "messages" = "tuples";
export let render_markdown = true;
export let line_breaks = true;
export let latex_delimiters: {
@ -58,7 +58,7 @@
let _value: NormalisedMessage[] | null = [];
$: _value =
msg_format === "tuples"
type === "tuples"
? normalise_tuples(value as TupleFormat, root)
: normalise_messages(value as Message[], root);

@ -1,12 +1,7 @@
<script lang="ts">
//@ts-nocheck
import Plotly from "plotly.js-dist-min";
import {
afterUpdate,
onMount,
onDestroy,
createEventDispatcher
} from "svelte";
import { afterUpdate, createEventDispatcher } from "svelte";
export let value;
export let target;
@ -15,7 +10,6 @@
let plot_div;
let plotly_global_style;
let resizeObserver;
const dispatch = createEventDispatcher<{ load: undefined }>();
@ -30,23 +24,6 @@
}
}
onMount(() => {
resizeObserver = new ResizeObserver(() => {
Plotly.Plots.resize(plot_div);
});
if (plot_div && plot_div.parentElement) {
resizeObserver.observe(plot_div.parentElement);
}
});
onDestroy(() => {
window.removeEventListener("resize", updatePlot);
if (resizeObserver && plot_div && plot_div.parentElement) {
resizeObserver.unobserve(plot_div.parentElement);
}
});
afterUpdate(async () => {
load_plotly_css();

@ -47,7 +47,7 @@ class TestChatbot:
"proxy_url": None,
"_selectable": False,
"key": None,
"msg_format": "tuples",
"type": "tuples",
"latex_delimiters": [{"display": True, "left": "$$", "right": "$$"}],
"likeable": False,
"rtl": False,

@ -247,41 +247,41 @@ class TestAPI:
assert len(api_info["unnamed_endpoints"]) == 0
assert "/chat" in api_info["named_endpoints"]
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_streaming_api(self, msg_format, connect):
chatbot = gr.ChatInterface(stream, msg_format=msg_format).queue()
@pytest.mark.parametrize("type", ["tuples", "messages"])
def test_streaming_api(self, type, connect):
chatbot = gr.ChatInterface(stream, type=type).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_streaming_api_async(self, msg_format, connect):
chatbot = gr.ChatInterface(async_stream, msg_format=msg_format).queue()
@pytest.mark.parametrize("type", ["tuples", "messages"])
def test_streaming_api_async(self, type, connect):
chatbot = gr.ChatInterface(async_stream, type=type).queue()
with connect(chatbot) as client:
job = client.submit("hello")
wait([job])
assert job.outputs() == ["h", "he", "hel", "hell", "hello"]
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_non_streaming_api(self, msg_format, connect):
chatbot = gr.ChatInterface(double, msg_format=msg_format)
@pytest.mark.parametrize("type", ["tuples", "messages"])
def test_non_streaming_api(self, type, connect):
chatbot = gr.ChatInterface(double, type=type)
with connect(chatbot) as client:
result = client.predict("hello")
assert result == "hello hello"
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_non_streaming_api_async(self, msg_format, connect):
chatbot = gr.ChatInterface(async_greet, msg_format=msg_format)
@pytest.mark.parametrize("type", ["tuples", "messages"])
def test_non_streaming_api_async(self, type, connect):
chatbot = gr.ChatInterface(async_greet, type=type)
with connect(chatbot) as client:
result = client.predict("gradio")
assert result == "hi, gradio"
@pytest.mark.parametrize("msg_format", ["tuples", "messages"])
def test_streaming_api_with_additional_inputs(self, msg_format, connect):
@pytest.mark.parametrize("type", ["tuples", "messages"])
def test_streaming_api_with_additional_inputs(self, type, connect):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
msg_format=msg_format,
type=type,
additional_inputs=["textbox", "slider"],
).queue()
with connect(chatbot) as client: