Add additional_inputs, additional_inputs_accordion parameters to gr.Interface (#6945)

* guide

* new docs

* changes

* interface

* add changeset

* add changeset

* added basic test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-01-10 11:44:00 -08:00 committed by GitHub
parent 71aab1c617
commit ccf317fc97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 9 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Add `additional_inputs`, `additional_inputs_accordion` parameters to `gr.Interface`

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: interface_with_additional_inputs"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate_fake_image(prompt, seed, initial_image=None):\n", " return f\"Used seed: {seed}\", \"https://dummyimage.com/300/09f.png\"\n", "\n", "\n", "demo = gr.Interface(\n", " generate_fake_image,\n", " inputs=[\"textbox\"],\n", " outputs=[\"textbox\", \"image\"],\n", " additional_inputs=[\n", " gr.Slider(0, 1000),\n", " \"image\"\n", " ]\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n", "\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,19 @@
import gradio as gr
def generate_fake_image(prompt, seed, initial_image=None):
return f"Used seed: {seed}", "https://dummyimage.com/300/09f.png"
demo = gr.Interface(
generate_fake_image,
inputs=["textbox"],
outputs=["textbox", "image"],
additional_inputs=[
gr.Slider(0, 1000),
"image"
]
)
if __name__ == "__main__":
demo.launch()

View File

@ -508,13 +508,13 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
): ):
""" """
Parameters: Parameters:
theme: a Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the HF Hub (e.g. "gradio/monochrome"). If None, will use the Default theme. theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the HF Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True. analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
mode: a human-friendly name for the kind of Blocks or Interface being created. mode: A human-friendly name for the kind of Blocks or Interface being created. Used internally for analytics.
title: The tab title to display when this is opened in a browser window. title: The tab title to display when this is opened in a browser window.
css: custom css or path to custom css file to apply to entire Blocks. css: Custom css or path to custom css file to apply to entire Blocks.
js: custom js or path to custom js file to run when demo is first loaded. js: Custom js or path to custom js file to run when demo is first loaded.
head: custom html to insert into the head of the page. head: Custom html to insert into the head of the page. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
""" """
self.limiter = None self.limiter = None
if theme is None: if theme is None:

View File

