Add additional_inputs, additional_inputs_accordion parameters to gr.Interface (#6945)

* guide

* new docs

* changes

* interface

* add changeset

* add changeset

* added basic test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-01-10 11:44:00 -08:00 committed by GitHub
parent 71aab1c617
commit ccf317fc97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 9 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Add `additional_inputs`, `additional_inputs_accordion` parameters to `gr.Interface`

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: interface_with_additional_inputs"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate_fake_image(prompt, seed, initial_image=None):\n", " return f\"Used seed: {seed}\", \"https://dummyimage.com/300/09f.png\"\n", "\n", "\n", "demo = gr.Interface(\n", " generate_fake_image,\n", " inputs=[\"textbox\"],\n", " outputs=[\"textbox\", \"image\"],\n", " additional_inputs=[\n", " gr.Slider(0, 1000),\n", " \"image\"\n", " ]\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n", "\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,19 @@
import gradio as gr
def generate_fake_image(prompt, seed, initial_image=None):
return f"Used seed: {seed}", "https://dummyimage.com/300/09f.png"
demo = gr.Interface(
generate_fake_image,
inputs=["textbox"],
outputs=["textbox", "image"],
additional_inputs=[
gr.Slider(0, 1000),
"image"
]
)
if __name__ == "__main__":
demo.launch()

View File

@ -508,13 +508,13 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
):
"""
Parameters:
theme: a Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the HF Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
mode: a human-friendly name for the kind of Blocks or Interface being created.
theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the HF Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
mode: A human-friendly name for the kind of Blocks or Interface being created. Used internally for analytics.
title: The tab title to display when this is opened in a browser window.
css: custom css or path to custom css file to apply to entire Blocks.
js: custom js or path to custom js file to run when demo is first loaded.
head: custom html to insert into the head of the page.
css: Custom css or path to custom css file to apply to entire Blocks.
js: Custom js or path to custom js file to run when demo is first loaded.
head: Custom html to insert into the head of the page. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
"""
self.limiter = None
if theme is None:

View File

@ -28,7 +28,7 @@ from gradio.data_classes import InterfaceTypes
from gradio.events import Events, on
from gradio.exceptions import RenderError
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Column, Row, Tab, Tabs
from gradio.layouts import Accordion, Column, Row, Tab, Tabs
from gradio.pipelines import load_from_pipeline
from gradio.themes import ThemeClass as Theme
@ -115,6 +115,10 @@ class Interface(Blocks):
_api_mode: bool = False,
allow_duplication: bool = False,
concurrency_limit: int | None | Literal["default"] = "default",
js: str | None = None,
head: str | None = None,
additional_inputs: str | Component | list[str | Component] | None = None,
additional_inputs_accordion: str | Accordion | None = None,
**kwargs,
):
"""
@ -142,6 +146,10 @@ class Interface(Blocks):
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None, the name of the prediction function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event.
allow_duplication: If True, then will show a 'Duplicate Spaces' button on Hugging Face Spaces.
concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which itself is 1 by default).
js: Custom js or path to custom js file to run when demo is first loaded.
head: Custom html to insert into the head of the page. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
additional_inputs: A single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. These components will be rendered in an accordion below the main input components. By default, no additional input components will be displayed.
additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -149,6 +157,8 @@ class Interface(Blocks):
css=css,
title=title or "Gradio",
theme=theme,
js=js,
head=head,
**kwargs,
)
self.api_name: str | Literal[False] | None = api_name
@ -161,6 +171,8 @@ class Interface(Blocks):
elif inputs is None or inputs == []:
inputs = []
self.interface_type = InterfaceTypes.OUTPUT_ONLY
if additional_inputs is None:
additional_inputs = []
assert isinstance(inputs, (str, list, Component))
assert isinstance(outputs, (str, list, Component))
@ -169,6 +181,8 @@ class Interface(Blocks):
inputs = [inputs]
if not isinstance(outputs, list):
outputs = [outputs]
if not isinstance(additional_inputs, list):
additional_inputs = [additional_inputs]
if self.space_id and cache_examples is None:
self.cache_examples = True
@ -207,10 +221,36 @@ class Interface(Blocks):
)
self.cache_examples = False
self.input_components = [
self.main_input_components = [
get_component_instance(i, unrender=True)
for i in inputs # type: ignore
]
self.additional_input_components = [
get_component_instance(i, unrender=True)
for i in additional_inputs # type: ignore
]
if additional_inputs_accordion is None:
self.additional_inputs_accordion_params = {
"label": "Additional Inputs",
"open": False,
}
elif isinstance(additional_inputs_accordion, str):
self.additional_inputs_accordion_params = {
"label": additional_inputs_accordion
}
elif isinstance(additional_inputs_accordion, Accordion):
self.additional_inputs_accordion_params = (
additional_inputs_accordion.recover_kwargs(
additional_inputs_accordion.get_config()
)
)
else:
raise ValueError(
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
)
self.input_components = (
self.main_input_components + self.additional_input_components
)
self.output_components = [
get_component_instance(o, unrender=True)
for o in outputs # type: ignore
@ -442,8 +482,12 @@ class Interface(Blocks):
with Column(variant="panel"):
input_component_column = Column()
with input_component_column:
for component in self.input_components:
for component in self.main_input_components:
component.render()
if self.additional_input_components:
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
for component in self.additional_input_components:
component.render()
with Row():
if self.interface_type in [
InterfaceTypes.STANDARD,

View File

@ -83,6 +83,18 @@ Another useful keyword argument is `label=`, which is present in every `Componen
gr.Number(label='Age', info='In years, must be greater than 0')
```
## Additional Inputs within an Accordion
If your prediction function takes many inputs, you may want to hide some of them within a collapsed accordion to avoid cluttering the UI. The `Interface` class takes an `additional_inputs` argument which is similar to `inputs` but any input components included here are not visible by default. The user must click on the accordion to show these components. The additional inputs are passed into the prediction function, in order, after the standard inputs.
You can customize the appearance of the accordion by using the optional `additional_inputs_accordion` argument, which accepts a string (in which case, it becomes the label of the accordion), or an instance of the `gr.Accordion()` class (e.g. this lets you control whether the accordion is open or closed by default).
Here's an example:
$code_interface_with_additional_inputs
$demo_interface_with_additional_inputs
## Flagging
By default, an `Interface` will have "Flag" button. When a user testing your `Interface` sees input with interesting output, such as erroneous or unexpected model behaviour, they can flag the input for you to review. Within the directory provided by the `flagging_dir=` argument to the `Interface` constructor, a CSV file will log the flagged inputs. If the interface involves file data, such as for Image and Audio components, folders will be created to store those flagged data as well.

View File

@ -177,6 +177,12 @@ class TestInterface:
Interface(fn=str, inputs=t, outputs=Textbox())
assert t.label == "input 0"
def test_interface_additional_components_are_included_as_inputs(self):
t = Textbox()
s = gradio.Slider(0, 100)
io = Interface(fn=str, inputs=t, outputs=Textbox(), additional_inputs=s)
assert io.input_components == [t, s]
class TestTabbedInterface:
def test_tabbed_interface_config_matches_manual_tab(self):