Adds additional_inputs to gr.ChatInterface (#4985)

* adding additional inputs

* add param

* guide

* add is_rendered

* add demo

* fixing examples

* add test

* guide

* add changeset

* Fix typos

* Remove label

* Revert "Remove label"

This reverts commit 10042856151b0ff79412613e6cb176dfc8642117.

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
This commit is contained in:
Abubakar Abid 2023-07-24 18:55:47 +03:00 committed by GitHub
parent 4b0e98e40a
commit b74f845303
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 251 additions and 53 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Adds `additional_inputs` to `gr.ChatInterface`

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,18 @@
import gradio as gr
import time
def echo(message, history, system_prompt, tokens):
response = f"System prompt: {system_prompt}\n Message: {message}."
for i in range(min(len(response), int(tokens))):
time.sleep(0.05)
yield response[: i+1]
demo = gr.ChatInterface(echo,
additional_inputs=[
gr.Textbox("You are helpful AI.", label="System Prompt"),
gr.Slider(10, 100)
]
)
if __name__ == "__main__":
demo.queue().launch()

View File

@ -110,6 +110,7 @@ class Block:
self.share_token = secrets.token_urlsafe(32)
self._skip_init_processing = _skip_init_processing
self.parent: BlockContext | None = None
self.is_rendered: bool = False
if render:
self.render()
@ -127,6 +128,7 @@ class Block:
Context.block.add(self)
if Context.root_block is not None:
Context.root_block.blocks[self._id] = self
self.is_rendered = True
if isinstance(self, components.IOComponent):
Context.root_block.temp_file_sets.append(self.temp_files)
return self
@ -144,6 +146,7 @@ class Block:
if Context.root_block is not None:
try:
del Context.root_block.blocks[self._id]
self.is_rendered = False
except KeyError:
pass
return self

View File

@ -6,22 +6,24 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
from __future__ import annotations
import inspect
import warnings
from typing import Callable, Generator
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
from gradio.blocks import Blocks
from gradio.components import (
Button,
Chatbot,
IOComponent,
Markdown,
State,
Textbox,
get_component_instance,
)
from gradio.events import Dependency, EventListenerMethod
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.layouts import Column, Group, Row
from gradio.layouts import Accordion, Column, Group, Row
from gradio.themes import ThemeClass as Theme
set_documentation_group("chatinterface")
@ -53,6 +55,8 @@ class ChatInterface(Blocks):
*,
chatbot: Chatbot | None = None,
textbox: Textbox | None = None,
additional_inputs: str | IOComponent | list[str | IOComponent] | None = None,
additional_inputs_accordion_name: str = "Additional Inputs",
examples: list[str] | None = None,
cache_examples: bool | None = None,
title: str | None = None,
@ -65,12 +69,15 @@ class ChatInterface(Blocks):
retry_btn: str | None | Button = "🔄 Retry",
undo_btn: str | None | Button = "↩️ Undo",
clear_btn: str | None | Button = "🗑️ Clear",
autofocus: bool = True,
):
"""
Parameters:
fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided.
examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
@ -83,6 +90,7 @@ class ChatInterface(Blocks):
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
autofocus: If True, autofocuses to the textbox when the page loads.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -91,12 +99,6 @@ class ChatInterface(Blocks):
title=title or "Gradio",
theme=theme,
)
if len(inspect.signature(fn).parameters) != 2:
warnings.warn(
"The function to ChatInterface should take two inputs (message, history) and return a single string response.",
UserWarning,
)
self.fn = fn
self.is_generator = inspect.isgeneratorfunction(self.fn)
self.examples = examples
@ -106,6 +108,16 @@ class ChatInterface(Blocks):
self.cache_examples = cache_examples or False
self.buttons: list[Button] = []
if additional_inputs:
if not isinstance(additional_inputs, list):
additional_inputs = [additional_inputs]
self.additional_inputs = [
get_component_instance(i, render=False) for i in additional_inputs # type: ignore
]
else:
self.additional_inputs = []
self.additional_inputs_accordion_name = additional_inputs_accordion_name
with self:
if title:
Markdown(
@ -130,9 +142,10 @@ class ChatInterface(Blocks):
self.textbox = Textbox(
container=False,
show_label=False,
label="Message",
placeholder="Type a message...",
scale=7,
autofocus=True,
autofocus=autofocus,
)
if submit_btn:
if isinstance(submit_btn, Button):
@ -199,12 +212,24 @@ class ChatInterface(Blocks):
self.examples_handler = Examples(
examples=examples,
inputs=self.textbox,
inputs=[self.textbox] + self.additional_inputs,
outputs=self.chatbot,
fn=examples_fn,
cache_examples=self.cache_examples,
)
any_unrendered_inputs = any(
not inp.is_rendered for inp in self.additional_inputs
)
if self.additional_inputs and any_unrendered_inputs:
with Accordion(self.additional_inputs_accordion_name, open=False):
for input_component in self.additional_inputs:
if not input_component.is_rendered:
input_component.render()
# The example caching must happen after the input components have rendered
if cache_examples:
client_utils.synchronize_async(self.examples_handler.cache)
self.saved_input = State()
self.chatbot_state = State([])
@ -230,7 +255,7 @@ class ChatInterface(Blocks):
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
@ -255,7 +280,7 @@ class ChatInterface(Blocks):
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
@ -280,7 +305,7 @@ class ChatInterface(Blocks):
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
@ -358,7 +383,7 @@ class ChatInterface(Blocks):
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state],
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state],
api_name="chat",
)
@ -373,18 +398,26 @@ class ChatInterface(Blocks):
return history, history
def _submit_fn(
self, message: str, history_with_input: list[list[str | None]]
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history = history_with_input[:-1]
response = self.fn(message, history)
response = self.fn(message, history, *args, **kwargs)
history.append([message, response])
return history, history
def _stream_fn(
self, message: str, history_with_input: list[list[str | None]]
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
history = history_with_input[:-1]
generator = self.fn(message, history)
generator = self.fn(message, history, *args, **kwargs)
try:
first_response = next(generator)
update = history + [[message, first_response]]
@ -397,16 +430,16 @@ class ChatInterface(Blocks):
yield update, update
def _api_submit_fn(
self, message: str, history: list[list[str | None]]
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> tuple[str, list[list[str | None]]]:
response = self.fn(message, history)
history.append([message, response])
return response, history
def _api_stream_fn(
self, message: str, history: list[list[str | None]]
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> Generator[tuple[str | None, list[list[str | None]]], None, None]:
generator = self.fn(message, history)
generator = self.fn(message, history, *args, **kwargs)
try:
first_response = next(generator)
yield first_response, history + [[message, first_response]]
@ -415,13 +448,16 @@ class ChatInterface(Blocks):
for response in generator:
yield response, history + [[message, response]]
def _examples_fn(self, message: str) -> list[list[str | None]]:
return [[message, self.fn(message, [])]]
def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]:
return [[message, self.fn(message, [], *args, **kwargs)]]
def _examples_stream_fn(
self, message: str
self,
message: str,
*args,
**kwargs,
) -> Generator[list[list[str | None]], None, None]:
for response in self.fn(message, []):
for response in self.fn(message, [], *args, **kwargs):
yield [[message, response]]
def _delete_prev_fn(

View File

@ -195,7 +195,7 @@ class Examples:
self.non_none_examples = non_none_examples
self.inputs = inputs
self.inputs_with_examples = inputs_with_examples
self.outputs = outputs
self.outputs = outputs or []
self.fn = fn
self.cache_examples = cache_examples
self._api_mode = _api_mode
@ -250,23 +250,14 @@ class Examples:
component to hold the examples"""
async def load_example(example_id):
if self.cache_examples:
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
else:
processed_example = self.non_none_processed_examples[example_id]
processed_example = self.non_none_processed_examples[example_id]
return utils.resolve_singleton(processed_example)
if Context.root_block:
if self.cache_examples and self.outputs:
targets = self.inputs_with_examples + self.outputs
else:
targets = self.inputs_with_examples
load_input_event = self.dataset.click(
self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=targets, # type: ignore
outputs=self.inputs_with_examples, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
@ -275,7 +266,7 @@ class Examples:
if self.run_on_click and not self.cache_examples:
if self.fn is None:
raise ValueError("Cannot run_on_click if no function is provided")
load_input_event.then(
self.load_input_event.then(
self.fn,
inputs=self.inputs, # type: ignore
outputs=self.outputs, # type: ignore
@ -301,25 +292,24 @@ class Examples:
if inspect.isgeneratorfunction(self.fn):
def get_final_item(args): # type: ignore
def get_final_item(*args): # type: ignore
x = None
for x in self.fn(args): # noqa: B007 # type: ignore
for x in self.fn(*args): # noqa: B007 # type: ignore
pass
return x
fn = get_final_item
elif inspect.isasyncgenfunction(self.fn):
async def get_final_item(args):
async def get_final_item(*args):
x = None
async for x in self.fn(args): # noqa: B007 # type: ignore
async for x in self.fn(*args): # noqa: B007 # type: ignore
pass
return x
fn = get_final_item
else:
fn = self.fn
# create a fake dependency to process the examples and get the predictions
dependency, fn_index = Context.root_block.set_event_trigger(
event_name="fake_event",
@ -352,6 +342,30 @@ class Examples:
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)
# Remove the original load_input_event and replace it with one that
# also populates the input. We do it this way to to allow the cache()
# method to be called independently of the create() method
index = Context.root_block.dependencies.index(self.load_input_event)
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)
async def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)
self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
api_name=self.api_name, # type: ignore
)
print("Caching complete\n")
async def load_from_cache(self, example_id: int) -> list[Any]:

