mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Adds additional_inputs
to gr.ChatInterface
(#4985)
* adding additional inputs * add param * guide * add is_rendered * add demo * fixing examples * add test * guide * add changeset * Fix typos * Remove label * Revert "Remove label" This reverts commit 10042856151b0ff79412613e6cb176dfc8642117. * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
This commit is contained in:
parent
4b0e98e40a
commit
b74f845303
5
.changeset/witty-pets-rhyme.md
Normal file
5
.changeset/witty-pets-rhyme.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Adds `additional_inputs` to `gr.ChatInterface`
|
1
demo/chatinterface_system_prompt/run.ipynb
Normal file
1
demo/chatinterface_system_prompt/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
18
demo/chatinterface_system_prompt/run.py
Normal file
18
demo/chatinterface_system_prompt/run.py
Normal file
@ -0,0 +1,18 @@
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
def echo(message, history, system_prompt, tokens):
|
||||
response = f"System prompt: {system_prompt}\n Message: {message}."
|
||||
for i in range(min(len(response), int(tokens))):
|
||||
time.sleep(0.05)
|
||||
yield response[: i+1]
|
||||
|
||||
demo = gr.ChatInterface(echo,
|
||||
additional_inputs=[
|
||||
gr.Textbox("You are helpful AI.", label="System Prompt"),
|
||||
gr.Slider(10, 100)
|
||||
]
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue().launch()
|
@ -110,6 +110,7 @@ class Block:
|
||||
self.share_token = secrets.token_urlsafe(32)
|
||||
self._skip_init_processing = _skip_init_processing
|
||||
self.parent: BlockContext | None = None
|
||||
self.is_rendered: bool = False
|
||||
|
||||
if render:
|
||||
self.render()
|
||||
@ -127,6 +128,7 @@ class Block:
|
||||
Context.block.add(self)
|
||||
if Context.root_block is not None:
|
||||
Context.root_block.blocks[self._id] = self
|
||||
self.is_rendered = True
|
||||
if isinstance(self, components.IOComponent):
|
||||
Context.root_block.temp_file_sets.append(self.temp_files)
|
||||
return self
|
||||
@ -144,6 +146,7 @@ class Block:
|
||||
if Context.root_block is not None:
|
||||
try:
|
||||
del Context.root_block.blocks[self._id]
|
||||
self.is_rendered = False
|
||||
except KeyError:
|
||||
pass
|
||||
return self
|
||||
|
@ -6,22 +6,24 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, Generator
|
||||
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.components import (
|
||||
Button,
|
||||
Chatbot,
|
||||
IOComponent,
|
||||
Markdown,
|
||||
State,
|
||||
Textbox,
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.events import Dependency, EventListenerMethod
|
||||
from gradio.helpers import create_examples as Examples # noqa: N812
|
||||
from gradio.layouts import Column, Group, Row
|
||||
from gradio.layouts import Accordion, Column, Group, Row
|
||||
from gradio.themes import ThemeClass as Theme
|
||||
|
||||
set_documentation_group("chatinterface")
|
||||
@ -53,6 +55,8 @@ class ChatInterface(Blocks):
|
||||
*,
|
||||
chatbot: Chatbot | None = None,
|
||||
textbox: Textbox | None = None,
|
||||
additional_inputs: str | IOComponent | list[str | IOComponent] | None = None,
|
||||
additional_inputs_accordion_name: str = "Additional Inputs",
|
||||
examples: list[str] | None = None,
|
||||
cache_examples: bool | None = None,
|
||||
title: str | None = None,
|
||||
@ -65,12 +69,15 @@ class ChatInterface(Blocks):
|
||||
retry_btn: str | None | Button = "🔄 Retry",
|
||||
undo_btn: str | None | Button = "↩️ Undo",
|
||||
clear_btn: str | None | Button = "🗑️ Clear",
|
||||
autofocus: bool = True,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
|
||||
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
|
||||
textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
|
||||
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
|
||||
additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided.
|
||||
examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
|
||||
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
||||
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
|
||||
@ -83,6 +90,7 @@ class ChatInterface(Blocks):
|
||||
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
||||
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
||||
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
||||
autofocus: If True, autofocuses to the textbox when the page loads.
|
||||
"""
|
||||
super().__init__(
|
||||
analytics_enabled=analytics_enabled,
|
||||
@ -91,12 +99,6 @@ class ChatInterface(Blocks):
|
||||
title=title or "Gradio",
|
||||
theme=theme,
|
||||
)
|
||||
if len(inspect.signature(fn).parameters) != 2:
|
||||
warnings.warn(
|
||||
"The function to ChatInterface should take two inputs (message, history) and return a single string response.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.fn = fn
|
||||
self.is_generator = inspect.isgeneratorfunction(self.fn)
|
||||
self.examples = examples
|
||||
@ -106,6 +108,16 @@ class ChatInterface(Blocks):
|
||||
self.cache_examples = cache_examples or False
|
||||
self.buttons: list[Button] = []
|
||||
|
||||
if additional_inputs:
|
||||
if not isinstance(additional_inputs, list):
|
||||
additional_inputs = [additional_inputs]
|
||||
self.additional_inputs = [
|
||||
get_component_instance(i, render=False) for i in additional_inputs # type: ignore
|
||||
]
|
||||
else:
|
||||
self.additional_inputs = []
|
||||
self.additional_inputs_accordion_name = additional_inputs_accordion_name
|
||||
|
||||
with self:
|
||||
if title:
|
||||
Markdown(
|
||||
@ -130,9 +142,10 @@ class ChatInterface(Blocks):
|
||||
self.textbox = Textbox(
|
||||
container=False,
|
||||
show_label=False,
|
||||
label="Message",
|
||||
placeholder="Type a message...",
|
||||
scale=7,
|
||||
autofocus=True,
|
||||
autofocus=autofocus,
|
||||
)
|
||||
if submit_btn:
|
||||
if isinstance(submit_btn, Button):
|
||||
@ -199,12 +212,24 @@ class ChatInterface(Blocks):
|
||||
|
||||
self.examples_handler = Examples(
|
||||
examples=examples,
|
||||
inputs=self.textbox,
|
||||
inputs=[self.textbox] + self.additional_inputs,
|
||||
outputs=self.chatbot,
|
||||
fn=examples_fn,
|
||||
cache_examples=self.cache_examples,
|
||||
)
|
||||
|
||||
any_unrendered_inputs = any(
|
||||
not inp.is_rendered for inp in self.additional_inputs
|
||||
)
|
||||
if self.additional_inputs and any_unrendered_inputs:
|
||||
with Accordion(self.additional_inputs_accordion_name, open=False):
|
||||
for input_component in self.additional_inputs:
|
||||
if not input_component.is_rendered:
|
||||
input_component.render()
|
||||
|
||||
# The example caching must happen after the input components have rendered
|
||||
if cache_examples:
|
||||
client_utils.synchronize_async(self.examples_handler.cache)
|
||||
|
||||
self.saved_input = State()
|
||||
self.chatbot_state = State([])
|
||||
|
||||
@ -230,7 +255,7 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
@ -255,7 +280,7 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
@ -280,7 +305,7 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
@ -358,7 +383,7 @@ class ChatInterface(Blocks):
|
||||
|
||||
self.fake_api_btn.click(
|
||||
api_fn,
|
||||
[self.textbox, self.chatbot_state],
|
||||
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
||||
[self.textbox, self.chatbot_state],
|
||||
api_name="chat",
|
||||
)
|
||||
@ -373,18 +398,26 @@ class ChatInterface(Blocks):
|
||||
return history, history
|
||||
|
||||
def _submit_fn(
|
||||
self, message: str, history_with_input: list[list[str | None]]
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
||||
history = history_with_input[:-1]
|
||||
response = self.fn(message, history)
|
||||
response = self.fn(message, history, *args, **kwargs)
|
||||
history.append([message, response])
|
||||
return history, history
|
||||
|
||||
def _stream_fn(
|
||||
self, message: str, history_with_input: list[list[str | None]]
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
|
||||
history = history_with_input[:-1]
|
||||
generator = self.fn(message, history)
|
||||
generator = self.fn(message, history, *args, **kwargs)
|
||||
try:
|
||||
first_response = next(generator)
|
||||
update = history + [[message, first_response]]
|
||||
@ -397,16 +430,16 @@ class ChatInterface(Blocks):
|
||||
yield update, update
|
||||
|
||||
def _api_submit_fn(
|
||||
self, message: str, history: list[list[str | None]]
|
||||
self, message: str, history: list[list[str | None]], *args, **kwargs
|
||||
) -> tuple[str, list[list[str | None]]]:
|
||||
response = self.fn(message, history)
|
||||
history.append([message, response])
|
||||
return response, history
|
||||
|
||||
def _api_stream_fn(
|
||||
self, message: str, history: list[list[str | None]]
|
||||
self, message: str, history: list[list[str | None]], *args, **kwargs
|
||||
) -> Generator[tuple[str | None, list[list[str | None]]], None, None]:
|
||||
generator = self.fn(message, history)
|
||||
generator = self.fn(message, history, *args, **kwargs)
|
||||
try:
|
||||
first_response = next(generator)
|
||||
yield first_response, history + [[message, first_response]]
|
||||
@ -415,13 +448,16 @@ class ChatInterface(Blocks):
|
||||
for response in generator:
|
||||
yield response, history + [[message, response]]
|
||||
|
||||
def _examples_fn(self, message: str) -> list[list[str | None]]:
|
||||
return [[message, self.fn(message, [])]]
|
||||
def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]:
|
||||
return [[message, self.fn(message, [], *args, **kwargs)]]
|
||||
|
||||
def _examples_stream_fn(
|
||||
self, message: str
|
||||
self,
|
||||
message: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Generator[list[list[str | None]], None, None]:
|
||||
for response in self.fn(message, []):
|
||||
for response in self.fn(message, [], *args, **kwargs):
|
||||
yield [[message, response]]
|
||||
|
||||
def _delete_prev_fn(
|
||||
|
@ -195,7 +195,7 @@ class Examples:
|
||||
self.non_none_examples = non_none_examples
|
||||
self.inputs = inputs
|
||||
self.inputs_with_examples = inputs_with_examples
|
||||
self.outputs = outputs
|
||||
self.outputs = outputs or []
|
||||
self.fn = fn
|
||||
self.cache_examples = cache_examples
|
||||
self._api_mode = _api_mode
|
||||
@ -250,23 +250,14 @@ class Examples:
|
||||
component to hold the examples"""
|
||||
|
||||
async def load_example(example_id):
|
||||
if self.cache_examples:
|
||||
processed_example = self.non_none_processed_examples[
|
||||
example_id
|
||||
] + await self.load_from_cache(example_id)
|
||||
else:
|
||||
processed_example = self.non_none_processed_examples[example_id]
|
||||
processed_example = self.non_none_processed_examples[example_id]
|
||||
return utils.resolve_singleton(processed_example)
|
||||
|
||||
if Context.root_block:
|
||||
if self.cache_examples and self.outputs:
|
||||
targets = self.inputs_with_examples + self.outputs
|
||||
else:
|
||||
targets = self.inputs_with_examples
|
||||
load_input_event = self.dataset.click(
|
||||
self.load_input_event = self.dataset.click(
|
||||
load_example,
|
||||
inputs=[self.dataset],
|
||||
outputs=targets, # type: ignore
|
||||
outputs=self.inputs_with_examples, # type: ignore
|
||||
show_progress="hidden",
|
||||
postprocess=False,
|
||||
queue=False,
|
||||
@ -275,7 +266,7 @@ class Examples:
|
||||
if self.run_on_click and not self.cache_examples:
|
||||
if self.fn is None:
|
||||
raise ValueError("Cannot run_on_click if no function is provided")
|
||||
load_input_event.then(
|
||||
self.load_input_event.then(
|
||||
self.fn,
|
||||
inputs=self.inputs, # type: ignore
|
||||
outputs=self.outputs, # type: ignore
|
||||
@ -301,25 +292,24 @@ class Examples:
|
||||
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
|
||||
def get_final_item(args): # type: ignore
|
||||
def get_final_item(*args): # type: ignore
|
||||
x = None
|
||||
for x in self.fn(args): # noqa: B007 # type: ignore
|
||||
for x in self.fn(*args): # noqa: B007 # type: ignore
|
||||
pass
|
||||
return x
|
||||
|
||||
fn = get_final_item
|
||||
elif inspect.isasyncgenfunction(self.fn):
|
||||
|
||||
async def get_final_item(args):
|
||||
async def get_final_item(*args):
|
||||
x = None
|
||||
async for x in self.fn(args): # noqa: B007 # type: ignore
|
||||
async for x in self.fn(*args): # noqa: B007 # type: ignore
|
||||
pass
|
||||
return x
|
||||
|
||||
fn = get_final_item
|
||||
else:
|
||||
fn = self.fn
|
||||
|
||||
# create a fake dependency to process the examples and get the predictions
|
||||
dependency, fn_index = Context.root_block.set_event_trigger(
|
||||
event_name="fake_event",
|
||||
@ -352,6 +342,30 @@ class Examples:
|
||||
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
|
||||
Context.root_block.dependencies.remove(dependency)
|
||||
Context.root_block.fns.pop(fn_index)
|
||||
|
||||
# Remove the original load_input_event and replace it with one that
|
||||
# also populates the input. We do it this way to to allow the cache()
|
||||
# method to be called independently of the create() method
|
||||
index = Context.root_block.dependencies.index(self.load_input_event)
|
||||
Context.root_block.dependencies.pop(index)
|
||||
Context.root_block.fns.pop(index)
|
||||
|
||||
async def load_example(example_id):
|
||||
processed_example = self.non_none_processed_examples[
|
||||
example_id
|
||||
] + await self.load_from_cache(example_id)
|
||||
return utils.resolve_singleton(processed_example)
|
||||
|
||||
self.load_input_event = self.dataset.click(
|
||||
load_example,
|
||||
inputs=[self.dataset],
|
||||
outputs=self.inputs_with_examples + self.outputs, # type: ignore
|
||||
show_progress="hidden",
|
||||
postprocess=False,
|
||||
queue=False,
|
||||
api_name=self.api_name, # type: ignore
|
||||
)
|
||||
|
||||
print("Caching complete\n")
|
||||
|
||||
async def load_from_cache(self, example_id: int) -> list[Any]:
|
||||
|
@ -87,7 +87,7 @@ def slow_echo(message, history):
|
||||
gr.ChatInterface(slow_echo).queue().launch()
|
||||
```
|
||||
|
||||
Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions.
|
||||
Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions. While the response is streaming, the "Submit" button turns into a "Stop" button that can be used to stop the generator function. You can customize the appearance and behavior of the "Stop" button using the `stop_btn` parameter.
|
||||
|
||||
## Customizing your chatbot
|
||||
|
||||
@ -125,11 +125,44 @@ gr.ChatInterface(
|
||||
).launch()
|
||||
```
|
||||
|
||||
## Additional Inputs
|
||||
|
||||
You may want to add additional parameters to your chatbot and expose them to your users through the Chatbot UI. For example, suppose you want to add a textbox for a system prompt, or a slider that sets the number of tokens in the chatbot's response. The `ChatInterface` class supports an `additional_inputs` parameter which can be used to add additional input components.
|
||||
|
||||
The `additional_inputs` parameters accepts a component or a list of components. You can pass the component instances directly, or use their string shortcuts (e.g. `"textbox"` instead of `gr.Textbox()`). If you pass in component instances, and they have *not* already been rendered, then the components will appear underneath the chatbot (and any examples) within a `gr.Accordion()`. You can set the label of this accordion using the `additional_inputs_accordion_name` parameter.
|
||||
|
||||
Here's a complete example:
|
||||
|
||||
$code_chatinterface_system_prompt
|
||||
|
||||
If the components you pass into the `additional_inputs` have already been rendered in a parent `gr.Blocks()`, then they will *not* be re-rendered in the accordion. This provides flexibility in deciding where to lay out the input components. In the example below, we position the `gr.Textbox()` on top of the Chatbot UI, while keeping the slider underneath.
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
def echo(message, history, system_prompt, tokens):
|
||||
response = f"System prompt: {system_prompt}\n Message: {message}."
|
||||
for i in range(min(len(response), int(tokens))):
|
||||
time.sleep(0.05)
|
||||
yield response[: i+1]
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
system_prompt = gr.Textbox("You are helpful AI.", label="System Prompt")
|
||||
slider = gr.Slider(10, 100, render=False)
|
||||
|
||||
gr.ChatInterface(
|
||||
echo, additional_inputs=[system_prompt, slider]
|
||||
)
|
||||
|
||||
demo.queue().launch()
|
||||
```
|
||||
|
||||
If you need to create something even more custom, then its best to construct the chatbot UI using the low-level `gr.Blocks()` API. We have [a dedicated guide for that here](/guides/creating-a-custom-chatbot-with-blocks).
|
||||
|
||||
## Using your chatbot via an API
|
||||
|
||||
Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message, and will return the response, internally keeping track of the messages sent so far.
|
||||
Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message (and potentially additional inputs if you have set any using the `additional_inputs` parameter), and will return the response, internally keeping track of the messages sent so far.
|
||||
|
||||
[](https://github.com/gradio-app/gradio/assets/1778297/7b10d6db-6476-4e2e-bebd-ecda802c3b8f)
|
||||
|
||||
@ -251,4 +284,4 @@ def predict(message, history):
|
||||
gr.ChatInterface(predict).queue().launch()
|
||||
```
|
||||
|
||||
With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building more custom Chabot UI, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API.
|
||||
With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building even more custom Chatbot applications, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API.
|
@ -1221,6 +1221,37 @@ class TestRender:
|
||||
io3 = io2.render()
|
||||
assert io2 == io3
|
||||
|
||||
def test_is_rendered(self):
|
||||
t = gr.Textbox()
|
||||
with gr.Blocks():
|
||||
pass
|
||||
assert not t.is_rendered
|
||||
|
||||
t = gr.Textbox()
|
||||
with gr.Blocks():
|
||||
t.render()
|
||||
assert t.is_rendered
|
||||
|
||||
t = gr.Textbox()
|
||||
with gr.Blocks():
|
||||
t.render()
|
||||
t.unrender()
|
||||
assert not t.is_rendered
|
||||
|
||||
with gr.Blocks():
|
||||
t = gr.Textbox()
|
||||
assert t.is_rendered
|
||||
|
||||
with gr.Blocks():
|
||||
t = gr.Textbox()
|
||||
with gr.Blocks():
|
||||
pass
|
||||
assert t.is_rendered
|
||||
|
||||
t = gr.Textbox()
|
||||
gr.Interface(lambda x: x, "textbox", t)
|
||||
assert t.is_rendered
|
||||
|
||||
def test_no_error_if_state_rendered_multiple_times(self):
|
||||
state = gr.State("")
|
||||
gr.TabbedInterface(
|
||||
|
@ -1,8 +1,10 @@
|
||||
import tempfile
|
||||
from concurrent.futures import wait
|
||||
|
||||
import pytest
|
||||
|
||||
import gradio as gr
|
||||
from gradio import helpers
|
||||
|
||||
|
||||
def invalid_fn(message):
|
||||
@ -22,15 +24,17 @@ def count(message, history):
|
||||
return str(len(history))
|
||||
|
||||
|
||||
def echo_system_prompt_plus_message(message, history, system_prompt, tokens):
|
||||
response = f"{system_prompt} {message}"
|
||||
for i in range(min(len(response), int(tokens))):
|
||||
yield response[: i + 1]
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_no_fn(self):
|
||||
with pytest.raises(TypeError):
|
||||
gr.ChatInterface()
|
||||
|
||||
def test_invalid_fn_inputs(self):
|
||||
with pytest.warns(UserWarning):
|
||||
gr.ChatInterface(invalid_fn)
|
||||
|
||||
def test_configuring_buttons(self):
|
||||
chatbot = gr.ChatInterface(double, submit_btn=None, retry_btn=None)
|
||||
assert chatbot.submit_btn is None
|
||||
@ -74,7 +78,8 @@ class TestInit:
|
||||
assert prediction_hi[0][0] == ["hi", "hi hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_streaming(self):
|
||||
async def test_example_caching_with_streaming(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
stream, examples=["hello", "hi"], cache_examples=True
|
||||
)
|
||||
@ -83,6 +88,40 @@ class TestInit:
|
||||
assert prediction_hello[0][0] == ["hello", "hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_additional_inputs(self, monkeypatch):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
chatbot = gr.ChatInterface(
|
||||
echo_system_prompt_plus_message,
|
||||
additional_inputs=["textbox", "slider"],
|
||||
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "robot hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "ro"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_example_caching_with_additional_inputs_already_rendered(
|
||||
self, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
|
||||
with gr.Blocks():
|
||||
with gr.Accordion("Inputs"):
|
||||
text = gr.Textbox()
|
||||
slider = gr.Slider()
|
||||
chatbot = gr.ChatInterface(
|
||||
echo_system_prompt_plus_message,
|
||||
additional_inputs=[text, slider],
|
||||
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
|
||||
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
|
||||
assert prediction_hello[0][0] == ["hello", "robot hello"]
|
||||
assert prediction_hi[0][0] == ["hi", "ro"]
|
||||
|
||||
|
||||
class TestAPI:
|
||||
def test_get_api_info(self):
|
||||
@ -104,3 +143,21 @@ class TestAPI:
|
||||
with connect(chatbot) as client:
|
||||
result = client.predict("hello")
|
||||
assert result == "hello hello"
|
||||
|
||||
def test_streaming_api_with_additional_inputs(self, connect):
|
||||
chatbot = gr.ChatInterface(
|
||||
echo_system_prompt_plus_message,
|
||||
additional_inputs=["textbox", "slider"],
|
||||
).queue()
|
||||
with connect(chatbot) as client:
|
||||
job = client.submit("hello", "robot", 7)
|
||||
wait([job])
|
||||
assert job.outputs() == [
|
||||
"r",
|
||||
"ro",
|
||||
"rob",
|
||||
"robo",
|
||||
"robot",
|
||||
"robot ",
|
||||
"robot h",
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user