diff --git a/gradio/data_classes.py b/gradio/data_classes.py index b9d4bb9322..49ee748441 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -1,3 +1,4 @@ +from enum import Enum, auto from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel @@ -18,3 +19,10 @@ class PredictBody(BaseModel): class ResetBody(BaseModel): session_hash: str fn_index: int + + +class InterfaceTypes(Enum): + STANDARD = auto() + INPUT_ONLY = auto() + OUTPUT_ONLY = auto() + UNIFIED = auto() diff --git a/gradio/events.py b/gradio/events.py index 4595e62ada..e2adcfd94d 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -34,7 +34,7 @@ def set_cancel_events( class Changeable(Block): def change( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -97,7 +97,7 @@ class Changeable(Block): class Clickable(Block): def click( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -161,7 +161,7 @@ class Clickable(Block): class Submittable(Block): def submit( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -226,7 +226,7 @@ class Submittable(Block): class Editable(Block): def edit( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -290,7 +290,7 @@ class Editable(Block): class Clearable(Block): def clear( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -354,7 +354,7 @@ class Clearable(Block): class Playable(Block): def play( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -416,7 +416,7 @@ class Playable(Block): def pause( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -478,7 +478,7 @@ class Playable(Block): def stop( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -542,7 +542,7 @@ class Playable(Block): class Streamable(Block): def stream( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -608,7 +608,7 @@ class Streamable(Block): class Blurrable(Block): def blur( self, - fn: Callable, + fn: Callable | None, inputs: Component | List[Component] | Set[Component] | None = None, outputs: Component | List[Component] | None = None, api_name: str | None = None, @@ -665,7 +665,7 @@ class Blurrable(Block): class Uploadable(Block): def upload( self, - fn: Callable, + fn: Callable | None, inputs: List[Component], outputs: Component | List[Component] | None = None, api_name: str | None = None, diff --git a/gradio/flagging.py b/gradio/flagging.py index 2eb00cacbc..381d6b97d5 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -551,10 +551,10 @@ class FlagMethod: Helper class that contains the flagging button option and callback """ - def __init__(self, flagging_callback, flag_option=None): + def __init__(self, flagging_callback: FlaggingCallback, flag_option=None): self.flagging_callback = flagging_callback self.flag_option = flag_option self.__name__ = "Flag" def __call__(self, *flag_data): - self.flagging_callback.flag(flag_data, flag_option=self.flag_option) + self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option) diff --git a/gradio/interface.py b/gradio/interface.py index eac55e064e..8f93b1a557 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -12,34 +12,33 @@ import pkgutil import re import warnings import weakref -from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, List, Tuple from markdown_it import MarkdownIt -from mdit_py_plugins.dollarmath import dollarmath_plugin -from mdit_py_plugins.footnote import footnote_plugin +from mdit_py_plugins.dollarmath.index import dollarmath_plugin +from mdit_py_plugins.footnote.index import footnote_plugin from gradio import Examples, interpretation, utils from gradio.blocks import Blocks from gradio.components import ( Button, - Component, Interpretation, IOComponent, Markdown, State, get_component_instance, ) +from gradio.data_classes import InterfaceTypes from gradio.documentation import document, set_documentation_group from gradio.events import Changeable, Streamable from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod -from gradio.layouts import Column, Row, TabItem, Tabs +from gradio.layouts import Column, Row, Tab, Tabs from gradio.pipelines import load_from_pipeline set_documentation_group("interface") if TYPE_CHECKING: # Only import for type checking (is False at runtime). - import transformers + from transformers.pipelines.base import Pipeline @document("launch", "load", "from_pipeline", "integrate", "queue") @@ -66,12 +65,6 @@ class Interface(Blocks): # stores references to all currently existing Interface instances instances: weakref.WeakSet = weakref.WeakSet() - class InterfaceTypes(Enum): - STANDARD = auto() - INPUT_ONLY = auto() - OUTPUT_ONLY = auto() - UNIFIED = auto() - @classmethod def get_instances(cls) -> List[Interface]: """ @@ -83,9 +76,9 @@ class Interface(Blocks): def load( cls, name: str, - src: Optional[str] = None, - api_key: Optional[str] = None, - alias: Optional[str] = None, + src: str | None = None, + api_key: str | None = None, + alias: str | None = None, **kwargs, ) -> Interface: """ @@ -109,7 +102,7 @@ class Interface(Blocks): return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs) @classmethod - def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface: + def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface: """ Class method that constructs an Interface from a Hugging Face transformers.Pipeline object. The input and output components are automatically determined from the pipeline. @@ -131,25 +124,25 @@ class Interface(Blocks): def __init__( self, fn: Callable, - inputs: Optional[str | Component | List[str | Component]], - outputs: Optional[str | Component | List[str | Component]], - examples: Optional[List[Any] | List[List[Any]] | str] = None, - cache_examples: Optional[bool] = None, + inputs: str | IOComponent | List[str | IOComponent] | None, + outputs: str | IOComponent | List[str | IOComponent] | None, + examples: List[Any] | List[List[Any]] | str | None = None, + cache_examples: bool | None = None, examples_per_page: int = 10, live: bool = False, - interpretation: Optional[Callable | str] = None, + interpretation: Callable | str | None = None, num_shap: float = 2.0, - title: Optional[str] = None, - description: Optional[str] = None, - article: Optional[str] = None, - thumbnail: Optional[str] = None, - theme: Optional[str] = None, - css: Optional[str] = None, - allow_flagging: Optional[str] = None, - flagging_options: List[str] = None, + title: str | None = None, + description: str | None = None, + article: str | None = None, + thumbnail: str | None = None, + theme: str = "default", + css: str | None = None, + allow_flagging: str | None = None, + flagging_options: List[str] | None = None, flagging_dir: str = "flagged", flagging_callback: FlaggingCallback = CSVLogger(), - analytics_enabled: Optional[bool] = None, + analytics_enabled: bool | None = None, batch: bool = False, max_batch_size: int = 4, _api_mode: bool = False, @@ -184,21 +177,11 @@ class Interface(Blocks): analytics_enabled=analytics_enabled, mode="interface", css=css, - title=title, + title=title or "Gradio", theme=theme, **kwargs, ) - self.interface_type = self.InterfaceTypes.STANDARD - if (inputs is None or inputs == []) and (outputs is None or outputs == []): - raise ValueError("Must provide at least one of `inputs` or `outputs`") - elif outputs is None or outputs == []: - outputs = [] - self.interface_type = self.InterfaceTypes.INPUT_ONLY - elif inputs is None or inputs == []: - inputs = [] - self.interface_type = self.InterfaceTypes.OUTPUT_ONLY - if isinstance(fn, list): raise DeprecationWarning( "The `fn` parameter only accepts a single function, support for a list " @@ -206,6 +189,19 @@ class Interface(Blocks): "instead." ) + self.interface_type = InterfaceTypes.STANDARD + if (inputs is None or inputs == []) and (outputs is None or outputs == []): + raise ValueError("Must provide at least one of `inputs` or `outputs`") + elif outputs is None or outputs == []: + outputs = [] + self.interface_type = InterfaceTypes.INPUT_ONLY + elif inputs is None or inputs == []: + inputs = [] + self.interface_type = InterfaceTypes.OUTPUT_ONLY + + assert isinstance(inputs, (str, list, IOComponent)) + assert isinstance(outputs, (str, list, IOComponent)) + if not isinstance(inputs, list): inputs = [inputs] if not isinstance(outputs, list): @@ -234,7 +230,7 @@ class Interface(Blocks): state_output_index = state_output_indexes[0] if inputs[state_input_index] == "state": default = utils.get_default_args(fn)[state_input_index] - state_variable = State(value=default) + state_variable = State(value=default) # type: ignore else: state_variable = inputs[state_input_index] @@ -266,13 +262,14 @@ class Interface(Blocks): i is o for i, o in zip(self.input_components, self.output_components) ] if all(same_components): - self.interface_type = self.InterfaceTypes.UNIFIED + self.interface_type = InterfaceTypes.UNIFIED if self.interface_type in [ - self.InterfaceTypes.STANDARD, - self.InterfaceTypes.OUTPUT_ONLY, + InterfaceTypes.STANDARD, + InterfaceTypes.OUTPUT_ONLY, ]: for o in self.output_components: + assert isinstance(o, IOComponent) o.interactive = False # Force output components to be non-interactive if ( @@ -375,6 +372,8 @@ class Interface(Blocks): self.flagging_options = flagging_options self.flagging_callback = flagging_callback self.flagging_dir = flagging_dir + self.batch = batch + self.max_batch_size = max_batch_size self.save_to = None # Used for selenium tests self.share = None @@ -395,7 +394,7 @@ class Interface(Blocks): "allow_flagging": allow_flagging, "custom_css": self.css is not None, "theme": self.theme, - "version": pkgutil.get_data(__name__, "version.txt") + "version": (pkgutil.get_data(__name__, "version.txt") or b"") .decode("ascii") .strip(), } @@ -406,9 +405,11 @@ class Interface(Blocks): param_names = inspect.getfullargspec(self.fn)[0] for component, param_name in zip(self.input_components, param_names): + assert isinstance(component, IOComponent) if component.label is None: component.label = param_name for i, component in enumerate(self.output_components): + assert isinstance(component, IOComponent) if component.label is None: if len(self.output_components) == 1: component.label = "output" @@ -417,257 +418,342 @@ class Interface(Blocks): if self.allow_flagging != "never": if ( - self.interface_type == self.InterfaceTypes.UNIFIED + self.interface_type == InterfaceTypes.UNIFIED or self.allow_flagging == "auto" ): - self.flagging_callback.setup(self.input_components, self.flagging_dir) - elif self.interface_type == self.InterfaceTypes.INPUT_ONLY: + self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore + elif self.interface_type == InterfaceTypes.INPUT_ONLY: pass else: self.flagging_callback.setup( - self.input_components + self.output_components, self.flagging_dir + self.input_components + self.output_components, self.flagging_dir # type: ignore ) + # Render the Gradio UI with self: - if self.title: - Markdown( - "

" - + self.title - + "

" - ) - if self.description: - Markdown(self.description) + self.render_title_description() - def render_flag_btns(flagging_options): - if flagging_options is None: - return [(Button("Flag"), None)] - else: - return [ - ( - Button("Flag as " + flag_option), - flag_option, - ) - for flag_option in flagging_options - ] + submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None + interpretation_btn, interpretation_set = None, None + input_component_column, interpret_component_column = None, None with Row().style(equal_height=False): if self.interface_type in [ - self.InterfaceTypes.STANDARD, - self.InterfaceTypes.INPUT_ONLY, - self.InterfaceTypes.UNIFIED, + InterfaceTypes.STANDARD, + InterfaceTypes.INPUT_ONLY, + InterfaceTypes.UNIFIED, ]: - with Column(variant="panel"): - input_component_column = Column() - with input_component_column: - for component in self.input_components: - component.render() - if self.interpretation: - interpret_component_column = Column(visible=False) - interpretation_set = [] - with interpret_component_column: - for component in self.input_components: - interpretation_set.append(Interpretation(component)) - with Row(): - if self.interface_type in [ - self.InterfaceTypes.STANDARD, - self.InterfaceTypes.INPUT_ONLY, - ]: - clear_btn = Button("Clear") - if not self.live: - submit_btn = Button("Submit", variant="primary") - # Stopping jobs only works if the queue is enabled - # We don't know if the queue is enabled when the interface - # is created. We use whether a generator function is provided - # as a proxy of whether the queue will be enabled. - # Using a generator function without the queue will raise an error. - if inspect.isgeneratorfunction(fn): - stop_btn = Button("Stop", variant="stop") - - elif self.interface_type == self.InterfaceTypes.UNIFIED: - clear_btn = Button("Clear") - submit_btn = Button("Submit", variant="primary") - if inspect.isgeneratorfunction(fn) and not self.live: - stop_btn = Button("Stop", variant="stop") - if self.allow_flagging == "manual": - flag_btns = render_flag_btns(self.flagging_options) - + ( + submit_btn, + clear_btn, + stop_btn, + flag_btns, + input_component_column, + interpret_component_column, + interpretation_set, + ) = self.render_input_column() if self.interface_type in [ - self.InterfaceTypes.STANDARD, - self.InterfaceTypes.OUTPUT_ONLY, + InterfaceTypes.STANDARD, + InterfaceTypes.OUTPUT_ONLY, ]: + ( + submit_btn_out, + clear_btn_2_out, + stop_btn_2_out, + flag_btns_out, + interpretation_btn, + ) = self.render_output_column(submit_btn) + submit_btn = submit_btn or submit_btn_out + clear_btn = clear_btn or clear_btn_2_out + stop_btn = stop_btn or stop_btn_2_out + flag_btns = flag_btns or flag_btns_out - with Column(variant="panel"): - for component in self.output_components: - if not (isinstance(component, State)): - component.render() - with Row(): - if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY: - clear_btn = Button("Clear") - submit_btn = Button("Generate", variant="primary") - if inspect.isgeneratorfunction(fn) and not self.live: - # Stopping jobs only works if the queue is enabled - # We don't know if the queue is enabled when the interface - # is created. We use whether a generator function is provided - # as a proxy of whether the queue will be enabled. - # Using a generator function without the queue will raise an error. - stop_btn = Button("Stop", variant="stop") - if self.allow_flagging == "manual": - flag_btns = render_flag_btns(self.flagging_options) - if self.interpretation: - interpretation_btn = Button("Interpret") - if self.live: - if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY: - super().load(self.fn, None, self.output_components) - submit_btn.click( - self.fn, - None, - self.output_components, - api_name="predict", - preprocess=not (self.api_mode), - postprocess=not (self.api_mode), - batch=batch, - max_batch_size=max_batch_size, - ) - else: - for component in self.input_components: - if isinstance(component, Streamable): - if component.streaming: - component.stream( - self.fn, - self.input_components, - self.output_components, - api_name="predict", - preprocess=not (self.api_mode), - postprocess=not (self.api_mode), - ) - continue - else: - print( - "Hint: Set streaming=True for " - + component.__class__.__name__ - + " component to use live streaming." - ) - if isinstance(component, Changeable): - component.change( - self.fn, - self.input_components, - self.output_components, - api_name="predict", - preprocess=not (self.api_mode), - postprocess=not (self.api_mode), - ) - else: - pred = submit_btn.click( - self.fn, - self.input_components, - self.output_components, - api_name="predict", - scroll_to_output=True, - preprocess=not (self.api_mode), - postprocess=not (self.api_mode), - batch=batch, - max_batch_size=max_batch_size, - ) - if inspect.isgeneratorfunction(fn): - stop_btn.click( - None, - inputs=None, - outputs=None, - cancels=[pred], - ) + assert clear_btn is not None, "Clear button not rendered" - clear_btn.click( - None, - [], - ( - self.input_components - + self.output_components - + ( - [input_component_column] - if self.interface_type - in [ - self.InterfaceTypes.STANDARD, - self.InterfaceTypes.INPUT_ONLY, - self.InterfaceTypes.UNIFIED, - ] - else [] - ) - + ([interpret_component_column] if self.interpretation else []) - ), - _js=f"""() => {json.dumps( - [component.cleared_value if hasattr(component, "cleared_value") else None - for component in self.input_components + self.output_components] + ( - [Column.update(visible=True)] - if self.interface_type - in [ - self.InterfaceTypes.STANDARD, - self.InterfaceTypes.INPUT_ONLY, - self.InterfaceTypes.UNIFIED, - ] - else [] - ) - + ([Column.update(visible=False)] if self.interpretation else []) - )} - """, + self.attach_submit_events(submit_btn, stop_btn) + self.attach_clear_events( + clear_btn, input_component_column, interpret_component_column + ) + self.attach_interpretation_events( + interpretation_btn, + interpretation_set, + input_component_column, + interpret_component_column, ) - if self.allow_flagging in ["manual", "auto"]: - if self.interface_type in [ - self.InterfaceTypes.STANDARD, - self.InterfaceTypes.OUTPUT_ONLY, - self.InterfaceTypes.UNIFIED, - ]: - if self.allow_flagging == "auto": - flag_btns = [(submit_btn, None)] - if ( - self.interface_type == self.InterfaceTypes.UNIFIED - or self.allow_flagging == "auto" - ): - flag_components = self.input_components - else: - flag_components = self.input_components + self.output_components - for flag_btn, flag_option in flag_btns: - flag_method = FlagMethod(self.flagging_callback, flag_option) - flag_btn.click( - flag_method, - inputs=flag_components, - outputs=[], - preprocess=False, - queue=False, - ) - - if self.examples: - non_state_inputs = [ - c for c in self.input_components if not isinstance(c, State) - ] - non_state_outputs = [ - c for c in self.output_components if not isinstance(c, State) - ] - self.examples_handler = Examples( - examples=examples, - inputs=non_state_inputs, - outputs=non_state_outputs, - fn=self.fn, - cache_examples=self.cache_examples, - examples_per_page=examples_per_page, - _api_mode=_api_mode, - batch=batch, - ) - - if self.interpretation: - interpretation_btn.click( - self.interpret_func, - inputs=self.input_components + self.output_components, - outputs=interpretation_set - + [input_component_column, interpret_component_column], - preprocess=False, - ) - - if self.article: - Markdown(self.article) + self.render_flagging_buttons(flag_btns) + self.render_examples() + self.render_article() self.config = self.get_config_file() + def render_title_description(self) -> None: + if self.title: + Markdown( + "

" + + self.title + + "

" + ) + if self.description: + Markdown(self.description) + + def render_flag_btns(self) -> List[Tuple[Button, str | None]]: + if self.flagging_options is None: + return [(Button("Flag"), None)] + else: + return [ + ( + Button("Flag as " + flag_option), + flag_option, + ) + for flag_option in self.flagging_options + ] + + def render_input_column( + self, + ) -> Tuple[ + Button | None, + Button | None, + Button | None, + List | None, + Column, + Column | None, + List[Interpretation] | None, + ]: + submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None + interpret_component_column, interpretation_set = None, None + + with Column(variant="panel"): + input_component_column = Column() + with input_component_column: + for component in self.input_components: + component.render() + if self.interpretation: + interpret_component_column = Column(visible=False) + interpretation_set = [] + with interpret_component_column: + for component in self.input_components: + interpretation_set.append(Interpretation(component)) + with Row(): + if self.interface_type in [ + InterfaceTypes.STANDARD, + InterfaceTypes.INPUT_ONLY, + ]: + clear_btn = Button("Clear") + if not self.live: + submit_btn = Button("Submit", variant="primary") + # Stopping jobs only works if the queue is enabled + # We don't know if the queue is enabled when the interface + # is created. We use whether a generator function is provided + # as a proxy of whether the queue will be enabled. + # Using a generator function without the queue will raise an error. + if inspect.isgeneratorfunction(self.fn): + stop_btn = Button("Stop", variant="stop") + elif self.interface_type == InterfaceTypes.UNIFIED: + clear_btn = Button("Clear") + submit_btn = Button("Submit", variant="primary") + if inspect.isgeneratorfunction(self.fn) and not self.live: + stop_btn = Button("Stop", variant="stop") + if self.allow_flagging == "manual": + flag_btns = self.render_flag_btns() + elif self.allow_flagging == "auto": + flag_btns = [(submit_btn, None)] + return ( + submit_btn, + clear_btn, + stop_btn, + flag_btns, + input_component_column, + interpret_component_column, + interpretation_set, + ) + + def render_output_column( + self, + submit_btn_in: Button | None, + ) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]: + submit_btn = submit_btn_in + interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None + + with Column(variant="panel"): + for component in self.output_components: + if not (isinstance(component, State)): + component.render() + with Row(): + if self.interface_type == InterfaceTypes.OUTPUT_ONLY: + clear_btn = Button("Clear") + submit_btn = Button("Generate", variant="primary") + if inspect.isgeneratorfunction(self.fn) and not self.live: + # Stopping jobs only works if the queue is enabled + # We don't know if the queue is enabled when the interface + # is created. We use whether a generator function is provided + # as a proxy of whether the queue will be enabled. + # Using a generator function without the queue will raise an error. + stop_btn = Button("Stop", variant="stop") + if self.allow_flagging == "manual": + flag_btns = self.render_flag_btns() + elif self.allow_flagging == "auto": + assert submit_btn is not None, "Submit button not rendered" + flag_btns = [(submit_btn, None)] + if self.interpretation: + interpretation_btn = Button("Interpret") + + return submit_btn, clear_btn, stop_btn, flag_btns, interpretation_btn + + def render_article(self): + if self.article: + Markdown(self.article) + + def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None): + if self.live: + if self.interface_type == InterfaceTypes.OUTPUT_ONLY: + assert submit_btn is not None, "Submit button not rendered" + super().load(self.fn, None, self.output_components) + # For output-only interfaces, the user probably still want a "generate" + # button even if the Interface is live + submit_btn.click( + self.fn, + None, + self.output_components, + api_name="predict", + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + batch=self.batch, + max_batch_size=self.max_batch_size, + ) + else: + for component in self.input_components: + if isinstance(component, Streamable) and component.streaming: + component.stream( + self.fn, + self.input_components, + self.output_components, + api_name="predict", + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + ) + continue + if isinstance(component, Changeable): + component.change( + self.fn, + self.input_components, + self.output_components, + api_name="predict", + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + ) + else: + assert submit_btn is not None, "Submit button not rendered" + pred = submit_btn.click( + self.fn, + self.input_components, + self.output_components, + api_name="predict", + scroll_to_output=True, + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + batch=self.batch, + max_batch_size=self.max_batch_size, + ) + if stop_btn: + stop_btn.click( + None, + inputs=None, + outputs=None, + cancels=[pred], + ) + + def attach_clear_events( + self, + clear_btn: Button, + input_component_column: Column | None, + interpret_component_column: Column | None, + ): + clear_btn.click( + None, + [], + ( + self.input_components + + self.output_components + + ([input_component_column] if input_component_column else []) + + ([interpret_component_column] if self.interpretation else []) + ), # type: ignore + _js=f"""() => {json.dumps( + [getattr(component, "cleared_value", None) + for component in self.input_components + self.output_components] + ( + [Column.update(visible=True)] + if self.interface_type + in [ + InterfaceTypes.STANDARD, + InterfaceTypes.INPUT_ONLY, + InterfaceTypes.UNIFIED, + ] + else [] + ) + + ([Column.update(visible=False)] if self.interpretation else []) + )} + """, + ) + + def attach_interpretation_events( + self, + interpretation_btn: Button | None, + interpretation_set: List[Interpretation] | None, + input_component_column: Column | None, + interpret_component_column: Column | None, + ): + if interpretation_btn: + interpretation_btn.click( + self.interpret_func, + inputs=self.input_components + self.output_components, + outputs=interpretation_set + or [] + [input_component_column, interpret_component_column], # type: ignore + preprocess=False, + ) + + def render_flagging_buttons(self, flag_btns: List | None): + if flag_btns: + if self.interface_type in [ + InterfaceTypes.STANDARD, + InterfaceTypes.OUTPUT_ONLY, + InterfaceTypes.UNIFIED, + ]: + if ( + self.interface_type == InterfaceTypes.UNIFIED + or self.allow_flagging == "auto" + ): + flag_components = self.input_components + else: + flag_components = self.input_components + self.output_components + for flag_btn, flag_option in flag_btns: + flag_method = FlagMethod(self.flagging_callback, flag_option) + flag_btn.click( + flag_method, + inputs=flag_components, + outputs=[], + preprocess=False, + queue=False, + ) + + def render_examples(self): + if self.examples: + non_state_inputs = [ + c for c in self.input_components if not isinstance(c, State) + ] + non_state_outputs = [ + c for c in self.output_components if not isinstance(c, State) + ] + self.examples_handler = Examples( + examples=self.examples, + inputs=non_state_inputs, # type: ignore + outputs=non_state_outputs, # type: ignore + fn=self.fn, + cache_examples=self.cache_examples, + examples_per_page=self.examples_per_page, + _api_mode=self.api_mode, + batch=self.batch, + ) + def __str__(self): return self.__repr__() @@ -683,7 +769,7 @@ class Interface(Blocks): return repr async def interpret_func(self, *args): - return await self.interpret(args) + [ + return await self.interpret(list(args)) + [ Column.update(visible=False), Column.update(visible=True), ] @@ -698,20 +784,9 @@ class Interface(Blocks): def test_launch(self) -> None: """ - Passes a few samples through the function to test if the inputs/outputs - components are consistent with the function parameter and return values. + Deprecated. """ - print("Test launch: {}()...".format(self.__name__), end=" ") - raw_input = [] - for input_component in self.input_components: - if input_component.test_input is None: - print("SKIPPED") - break - else: - raw_input.append(input_component.test_input) - else: - self(raw_input) - print("PASSED") + warnings.warn("The Interface.test_launch() function is deprecated.") @document() @@ -725,11 +800,11 @@ class TabbedInterface(Blocks): def __init__( self, interface_list: List[Interface], - tab_names: Optional[List[str]] = None, - title: Optional[str] = None, + tab_names: List[str] | None = None, + title: str | None = None, theme: str = "default", - analytics_enabled: Optional[bool] = None, - css: Optional[str] = None, + analytics_enabled: bool | None = None, + css: str | None = None, ): """ Parameters: @@ -743,7 +818,7 @@ class TabbedInterface(Blocks): a Gradio Tabbed Interface for the given interfaces """ super().__init__( - title=title, + title=title or "Gradio", theme=theme, analytics_enabled=analytics_enabled, mode="tabbed_interface", @@ -760,7 +835,7 @@ class TabbedInterface(Blocks): ) with Tabs(): for (interface, tab_name) in zip(interface_list, tab_names): - with TabItem(label=tab_name): + with Tab(label=tab_name): interface.render() diff --git a/scripts/type_check_backend.sh b/scripts/type_check_backend.sh index 5422d5d0eb..62fa6507a1 100644 --- a/scripts/type_check_backend.sh +++ b/scripts/type_check_backend.sh @@ -6,4 +6,4 @@ pip_required pip install --upgrade pip pip install pyright cd gradio -pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py +pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py diff --git a/test/test_interfaces.py b/test/test_interfaces.py index 0408fe4829..9468922760 100644 --- a/test/test_interfaces.py +++ b/test/test_interfaces.py @@ -69,15 +69,6 @@ class TestInterface: ) assert dataset_check - def test_test_launch(self): - with captured_output() as (out, err): - prediction_fn = lambda x: x - prediction_fn.__name__ = "prediction_fn" - interface = Interface(prediction_fn, "textbox", "label") - interface.test_launch() - output = out.getvalue().strip() - assert output == "Test launch: prediction_fn()... PASSED" - @mock.patch("time.sleep") def test_block_thread(self, mock_sleep): with pytest.raises(KeyboardInterrupt):