View File

@ -87,7 +87,7 @@ def slow_echo(message, history):
gr.ChatInterface(slow_echo).queue().launch()
```
Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions.
Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions. While the response is streaming, the "Submit" button turns into a "Stop" button that can be used to stop the generator function. You can customize the appearance and behavior of the "Stop" button using the `stop_btn` parameter.
## Customizing your chatbot
@ -125,11 +125,44 @@ gr.ChatInterface(
).launch()
```
## Additional Inputs
You may want to add additional parameters to your chatbot and expose them to your users through the Chatbot UI. For example, suppose you want to add a textbox for a system prompt, or a slider that sets the number of tokens in the chatbot's response. The `ChatInterface` class supports an `additional_inputs` parameter which can be used to add additional input components.
The `additional_inputs` parameters accepts a component or a list of components. You can pass the component instances directly, or use their string shortcuts (e.g. `"textbox"` instead of `gr.Textbox()`). If you pass in component instances, and they have *not* already been rendered, then the components will appear underneath the chatbot (and any examples) within a `gr.Accordion()`. You can set the label of this accordion using the `additional_inputs_accordion_name` parameter.
Here's a complete example:
$code_chatinterface_system_prompt
If the components you pass into the `additional_inputs` have already been rendered in a parent `gr.Blocks()`, then they will *not* be re-rendered in the accordion. This provides flexibility in deciding where to lay out the input components. In the example below, we position the `gr.Textbox()` on top of the Chatbot UI, while keeping the slider underneath.
```python
import gradio as gr
import time
def echo(message, history, system_prompt, tokens):
response = f"System prompt: {system_prompt}\n Message: {message}."
for i in range(min(len(response), int(tokens))):
time.sleep(0.05)
yield response[: i+1]
with gr.Blocks() as demo:
system_prompt = gr.Textbox("You are helpful AI.", label="System Prompt")
slider = gr.Slider(10, 100, render=False)
gr.ChatInterface(
echo, additional_inputs=[system_prompt, slider]
)
demo.queue().launch()
```
If you need to create something even more custom, then its best to construct the chatbot UI using the low-level `gr.Blocks()` API. We have [a dedicated guide for that here](/guides/creating-a-custom-chatbot-with-blocks).
## Using your chatbot via an API
Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message, and will return the response, internally keeping track of the messages sent so far.
Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message (and potentially additional inputs if you have set any using the `additional_inputs` parameter), and will return the response, internally keeping track of the messages sent so far.
[](https://github.com/gradio-app/gradio/assets/1778297/7b10d6db-6476-4e2e-bebd-ecda802c3b8f)
@ -251,4 +284,4 @@ def predict(message, history):
gr.ChatInterface(predict).queue().launch()
```
With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building more custom Chabot UI, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API.
With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building even more custom Chatbot applications, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API.

View File

@ -1221,6 +1221,37 @@ class TestRender:
io3 = io2.render()
assert io2 == io3
def test_is_rendered(self):
t = gr.Textbox()
with gr.Blocks():
pass
assert not t.is_rendered
t = gr.Textbox()
with gr.Blocks():
t.render()
assert t.is_rendered
t = gr.Textbox()
with gr.Blocks():
t.render()
t.unrender()
assert not t.is_rendered
with gr.Blocks():
t = gr.Textbox()
assert t.is_rendered
with gr.Blocks():
t = gr.Textbox()
with gr.Blocks():
pass
assert t.is_rendered
t = gr.Textbox()
gr.Interface(lambda x: x, "textbox", t)
assert t.is_rendered
def test_no_error_if_state_rendered_multiple_times(self):
state = gr.State("")
gr.TabbedInterface(

View File

@ -1,8 +1,10 @@
import tempfile
from concurrent.futures import wait
import pytest
import gradio as gr
from gradio import helpers
def invalid_fn(message):
@ -22,15 +24,17 @@ def count(message, history):
return str(len(history))
def echo_system_prompt_plus_message(message, history, system_prompt, tokens):
response = f"{system_prompt} {message}"
for i in range(min(len(response), int(tokens))):
yield response[: i + 1]
class TestInit:
def test_no_fn(self):
with pytest.raises(TypeError):
gr.ChatInterface()
def test_invalid_fn_inputs(self):
with pytest.warns(UserWarning):
gr.ChatInterface(invalid_fn)
def test_configuring_buttons(self):
chatbot = gr.ChatInterface(double, submit_btn=None, retry_btn=None)
assert chatbot.submit_btn is None
@ -74,7 +78,8 @@ class TestInit:
assert prediction_hi[0][0] == ["hi", "hi hi"]
@pytest.mark.asyncio
async def test_example_caching_with_streaming(self):
async def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
@ -83,6 +88,40 @@ class TestInit:
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]
@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "robot hello"]
assert prediction_hi[0][0] == ["hi", "ro"]
@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs_already_rendered(
self, monkeypatch
):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "robot hello"]
assert prediction_hi[0][0] == ["hi", "ro"]
class TestAPI:
def test_get_api_info(self):
@ -104,3 +143,21 @@ class TestAPI:
with connect(chatbot) as client:
result = client.predict("hello")
assert result == "hello hello"
def test_streaming_api_with_additional_inputs(self, connect):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
).queue()
with connect(chatbot) as client:
job = client.submit("hello", "robot", 7)
wait([job])
assert job.outputs() == [
"r",
"ro",
"rob",
"robo",
"robot",
"robot ",
"robot h",
]