diff --git a/.changeset/witty-pets-rhyme.md b/.changeset/witty-pets-rhyme.md new file mode 100644 index 0000000000..eaf2c16692 --- /dev/null +++ b/.changeset/witty-pets-rhyme.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Adds `additional_inputs` to `gr.ChatInterface` diff --git a/demo/chatinterface_system_prompt/run.ipynb b/demo/chatinterface_system_prompt/run.ipynb new file mode 100644 index 0000000000..c92e7ea5d9 --- /dev/null +++ b/demo/chatinterface_system_prompt/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatinterface_system_prompt/run.py b/demo/chatinterface_system_prompt/run.py new file mode 100644 index 0000000000..e8b1422c4f --- /dev/null +++ b/demo/chatinterface_system_prompt/run.py @@ -0,0 +1,18 @@ +import gradio as gr +import time + +def echo(message, history, system_prompt, tokens): + response = f"System prompt: {system_prompt}\n Message: {message}." + for i in range(min(len(response), int(tokens))): + time.sleep(0.05) + yield response[: i+1] + +demo = gr.ChatInterface(echo, + additional_inputs=[ + gr.Textbox("You are helpful AI.", label="System Prompt"), + gr.Slider(10, 100) + ] + ) + +if __name__ == "__main__": + demo.queue().launch() \ No newline at end of file diff --git a/gradio/blocks.py b/gradio/blocks.py index f441d39d9b..8386ecf21f 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -110,6 +110,7 @@ class Block: self.share_token = secrets.token_urlsafe(32) self._skip_init_processing = _skip_init_processing self.parent: BlockContext | None = None + self.is_rendered: bool = False if render: self.render() @@ -127,6 +128,7 @@ class Block: Context.block.add(self) if Context.root_block is not None: Context.root_block.blocks[self._id] = self + self.is_rendered = True if isinstance(self, components.IOComponent): Context.root_block.temp_file_sets.append(self.temp_files) return self @@ -144,6 +146,7 @@ class Block: if Context.root_block is not None: try: del Context.root_block.blocks[self._id] + self.is_rendered = False except KeyError: pass return self diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 0c9fc3b85c..aaee1c820e 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -6,22 +6,24 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat from __future__ import annotations import inspect -import warnings from typing import Callable, Generator +from gradio_client import utils as client_utils from gradio_client.documentation import document, set_documentation_group from gradio.blocks import Blocks from gradio.components import ( Button, Chatbot, + IOComponent, Markdown, State, Textbox, + get_component_instance, ) from gradio.events import Dependency, EventListenerMethod from gradio.helpers import create_examples as Examples # noqa: N812 -from gradio.layouts import Column, Group, Row +from gradio.layouts import Accordion, Column, Group, Row from gradio.themes import ThemeClass as Theme set_documentation_group("chatinterface") @@ -53,6 +55,8 @@ class ChatInterface(Blocks): *, chatbot: Chatbot | None = None, textbox: Textbox | None = None, + additional_inputs: str | IOComponent | list[str | IOComponent] | None = None, + additional_inputs_accordion_name: str = "Additional Inputs", examples: list[str] | None = None, cache_examples: bool | None = None, title: str | None = None, @@ -65,12 +69,15 @@ class ChatInterface(Blocks): retry_btn: str | None | Button = "🔄 Retry", undo_btn: str | None | Button = "↩ī¸ Undo", clear_btn: str | None | Button = "🗑ī¸ Clear", + autofocus: bool = True, ): """ Parameters: fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format. chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created. textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created. + additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion. + additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided. examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input. cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False. title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window. @@ -83,6 +90,7 @@ class ChatInterface(Blocks): retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used. undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used. clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used. + autofocus: If True, autofocuses to the textbox when the page loads. """ super().__init__( analytics_enabled=analytics_enabled, @@ -91,12 +99,6 @@ class ChatInterface(Blocks): title=title or "Gradio", theme=theme, ) - if len(inspect.signature(fn).parameters) != 2: - warnings.warn( - "The function to ChatInterface should take two inputs (message, history) and return a single string response.", - UserWarning, - ) - self.fn = fn self.is_generator = inspect.isgeneratorfunction(self.fn) self.examples = examples @@ -106,6 +108,16 @@ class ChatInterface(Blocks): self.cache_examples = cache_examples or False self.buttons: list[Button] = [] + if additional_inputs: + if not isinstance(additional_inputs, list): + additional_inputs = [additional_inputs] + self.additional_inputs = [ + get_component_instance(i, render=False) for i in additional_inputs # type: ignore + ] + else: + self.additional_inputs = [] + self.additional_inputs_accordion_name = additional_inputs_accordion_name + with self: if title: Markdown( @@ -130,9 +142,10 @@ class ChatInterface(Blocks): self.textbox = Textbox( container=False, show_label=False, + label="Message", placeholder="Type a message...", scale=7, - autofocus=True, + autofocus=autofocus, ) if submit_btn: if isinstance(submit_btn, Button): @@ -199,12 +212,24 @@ class ChatInterface(Blocks): self.examples_handler = Examples( examples=examples, - inputs=self.textbox, + inputs=[self.textbox] + self.additional_inputs, outputs=self.chatbot, fn=examples_fn, - cache_examples=self.cache_examples, ) + any_unrendered_inputs = any( + not inp.is_rendered for inp in self.additional_inputs + ) + if self.additional_inputs and any_unrendered_inputs: + with Accordion(self.additional_inputs_accordion_name, open=False): + for input_component in self.additional_inputs: + if not input_component.is_rendered: + input_component.render() + + # The example caching must happen after the input components have rendered + if cache_examples: + client_utils.synchronize_async(self.examples_handler.cache) + self.saved_input = State() self.chatbot_state = State([]) @@ -230,7 +255,7 @@ class ChatInterface(Blocks): ) .then( submit_fn, - [self.saved_input, self.chatbot_state], + [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, ) @@ -255,7 +280,7 @@ class ChatInterface(Blocks): ) .then( submit_fn, - [self.saved_input, self.chatbot_state], + [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, ) @@ -280,7 +305,7 @@ class ChatInterface(Blocks): ) .then( submit_fn, - [self.saved_input, self.chatbot_state], + [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, ) @@ -358,7 +383,7 @@ class ChatInterface(Blocks): self.fake_api_btn.click( api_fn, - [self.textbox, self.chatbot_state], + [self.textbox, self.chatbot_state] + self.additional_inputs, [self.textbox, self.chatbot_state], api_name="chat", ) @@ -373,18 +398,26 @@ class ChatInterface(Blocks): return history, history def _submit_fn( - self, message: str, history_with_input: list[list[str | None]] + self, + message: str, + history_with_input: list[list[str | None]], + *args, + **kwargs, ) -> tuple[list[list[str | None]], list[list[str | None]]]: history = history_with_input[:-1] - response = self.fn(message, history) + response = self.fn(message, history, *args, **kwargs) history.append([message, response]) return history, history def _stream_fn( - self, message: str, history_with_input: list[list[str | None]] + self, + message: str, + history_with_input: list[list[str | None]], + *args, + **kwargs, ) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]: history = history_with_input[:-1] - generator = self.fn(message, history) + generator = self.fn(message, history, *args, **kwargs) try: first_response = next(generator) update = history + [[message, first_response]] @@ -397,16 +430,16 @@ class ChatInterface(Blocks): yield update, update def _api_submit_fn( - self, message: str, history: list[list[str | None]] + self, message: str, history: list[list[str | None]], *args, **kwargs ) -> tuple[str, list[list[str | None]]]: response = self.fn(message, history) history.append([message, response]) return response, history def _api_stream_fn( - self, message: str, history: list[list[str | None]] + self, message: str, history: list[list[str | None]], *args, **kwargs ) -> Generator[tuple[str | None, list[list[str | None]]], None, None]: - generator = self.fn(message, history) + generator = self.fn(message, history, *args, **kwargs) try: first_response = next(generator) yield first_response, history + [[message, first_response]] @@ -415,13 +448,16 @@ class ChatInterface(Blocks): for response in generator: yield response, history + [[message, response]] - def _examples_fn(self, message: str) -> list[list[str | None]]: - return [[message, self.fn(message, [])]] + def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]: + return [[message, self.fn(message, [], *args, **kwargs)]] def _examples_stream_fn( - self, message: str + self, + message: str, + *args, + **kwargs, ) -> Generator[list[list[str | None]], None, None]: - for response in self.fn(message, []): + for response in self.fn(message, [], *args, **kwargs): yield [[message, response]] def _delete_prev_fn( diff --git a/gradio/helpers.py b/gradio/helpers.py index 660c89de70..188bd6059b 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -195,7 +195,7 @@ class Examples: self.non_none_examples = non_none_examples self.inputs = inputs self.inputs_with_examples = inputs_with_examples - self.outputs = outputs + self.outputs = outputs or [] self.fn = fn self.cache_examples = cache_examples self._api_mode = _api_mode @@ -250,23 +250,14 @@ class Examples: component to hold the examples""" async def load_example(example_id): - if self.cache_examples: - processed_example = self.non_none_processed_examples[ - example_id - ] + await self.load_from_cache(example_id) - else: - processed_example = self.non_none_processed_examples[example_id] + processed_example = self.non_none_processed_examples[example_id] return utils.resolve_singleton(processed_example) if Context.root_block: - if self.cache_examples and self.outputs: - targets = self.inputs_with_examples + self.outputs - else: - targets = self.inputs_with_examples - load_input_event = self.dataset.click( + self.load_input_event = self.dataset.click( load_example, inputs=[self.dataset], - outputs=targets, # type: ignore + outputs=self.inputs_with_examples, # type: ignore show_progress="hidden", postprocess=False, queue=False, @@ -275,7 +266,7 @@ class Examples: if self.run_on_click and not self.cache_examples: if self.fn is None: raise ValueError("Cannot run_on_click if no function is provided") - load_input_event.then( + self.load_input_event.then( self.fn, inputs=self.inputs, # type: ignore outputs=self.outputs, # type: ignore @@ -301,25 +292,24 @@ class Examples: if inspect.isgeneratorfunction(self.fn): - def get_final_item(args): # type: ignore + def get_final_item(*args): # type: ignore x = None - for x in self.fn(args): # noqa: B007 # type: ignore + for x in self.fn(*args): # noqa: B007 # type: ignore pass return x fn = get_final_item elif inspect.isasyncgenfunction(self.fn): - async def get_final_item(args): + async def get_final_item(*args): x = None - async for x in self.fn(args): # noqa: B007 # type: ignore + async for x in self.fn(*args): # noqa: B007 # type: ignore pass return x fn = get_final_item else: fn = self.fn - # create a fake dependency to process the examples and get the predictions dependency, fn_index = Context.root_block.set_event_trigger( event_name="fake_event", @@ -352,6 +342,30 @@ class Examples: # Remove the "fake_event" to prevent bugs in loading interfaces from spaces Context.root_block.dependencies.remove(dependency) Context.root_block.fns.pop(fn_index) + + # Remove the original load_input_event and replace it with one that + # also populates the input. We do it this way to to allow the cache() + # method to be called independently of the create() method + index = Context.root_block.dependencies.index(self.load_input_event) + Context.root_block.dependencies.pop(index) + Context.root_block.fns.pop(index) + + async def load_example(example_id): + processed_example = self.non_none_processed_examples[ + example_id + ] + await self.load_from_cache(example_id) + return utils.resolve_singleton(processed_example) + + self.load_input_event = self.dataset.click( + load_example, + inputs=[self.dataset], + outputs=self.inputs_with_examples + self.outputs, # type: ignore + show_progress="hidden", + postprocess=False, + queue=False, + api_name=self.api_name, # type: ignore + ) + print("Caching complete\n") async def load_from_cache(self, example_id: int) -> list[Any]: diff --git a/guides/04_chatbots/01_creating-a-chatbot-fast.md b/guides/04_chatbots/01_creating-a-chatbot-fast.md index 0b307537f9..353dcd6d5c 100644 --- a/guides/04_chatbots/01_creating-a-chatbot-fast.md +++ b/guides/04_chatbots/01_creating-a-chatbot-fast.md @@ -87,7 +87,7 @@ def slow_echo(message, history): gr.ChatInterface(slow_echo).queue().launch() ``` -Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions. +Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions. While the response is streaming, the "Submit" button turns into a "Stop" button that can be used to stop the generator function. You can customize the appearance and behavior of the "Stop" button using the `stop_btn` parameter. ## Customizing your chatbot @@ -125,11 +125,44 @@ gr.ChatInterface( ).launch() ``` +## Additional Inputs + +You may want to add additional parameters to your chatbot and expose them to your users through the Chatbot UI. For example, suppose you want to add a textbox for a system prompt, or a slider that sets the number of tokens in the chatbot's response. The `ChatInterface` class supports an `additional_inputs` parameter which can be used to add additional input components. + +The `additional_inputs` parameters accepts a component or a list of components. You can pass the component instances directly, or use their string shortcuts (e.g. `"textbox"` instead of `gr.Textbox()`). If you pass in component instances, and they have *not* already been rendered, then the components will appear underneath the chatbot (and any examples) within a `gr.Accordion()`. You can set the label of this accordion using the `additional_inputs_accordion_name` parameter. + +Here's a complete example: + +$code_chatinterface_system_prompt + +If the components you pass into the `additional_inputs` have already been rendered in a parent `gr.Blocks()`, then they will *not* be re-rendered in the accordion. This provides flexibility in deciding where to lay out the input components. In the example below, we position the `gr.Textbox()` on top of the Chatbot UI, while keeping the slider underneath. + +```python +import gradio as gr +import time + +def echo(message, history, system_prompt, tokens): + response = f"System prompt: {system_prompt}\n Message: {message}." + for i in range(min(len(response), int(tokens))): + time.sleep(0.05) + yield response[: i+1] + +with gr.Blocks() as demo: + system_prompt = gr.Textbox("You are helpful AI.", label="System Prompt") + slider = gr.Slider(10, 100, render=False) + + gr.ChatInterface( + echo, additional_inputs=[system_prompt, slider] + ) + +demo.queue().launch() +``` + If you need to create something even more custom, then its best to construct the chatbot UI using the low-level `gr.Blocks()` API. We have [a dedicated guide for that here](/guides/creating-a-custom-chatbot-with-blocks). ## Using your chatbot via an API -Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message, and will return the response, internally keeping track of the messages sent so far. +Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message (and potentially additional inputs if you have set any using the `additional_inputs` parameter), and will return the response, internally keeping track of the messages sent so far. [](https://github.com/gradio-app/gradio/assets/1778297/7b10d6db-6476-4e2e-bebd-ecda802c3b8f) @@ -251,4 +284,4 @@ def predict(message, history): gr.ChatInterface(predict).queue().launch() ``` -With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building more custom Chabot UI, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API. \ No newline at end of file +With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building even more custom Chatbot applications, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API. \ No newline at end of file diff --git a/test/test_blocks.py b/test/test_blocks.py index a61423137e..4126a1fb67 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -1221,6 +1221,37 @@ class TestRender: io3 = io2.render() assert io2 == io3 + def test_is_rendered(self): + t = gr.Textbox() + with gr.Blocks(): + pass + assert not t.is_rendered + + t = gr.Textbox() + with gr.Blocks(): + t.render() + assert t.is_rendered + + t = gr.Textbox() + with gr.Blocks(): + t.render() + t.unrender() + assert not t.is_rendered + + with gr.Blocks(): + t = gr.Textbox() + assert t.is_rendered + + with gr.Blocks(): + t = gr.Textbox() + with gr.Blocks(): + pass + assert t.is_rendered + + t = gr.Textbox() + gr.Interface(lambda x: x, "textbox", t) + assert t.is_rendered + def test_no_error_if_state_rendered_multiple_times(self): state = gr.State("") gr.TabbedInterface( diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index b398325da0..c15ddc3585 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -1,8 +1,10 @@ +import tempfile from concurrent.futures import wait import pytest import gradio as gr +from gradio import helpers def invalid_fn(message): @@ -22,15 +24,17 @@ def count(message, history): return str(len(history)) +def echo_system_prompt_plus_message(message, history, system_prompt, tokens): + response = f"{system_prompt} {message}" + for i in range(min(len(response), int(tokens))): + yield response[: i + 1] + + class TestInit: def test_no_fn(self): with pytest.raises(TypeError): gr.ChatInterface() - def test_invalid_fn_inputs(self): - with pytest.warns(UserWarning): - gr.ChatInterface(invalid_fn) - def test_configuring_buttons(self): chatbot = gr.ChatInterface(double, submit_btn=None, retry_btn=None) assert chatbot.submit_btn is None @@ -74,7 +78,8 @@ class TestInit: assert prediction_hi[0][0] == ["hi", "hi hi"] @pytest.mark.asyncio - async def test_example_caching_with_streaming(self): + async def test_example_caching_with_streaming(self, monkeypatch): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) chatbot = gr.ChatInterface( stream, examples=["hello", "hi"], cache_examples=True ) @@ -83,6 +88,40 @@ class TestInit: assert prediction_hello[0][0] == ["hello", "hello"] assert prediction_hi[0][0] == ["hi", "hi"] + @pytest.mark.asyncio + async def test_example_caching_with_additional_inputs(self, monkeypatch): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) + chatbot = gr.ChatInterface( + echo_system_prompt_plus_message, + additional_inputs=["textbox", "slider"], + examples=[["hello", "robot", 100], ["hi", "robot", 2]], + cache_examples=True, + ) + prediction_hello = await chatbot.examples_handler.load_from_cache(0) + prediction_hi = await chatbot.examples_handler.load_from_cache(1) + assert prediction_hello[0][0] == ["hello", "robot hello"] + assert prediction_hi[0][0] == ["hi", "ro"] + + @pytest.mark.asyncio + async def test_example_caching_with_additional_inputs_already_rendered( + self, monkeypatch + ): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) + with gr.Blocks(): + with gr.Accordion("Inputs"): + text = gr.Textbox() + slider = gr.Slider() + chatbot = gr.ChatInterface( + echo_system_prompt_plus_message, + additional_inputs=[text, slider], + examples=[["hello", "robot", 100], ["hi", "robot", 2]], + cache_examples=True, + ) + prediction_hello = await chatbot.examples_handler.load_from_cache(0) + prediction_hi = await chatbot.examples_handler.load_from_cache(1) + assert prediction_hello[0][0] == ["hello", "robot hello"] + assert prediction_hi[0][0] == ["hi", "ro"] + class TestAPI: def test_get_api_info(self): @@ -104,3 +143,21 @@ class TestAPI: with connect(chatbot) as client: result = client.predict("hello") assert result == "hello hello" + + def test_streaming_api_with_additional_inputs(self, connect): + chatbot = gr.ChatInterface( + echo_system_prompt_plus_message, + additional_inputs=["textbox", "slider"], + ).queue() + with connect(chatbot) as client: + job = client.submit("hello", "robot", 7) + wait([job]) + assert job.outputs() == [ + "r", + "ro", + "rob", + "robo", + "robot", + "robot ", + "robot h", + ]