@ -28,7 +28,7 @@ from gradio.data_classes import InterfaceTypes
from gradio.events import Events, on from gradio.events import Events, on
from gradio.exceptions import RenderError from gradio.exceptions import RenderError
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Column, Row, Tab, Tabs from gradio.layouts import Accordion, Column, Row, Tab, Tabs
from gradio.pipelines import load_from_pipeline from gradio.pipelines import load_from_pipeline
from gradio.themes import ThemeClass as Theme from gradio.themes import ThemeClass as Theme
@ -115,6 +115,10 @@ class Interface(Blocks):
_api_mode: bool = False, _api_mode: bool = False,
allow_duplication: bool = False, allow_duplication: bool = False,
concurrency_limit: int | None | Literal["default"] = "default", concurrency_limit: int | None | Literal["default"] = "default",
js: str | None = None,
head: str | None = None,
additional_inputs: str | Component | list[str | Component] | None = None,
additional_inputs_accordion: str | Accordion | None = None,
**kwargs, **kwargs,
): ):
""" """
@ -142,6 +146,10 @@ class Interface(Blocks):
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None, the name of the prediction function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event. api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None, the name of the prediction function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event.
allow_duplication: If True, then will show a 'Duplicate Spaces' button on Hugging Face Spaces. allow_duplication: If True, then will show a 'Duplicate Spaces' button on Hugging Face Spaces.
concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which itself is 1 by default). concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which itself is 1 by default).
js: Custom js or path to custom js file to run when demo is first loaded.
head: Custom html to insert into the head of the page. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
additional_inputs: A single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. These components will be rendered in an accordion below the main input components. By default, no additional input components will be displayed.
additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
""" """
super().__init__( super().__init__(
analytics_enabled=analytics_enabled, analytics_enabled=analytics_enabled,
@ -149,6 +157,8 @@ class Interface(Blocks):
css=css, css=css,
title=title or "Gradio", title=title or "Gradio",
theme=theme, theme=theme,
js=js,
head=head,
**kwargs, **kwargs,
) )
self.api_name: str | Literal[False] | None = api_name self.api_name: str | Literal[False] | None = api_name
@ -161,6 +171,8 @@ class Interface(Blocks):
elif inputs is None or inputs == []: elif inputs is None or inputs == []:
inputs = [] inputs = []
self.interface_type = InterfaceTypes.OUTPUT_ONLY self.interface_type = InterfaceTypes.OUTPUT_ONLY
if additional_inputs is None:
additional_inputs = []
assert isinstance(inputs, (str, list, Component)) assert isinstance(inputs, (str, list, Component))
assert isinstance(outputs, (str, list, Component)) assert isinstance(outputs, (str, list, Component))
@ -169,6 +181,8 @@ class Interface(Blocks):
inputs = [inputs] inputs = [inputs]
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
if not isinstance(additional_inputs, list):
additional_inputs = [additional_inputs]
if self.space_id and cache_examples is None: if self.space_id and cache_examples is None:
self.cache_examples = True self.cache_examples = True
@ -207,10 +221,36 @@ class Interface(Blocks):
) )
self.cache_examples = False self.cache_examples = False
self.input_components = [ self.main_input_components = [
get_component_instance(i, unrender=True) get_component_instance(i, unrender=True)
for i in inputs # type: ignore for i in inputs # type: ignore
] ]
self.additional_input_components = [
get_component_instance(i, unrender=True)
for i in additional_inputs # type: ignore
]
if additional_inputs_accordion is None:
self.additional_inputs_accordion_params = {
"label": "Additional Inputs",
"open": False,
}
elif isinstance(additional_inputs_accordion, str):
self.additional_inputs_accordion_params = {
"label": additional_inputs_accordion
}
elif isinstance(additional_inputs_accordion, Accordion):
self.additional_inputs_accordion_params = (
additional_inputs_accordion.recover_kwargs(
additional_inputs_accordion.get_config()
)
)
else:
raise ValueError(
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
)
self.input_components = (
self.main_input_components + self.additional_input_components
)
self.output_components = [ self.output_components = [
get_component_instance(o, unrender=True) get_component_instance(o, unrender=True)
for o in outputs # type: ignore for o in outputs # type: ignore
@ -442,7 +482,11 @@ class Interface(Blocks):
with Column(variant="panel"): with Column(variant="panel"):
input_component_column = Column() input_component_column = Column()
with input_component_column: with input_component_column:
for component in self.input_components: for component in self.main_input_components:
component.render()
if self.additional_input_components:
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
for component in self.additional_input_components:
component.render() component.render()
with Row(): with Row():
if self.interface_type in [ if self.interface_type in [

View File

@ -83,6 +83,18 @@ Another useful keyword argument is `label=`, which is present in every `Componen
gr.Number(label='Age', info='In years, must be greater than 0') gr.Number(label='Age', info='In years, must be greater than 0')
``` ```
## Additional Inputs within an Accordion
If your prediction function takes many inputs, you may want to hide some of them within a collapsed accordion to avoid cluttering the UI. The `Interface` class takes an `additional_inputs` argument which is similar to `inputs` but any input components included here are not visible by default. The user must click on the accordion to show these components. The additional inputs are passed into the prediction function, in order, after the standard inputs.
You can customize the appearance of the accordion by using the optional `additional_inputs_accordion` argument, which accepts a string (in which case, it becomes the label of the accordion), or an instance of the `gr.Accordion()` class (e.g. this lets you control whether the accordion is open or closed by default).
Here's an example:
$code_interface_with_additional_inputs
$demo_interface_with_additional_inputs
## Flagging ## Flagging
By default, an `Interface` will have "Flag" button. When a user testing your `Interface` sees input with interesting output, such as erroneous or unexpected model behaviour, they can flag the input for you to review. Within the directory provided by the `flagging_dir=` argument to the `Interface` constructor, a CSV file will log the flagged inputs. If the interface involves file data, such as for Image and Audio components, folders will be created to store those flagged data as well. By default, an `Interface` will have "Flag" button. When a user testing your `Interface` sees input with interesting output, such as erroneous or unexpected model behaviour, they can flag the input for you to review. Within the directory provided by the `flagging_dir=` argument to the `Interface` constructor, a CSV file will log the flagged inputs. If the interface involves file data, such as for Image and Audio components, folders will be created to store those flagged data as well.

View File

@ -177,6 +177,12 @@ class TestInterface:
Interface(fn=str, inputs=t, outputs=Textbox()) Interface(fn=str, inputs=t, outputs=Textbox())
assert t.label == "input 0" assert t.label == "input 0"
def test_interface_additional_components_are_included_as_inputs(self):
t = Textbox()
s = gradio.Slider(0, 100)
io = Interface(fn=str, inputs=t, outputs=Textbox(), additional_inputs=s)
assert io.input_components == [t, s]
class TestTabbedInterface: class TestTabbedInterface:
def test_tabbed_interface_config_matches_manual_tab(self): def test_tabbed_interface_config_matches_manual_tab(self):