diff --git a/gradio/data_classes.py b/gradio/data_classes.py
index b9d4bb9322..49ee748441 100644
--- a/gradio/data_classes.py
+++ b/gradio/data_classes.py
@@ -1,3 +1,4 @@
+from enum import Enum, auto
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
@@ -18,3 +19,10 @@ class PredictBody(BaseModel):
class ResetBody(BaseModel):
session_hash: str
fn_index: int
+
+
+class InterfaceTypes(Enum):
+ STANDARD = auto()
+ INPUT_ONLY = auto()
+ OUTPUT_ONLY = auto()
+ UNIFIED = auto()
diff --git a/gradio/events.py b/gradio/events.py
index 4595e62ada..e2adcfd94d 100644
--- a/gradio/events.py
+++ b/gradio/events.py
@@ -34,7 +34,7 @@ def set_cancel_events(
class Changeable(Block):
def change(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -97,7 +97,7 @@ class Changeable(Block):
class Clickable(Block):
def click(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -161,7 +161,7 @@ class Clickable(Block):
class Submittable(Block):
def submit(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -226,7 +226,7 @@ class Submittable(Block):
class Editable(Block):
def edit(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -290,7 +290,7 @@ class Editable(Block):
class Clearable(Block):
def clear(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -354,7 +354,7 @@ class Clearable(Block):
class Playable(Block):
def play(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -416,7 +416,7 @@ class Playable(Block):
def pause(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -478,7 +478,7 @@ class Playable(Block):
def stop(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -542,7 +542,7 @@ class Playable(Block):
class Streamable(Block):
def stream(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -608,7 +608,7 @@ class Streamable(Block):
class Blurrable(Block):
def blur(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@@ -665,7 +665,7 @@ class Blurrable(Block):
class Uploadable(Block):
def upload(
self,
- fn: Callable,
+ fn: Callable | None,
inputs: List[Component],
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
diff --git a/gradio/flagging.py b/gradio/flagging.py
index 2eb00cacbc..381d6b97d5 100644
--- a/gradio/flagging.py
+++ b/gradio/flagging.py
@@ -551,10 +551,10 @@ class FlagMethod:
Helper class that contains the flagging button option and callback
"""
- def __init__(self, flagging_callback, flag_option=None):
+ def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
self.__name__ = "Flag"
def __call__(self, *flag_data):
- self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
+ self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)
diff --git a/gradio/interface.py b/gradio/interface.py
index eac55e064e..8f93b1a557 100644
--- a/gradio/interface.py
+++ b/gradio/interface.py
@@ -12,34 +12,33 @@ import pkgutil
import re
import warnings
import weakref
-from enum import Enum, auto
-from typing import TYPE_CHECKING, Any, Callable, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from markdown_it import MarkdownIt
-from mdit_py_plugins.dollarmath import dollarmath_plugin
-from mdit_py_plugins.footnote import footnote_plugin
+from mdit_py_plugins.dollarmath.index import dollarmath_plugin
+from mdit_py_plugins.footnote.index import footnote_plugin
from gradio import Examples, interpretation, utils
from gradio.blocks import Blocks
from gradio.components import (
Button,
- Component,
Interpretation,
IOComponent,
Markdown,
State,
get_component_instance,
)
+from gradio.data_classes import InterfaceTypes
from gradio.documentation import document, set_documentation_group
from gradio.events import Changeable, Streamable
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
-from gradio.layouts import Column, Row, TabItem, Tabs
+from gradio.layouts import Column, Row, Tab, Tabs
from gradio.pipelines import load_from_pipeline
set_documentation_group("interface")
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
- import transformers
+ from transformers.pipelines.base import Pipeline
@document("launch", "load", "from_pipeline", "integrate", "queue")
@@ -66,12 +65,6 @@ class Interface(Blocks):
# stores references to all currently existing Interface instances
instances: weakref.WeakSet = weakref.WeakSet()
- class InterfaceTypes(Enum):
- STANDARD = auto()
- INPUT_ONLY = auto()
- OUTPUT_ONLY = auto()
- UNIFIED = auto()
-
@classmethod
def get_instances(cls) -> List[Interface]:
"""
@@ -83,9 +76,9 @@ class Interface(Blocks):
def load(
cls,
name: str,
- src: Optional[str] = None,
- api_key: Optional[str] = None,
- alias: Optional[str] = None,
+ src: str | None = None,
+ api_key: str | None = None,
+ alias: str | None = None,
**kwargs,
) -> Interface:
"""
@@ -109,7 +102,7 @@ class Interface(Blocks):
return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs)
@classmethod
- def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface:
+ def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface:
"""
Class method that constructs an Interface from a Hugging Face transformers.Pipeline object.
The input and output components are automatically determined from the pipeline.
@@ -131,25 +124,25 @@ class Interface(Blocks):
def __init__(
self,
fn: Callable,
- inputs: Optional[str | Component | List[str | Component]],
- outputs: Optional[str | Component | List[str | Component]],
- examples: Optional[List[Any] | List[List[Any]] | str] = None,
- cache_examples: Optional[bool] = None,
+ inputs: str | IOComponent | List[str | IOComponent] | None,
+ outputs: str | IOComponent | List[str | IOComponent] | None,
+ examples: List[Any] | List[List[Any]] | str | None = None,
+ cache_examples: bool | None = None,
examples_per_page: int = 10,
live: bool = False,
- interpretation: Optional[Callable | str] = None,
+ interpretation: Callable | str | None = None,
num_shap: float = 2.0,
- title: Optional[str] = None,
- description: Optional[str] = None,
- article: Optional[str] = None,
- thumbnail: Optional[str] = None,
- theme: Optional[str] = None,
- css: Optional[str] = None,
- allow_flagging: Optional[str] = None,
- flagging_options: List[str] = None,
+ title: str | None = None,
+ description: str | None = None,
+ article: str | None = None,
+ thumbnail: str | None = None,
+ theme: str = "default",
+ css: str | None = None,
+ allow_flagging: str | None = None,
+ flagging_options: List[str] | None = None,
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback = CSVLogger(),
- analytics_enabled: Optional[bool] = None,
+ analytics_enabled: bool | None = None,
batch: bool = False,
max_batch_size: int = 4,
_api_mode: bool = False,
@@ -184,21 +177,11 @@ class Interface(Blocks):
analytics_enabled=analytics_enabled,
mode="interface",
css=css,
- title=title,
+ title=title or "Gradio",
theme=theme,
**kwargs,
)
- self.interface_type = self.InterfaceTypes.STANDARD
- if (inputs is None or inputs == []) and (outputs is None or outputs == []):
- raise ValueError("Must provide at least one of `inputs` or `outputs`")
- elif outputs is None or outputs == []:
- outputs = []
- self.interface_type = self.InterfaceTypes.INPUT_ONLY
- elif inputs is None or inputs == []:
- inputs = []
- self.interface_type = self.InterfaceTypes.OUTPUT_ONLY
-
if isinstance(fn, list):
raise DeprecationWarning(
"The `fn` parameter only accepts a single function, support for a list "
@@ -206,6 +189,19 @@ class Interface(Blocks):
"instead."
)
+ self.interface_type = InterfaceTypes.STANDARD
+ if (inputs is None or inputs == []) and (outputs is None or outputs == []):
+ raise ValueError("Must provide at least one of `inputs` or `outputs`")
+ elif outputs is None or outputs == []:
+ outputs = []
+ self.interface_type = InterfaceTypes.INPUT_ONLY
+ elif inputs is None or inputs == []:
+ inputs = []
+ self.interface_type = InterfaceTypes.OUTPUT_ONLY
+
+ assert isinstance(inputs, (str, list, IOComponent))
+ assert isinstance(outputs, (str, list, IOComponent))
+
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):
@@ -234,7 +230,7 @@ class Interface(Blocks):
state_output_index = state_output_indexes[0]
if inputs[state_input_index] == "state":
default = utils.get_default_args(fn)[state_input_index]
- state_variable = State(value=default)
+ state_variable = State(value=default) # type: ignore
else:
state_variable = inputs[state_input_index]
@@ -266,13 +262,14 @@ class Interface(Blocks):
i is o for i, o in zip(self.input_components, self.output_components)
]
if all(same_components):
- self.interface_type = self.InterfaceTypes.UNIFIED
+ self.interface_type = InterfaceTypes.UNIFIED
if self.interface_type in [
- self.InterfaceTypes.STANDARD,
- self.InterfaceTypes.OUTPUT_ONLY,
+ InterfaceTypes.STANDARD,
+ InterfaceTypes.OUTPUT_ONLY,
]:
for o in self.output_components:
+ assert isinstance(o, IOComponent)
o.interactive = False # Force output components to be non-interactive
if (
@@ -375,6 +372,8 @@ class Interface(Blocks):
self.flagging_options = flagging_options
self.flagging_callback = flagging_callback
self.flagging_dir = flagging_dir
+ self.batch = batch
+ self.max_batch_size = max_batch_size
self.save_to = None # Used for selenium tests
self.share = None
@@ -395,7 +394,7 @@ class Interface(Blocks):
"allow_flagging": allow_flagging,
"custom_css": self.css is not None,
"theme": self.theme,
- "version": pkgutil.get_data(__name__, "version.txt")
+ "version": (pkgutil.get_data(__name__, "version.txt") or b"")
.decode("ascii")
.strip(),
}
@@ -406,9 +405,11 @@ class Interface(Blocks):
param_names = inspect.getfullargspec(self.fn)[0]
for component, param_name in zip(self.input_components, param_names):
+ assert isinstance(component, IOComponent)
if component.label is None:
component.label = param_name
for i, component in enumerate(self.output_components):
+ assert isinstance(component, IOComponent)
if component.label is None:
if len(self.output_components) == 1:
component.label = "output"
@@ -417,257 +418,342 @@ class Interface(Blocks):
if self.allow_flagging != "never":
if (
- self.interface_type == self.InterfaceTypes.UNIFIED
+ self.interface_type == InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
- self.flagging_callback.setup(self.input_components, self.flagging_dir)
- elif self.interface_type == self.InterfaceTypes.INPUT_ONLY:
+ self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
+ elif self.interface_type == InterfaceTypes.INPUT_ONLY:
pass
else:
self.flagging_callback.setup(
- self.input_components + self.output_components, self.flagging_dir
+ self.input_components + self.output_components, self.flagging_dir # type: ignore
)
+ # Render the Gradio UI
with self:
- if self.title:
- Markdown(
- "
"
- + self.title
- + "
"
- )
- if self.description:
- Markdown(self.description)
+ self.render_title_description()
- def render_flag_btns(flagging_options):
- if flagging_options is None:
- return [(Button("Flag"), None)]
- else:
- return [
- (
- Button("Flag as " + flag_option),
- flag_option,
- )
- for flag_option in flagging_options
- ]
+ submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
+ interpretation_btn, interpretation_set = None, None
+ input_component_column, interpret_component_column = None, None
with Row().style(equal_height=False):
if self.interface_type in [
- self.InterfaceTypes.STANDARD,
- self.InterfaceTypes.INPUT_ONLY,
- self.InterfaceTypes.UNIFIED,
+ InterfaceTypes.STANDARD,
+ InterfaceTypes.INPUT_ONLY,
+ InterfaceTypes.UNIFIED,
]:
- with Column(variant="panel"):
- input_component_column = Column()
- with input_component_column:
- for component in self.input_components:
- component.render()
- if self.interpretation:
- interpret_component_column = Column(visible=False)
- interpretation_set = []
- with interpret_component_column:
- for component in self.input_components:
- interpretation_set.append(Interpretation(component))
- with Row():
- if self.interface_type in [
- self.InterfaceTypes.STANDARD,
- self.InterfaceTypes.INPUT_ONLY,
- ]:
- clear_btn = Button("Clear")
- if not self.live:
- submit_btn = Button("Submit", variant="primary")
- # Stopping jobs only works if the queue is enabled
- # We don't know if the queue is enabled when the interface
- # is created. We use whether a generator function is provided
- # as a proxy of whether the queue will be enabled.
- # Using a generator function without the queue will raise an error.
- if inspect.isgeneratorfunction(fn):
- stop_btn = Button("Stop", variant="stop")
-
- elif self.interface_type == self.InterfaceTypes.UNIFIED:
- clear_btn = Button("Clear")
- submit_btn = Button("Submit", variant="primary")
- if inspect.isgeneratorfunction(fn) and not self.live:
- stop_btn = Button("Stop", variant="stop")
- if self.allow_flagging == "manual":
- flag_btns = render_flag_btns(self.flagging_options)
-
+ (
+ submit_btn,
+ clear_btn,
+ stop_btn,
+ flag_btns,
+ input_component_column,
+ interpret_component_column,
+ interpretation_set,
+ ) = self.render_input_column()
if self.interface_type in [
- self.InterfaceTypes.STANDARD,
- self.InterfaceTypes.OUTPUT_ONLY,
+ InterfaceTypes.STANDARD,
+ InterfaceTypes.OUTPUT_ONLY,
]:
+ (
+ submit_btn_out,
+ clear_btn_2_out,
+ stop_btn_2_out,
+ flag_btns_out,
+ interpretation_btn,
+ ) = self.render_output_column(submit_btn)
+ submit_btn = submit_btn or submit_btn_out
+ clear_btn = clear_btn or clear_btn_2_out
+ stop_btn = stop_btn or stop_btn_2_out
+ flag_btns = flag_btns or flag_btns_out
- with Column(variant="panel"):
- for component in self.output_components:
- if not (isinstance(component, State)):
- component.render()
- with Row():
- if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
- clear_btn = Button("Clear")
- submit_btn = Button("Generate", variant="primary")
- if inspect.isgeneratorfunction(fn) and not self.live:
- # Stopping jobs only works if the queue is enabled
- # We don't know if the queue is enabled when the interface
- # is created. We use whether a generator function is provided
- # as a proxy of whether the queue will be enabled.
- # Using a generator function without the queue will raise an error.
- stop_btn = Button("Stop", variant="stop")
- if self.allow_flagging == "manual":
- flag_btns = render_flag_btns(self.flagging_options)
- if self.interpretation:
- interpretation_btn = Button("Interpret")
- if self.live:
- if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
- super().load(self.fn, None, self.output_components)
- submit_btn.click(
- self.fn,
- None,
- self.output_components,
- api_name="predict",
- preprocess=not (self.api_mode),
- postprocess=not (self.api_mode),
- batch=batch,
- max_batch_size=max_batch_size,
- )
- else:
- for component in self.input_components:
- if isinstance(component, Streamable):
- if component.streaming:
- component.stream(
- self.fn,
- self.input_components,
- self.output_components,
- api_name="predict",
- preprocess=not (self.api_mode),
- postprocess=not (self.api_mode),
- )
- continue
- else:
- print(
- "Hint: Set streaming=True for "
- + component.__class__.__name__
- + " component to use live streaming."
- )
- if isinstance(component, Changeable):
- component.change(
- self.fn,
- self.input_components,
- self.output_components,
- api_name="predict",
- preprocess=not (self.api_mode),
- postprocess=not (self.api_mode),
- )
- else:
- pred = submit_btn.click(
- self.fn,
- self.input_components,
- self.output_components,
- api_name="predict",
- scroll_to_output=True,
- preprocess=not (self.api_mode),
- postprocess=not (self.api_mode),
- batch=batch,
- max_batch_size=max_batch_size,
- )
- if inspect.isgeneratorfunction(fn):
- stop_btn.click(
- None,
- inputs=None,
- outputs=None,
- cancels=[pred],
- )
+ assert clear_btn is not None, "Clear button not rendered"
- clear_btn.click(
- None,
- [],
- (
- self.input_components
- + self.output_components
- + (
- [input_component_column]
- if self.interface_type
- in [
- self.InterfaceTypes.STANDARD,
- self.InterfaceTypes.INPUT_ONLY,
- self.InterfaceTypes.UNIFIED,
- ]
- else []
- )
- + ([interpret_component_column] if self.interpretation else [])
- ),
- _js=f"""() => {json.dumps(
- [component.cleared_value if hasattr(component, "cleared_value") else None
- for component in self.input_components + self.output_components] + (
- [Column.update(visible=True)]
- if self.interface_type
- in [
- self.InterfaceTypes.STANDARD,
- self.InterfaceTypes.INPUT_ONLY,
- self.InterfaceTypes.UNIFIED,
- ]
- else []
- )
- + ([Column.update(visible=False)] if self.interpretation else [])
- )}
- """,
+ self.attach_submit_events(submit_btn, stop_btn)
+ self.attach_clear_events(
+ clear_btn, input_component_column, interpret_component_column
+ )
+ self.attach_interpretation_events(
+ interpretation_btn,
+ interpretation_set,
+ input_component_column,
+ interpret_component_column,
)
- if self.allow_flagging in ["manual", "auto"]:
- if self.interface_type in [
- self.InterfaceTypes.STANDARD,
- self.InterfaceTypes.OUTPUT_ONLY,
- self.InterfaceTypes.UNIFIED,
- ]:
- if self.allow_flagging == "auto":
- flag_btns = [(submit_btn, None)]
- if (
- self.interface_type == self.InterfaceTypes.UNIFIED
- or self.allow_flagging == "auto"
- ):
- flag_components = self.input_components
- else:
- flag_components = self.input_components + self.output_components
- for flag_btn, flag_option in flag_btns:
- flag_method = FlagMethod(self.flagging_callback, flag_option)
- flag_btn.click(
- flag_method,
- inputs=flag_components,
- outputs=[],
- preprocess=False,
- queue=False,
- )
-
- if self.examples:
- non_state_inputs = [
- c for c in self.input_components if not isinstance(c, State)
- ]
- non_state_outputs = [
- c for c in self.output_components if not isinstance(c, State)
- ]
- self.examples_handler = Examples(
- examples=examples,
- inputs=non_state_inputs,
- outputs=non_state_outputs,
- fn=self.fn,
- cache_examples=self.cache_examples,
- examples_per_page=examples_per_page,
- _api_mode=_api_mode,
- batch=batch,
- )
-
- if self.interpretation:
- interpretation_btn.click(
- self.interpret_func,
- inputs=self.input_components + self.output_components,
- outputs=interpretation_set
- + [input_component_column, interpret_component_column],
- preprocess=False,
- )
-
- if self.article:
- Markdown(self.article)
+ self.render_flagging_buttons(flag_btns)
+ self.render_examples()
+ self.render_article()
self.config = self.get_config_file()
+ def render_title_description(self) -> None:
+ if self.title:
+ Markdown(
+ ""
+ + self.title
+ + "
"
+ )
+ if self.description:
+ Markdown(self.description)
+
+ def render_flag_btns(self) -> List[Tuple[Button, str | None]]:
+ if self.flagging_options is None:
+ return [(Button("Flag"), None)]
+ else:
+ return [
+ (
+ Button("Flag as " + flag_option),
+ flag_option,
+ )
+ for flag_option in self.flagging_options
+ ]
+
+ def render_input_column(
+ self,
+ ) -> Tuple[
+ Button | None,
+ Button | None,
+ Button | None,
+ List | None,
+ Column,
+ Column | None,
+ List[Interpretation] | None,
+ ]:
+ submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
+ interpret_component_column, interpretation_set = None, None
+
+ with Column(variant="panel"):
+ input_component_column = Column()
+ with input_component_column:
+ for component in self.input_components:
+ component.render()
+ if self.interpretation:
+ interpret_component_column = Column(visible=False)
+ interpretation_set = []
+ with interpret_component_column:
+ for component in self.input_components:
+ interpretation_set.append(Interpretation(component))
+ with Row():
+ if self.interface_type in [
+ InterfaceTypes.STANDARD,
+ InterfaceTypes.INPUT_ONLY,
+ ]:
+ clear_btn = Button("Clear")
+ if not self.live:
+ submit_btn = Button("Submit", variant="primary")
+ # Stopping jobs only works if the queue is enabled
+ # We don't know if the queue is enabled when the interface
+ # is created. We use whether a generator function is provided
+ # as a proxy of whether the queue will be enabled.
+ # Using a generator function without the queue will raise an error.
+ if inspect.isgeneratorfunction(self.fn):
+ stop_btn = Button("Stop", variant="stop")
+ elif self.interface_type == InterfaceTypes.UNIFIED:
+ clear_btn = Button("Clear")
+ submit_btn = Button("Submit", variant="primary")
+ if inspect.isgeneratorfunction(self.fn) and not self.live:
+ stop_btn = Button("Stop", variant="stop")
+ if self.allow_flagging == "manual":
+ flag_btns = self.render_flag_btns()
+ elif self.allow_flagging == "auto":
+ flag_btns = [(submit_btn, None)]
+ return (
+ submit_btn,
+ clear_btn,
+ stop_btn,
+ flag_btns,
+ input_component_column,
+ interpret_component_column,
+ interpretation_set,
+ )
+
+ def render_output_column(
+ self,
+ submit_btn_in: Button | None,
+ ) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
+ submit_btn = submit_btn_in
+ interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
+
+ with Column(variant="panel"):
+ for component in self.output_components:
+ if not (isinstance(component, State)):
+ component.render()
+ with Row():
+ if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
+ clear_btn = Button("Clear")
+ submit_btn = Button("Generate", variant="primary")
+ if inspect.isgeneratorfunction(self.fn) and not self.live:
+ # Stopping jobs only works if the queue is enabled
+ # We don't know if the queue is enabled when the interface
+ # is created. We use whether a generator function is provided
+ # as a proxy of whether the queue will be enabled.
+ # Using a generator function without the queue will raise an error.
+ stop_btn = Button("Stop", variant="stop")
+ if self.allow_flagging == "manual":
+ flag_btns = self.render_flag_btns()
+ elif self.allow_flagging == "auto":
+ assert submit_btn is not None, "Submit button not rendered"
+ flag_btns = [(submit_btn, None)]
+ if self.interpretation:
+ interpretation_btn = Button("Interpret")
+
+ return submit_btn, clear_btn, stop_btn, flag_btns, interpretation_btn
+
+ def render_article(self):
+ if self.article:
+ Markdown(self.article)
+
+ def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
+ if self.live:
+ if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
+ assert submit_btn is not None, "Submit button not rendered"
+ super().load(self.fn, None, self.output_components)
+ # For output-only interfaces, the user probably still want a "generate"
+ # button even if the Interface is live
+ submit_btn.click(
+ self.fn,
+ None,
+ self.output_components,
+ api_name="predict",
+ preprocess=not (self.api_mode),
+ postprocess=not (self.api_mode),
+ batch=self.batch,
+ max_batch_size=self.max_batch_size,
+ )
+ else:
+ for component in self.input_components:
+ if isinstance(component, Streamable) and component.streaming:
+ component.stream(
+ self.fn,
+ self.input_components,
+ self.output_components,
+ api_name="predict",
+ preprocess=not (self.api_mode),
+ postprocess=not (self.api_mode),
+ )
+ continue
+ if isinstance(component, Changeable):
+ component.change(
+ self.fn,
+ self.input_components,
+ self.output_components,
+ api_name="predict",
+ preprocess=not (self.api_mode),
+ postprocess=not (self.api_mode),
+ )
+ else:
+ assert submit_btn is not None, "Submit button not rendered"
+ pred = submit_btn.click(
+ self.fn,
+ self.input_components,
+ self.output_components,
+ api_name="predict",
+ scroll_to_output=True,
+ preprocess=not (self.api_mode),
+ postprocess=not (self.api_mode),
+ batch=self.batch,
+ max_batch_size=self.max_batch_size,
+ )
+ if stop_btn:
+ stop_btn.click(
+ None,
+ inputs=None,
+ outputs=None,
+ cancels=[pred],
+ )
+
+ def attach_clear_events(
+ self,
+ clear_btn: Button,
+ input_component_column: Column | None,
+ interpret_component_column: Column | None,
+ ):
+ clear_btn.click(
+ None,
+ [],
+ (
+ self.input_components
+ + self.output_components
+ + ([input_component_column] if input_component_column else [])
+ + ([interpret_component_column] if self.interpretation else [])
+ ), # type: ignore
+ _js=f"""() => {json.dumps(
+ [getattr(component, "cleared_value", None)
+ for component in self.input_components + self.output_components] + (
+ [Column.update(visible=True)]
+ if self.interface_type
+ in [
+ InterfaceTypes.STANDARD,
+ InterfaceTypes.INPUT_ONLY,
+ InterfaceTypes.UNIFIED,
+ ]
+ else []
+ )
+ + ([Column.update(visible=False)] if self.interpretation else [])
+ )}
+ """,
+ )
+
+ def attach_interpretation_events(
+ self,
+ interpretation_btn: Button | None,
+ interpretation_set: List[Interpretation] | None,
+ input_component_column: Column | None,
+ interpret_component_column: Column | None,
+ ):
+ if interpretation_btn:
+ interpretation_btn.click(
+ self.interpret_func,
+ inputs=self.input_components + self.output_components,
+ outputs=interpretation_set
+ or [] + [input_component_column, interpret_component_column], # type: ignore
+ preprocess=False,
+ )
+
+ def render_flagging_buttons(self, flag_btns: List | None):
+ if flag_btns:
+ if self.interface_type in [
+ InterfaceTypes.STANDARD,
+ InterfaceTypes.OUTPUT_ONLY,
+ InterfaceTypes.UNIFIED,
+ ]:
+ if (
+ self.interface_type == InterfaceTypes.UNIFIED
+ or self.allow_flagging == "auto"
+ ):
+ flag_components = self.input_components
+ else:
+ flag_components = self.input_components + self.output_components
+ for flag_btn, flag_option in flag_btns:
+ flag_method = FlagMethod(self.flagging_callback, flag_option)
+ flag_btn.click(
+ flag_method,
+ inputs=flag_components,
+ outputs=[],
+ preprocess=False,
+ queue=False,
+ )
+
+ def render_examples(self):
+ if self.examples:
+ non_state_inputs = [
+ c for c in self.input_components if not isinstance(c, State)
+ ]
+ non_state_outputs = [
+ c for c in self.output_components if not isinstance(c, State)
+ ]
+ self.examples_handler = Examples(
+ examples=self.examples,
+ inputs=non_state_inputs, # type: ignore
+ outputs=non_state_outputs, # type: ignore
+ fn=self.fn,
+ cache_examples=self.cache_examples,
+ examples_per_page=self.examples_per_page,
+ _api_mode=self.api_mode,
+ batch=self.batch,
+ )
+
def __str__(self):
return self.__repr__()
@@ -683,7 +769,7 @@ class Interface(Blocks):
return repr
async def interpret_func(self, *args):
- return await self.interpret(args) + [
+ return await self.interpret(list(args)) + [
Column.update(visible=False),
Column.update(visible=True),
]
@@ -698,20 +784,9 @@ class Interface(Blocks):
def test_launch(self) -> None:
"""
- Passes a few samples through the function to test if the inputs/outputs
- components are consistent with the function parameter and return values.
+ Deprecated.
"""
- print("Test launch: {}()...".format(self.__name__), end=" ")
- raw_input = []
- for input_component in self.input_components:
- if input_component.test_input is None:
- print("SKIPPED")
- break
- else:
- raw_input.append(input_component.test_input)
- else:
- self(raw_input)
- print("PASSED")
+ warnings.warn("The Interface.test_launch() function is deprecated.")
@document()
@@ -725,11 +800,11 @@ class TabbedInterface(Blocks):
def __init__(
self,
interface_list: List[Interface],
- tab_names: Optional[List[str]] = None,
- title: Optional[str] = None,
+ tab_names: List[str] | None = None,
+ title: str | None = None,
theme: str = "default",
- analytics_enabled: Optional[bool] = None,
- css: Optional[str] = None,
+ analytics_enabled: bool | None = None,
+ css: str | None = None,
):
"""
Parameters:
@@ -743,7 +818,7 @@ class TabbedInterface(Blocks):
a Gradio Tabbed Interface for the given interfaces
"""
super().__init__(
- title=title,
+ title=title or "Gradio",
theme=theme,
analytics_enabled=analytics_enabled,
mode="tabbed_interface",
@@ -760,7 +835,7 @@ class TabbedInterface(Blocks):
)
with Tabs():
for (interface, tab_name) in zip(interface_list, tab_names):
- with TabItem(label=tab_name):
+ with Tab(label=tab_name):
interface.render()
diff --git a/scripts/type_check_backend.sh b/scripts/type_check_backend.sh
index 5422d5d0eb..62fa6507a1 100644
--- a/scripts/type_check_backend.sh
+++ b/scripts/type_check_backend.sh
@@ -6,4 +6,4 @@ pip_required
pip install --upgrade pip
pip install pyright
cd gradio
-pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py
+pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py
diff --git a/test/test_interfaces.py b/test/test_interfaces.py
index 0408fe4829..9468922760 100644
--- a/test/test_interfaces.py
+++ b/test/test_interfaces.py
@@ -69,15 +69,6 @@ class TestInterface:
)
assert dataset_check
- def test_test_launch(self):
- with captured_output() as (out, err):
- prediction_fn = lambda x: x
- prediction_fn.__name__ = "prediction_fn"
- interface = Interface(prediction_fn, "textbox", "label")
- interface.test_launch()
- output = out.getvalue().strip()
- assert output == "Test launch: prediction_fn()... PASSED"
-
@mock.patch("time.sleep")
def test_block_thread(self, mock_sleep):
with pytest.raises(KeyboardInterrupt):