Refactored interface.py (#2902)

* started pathlib

* blocks.py

* more changes

* fixes

* typing

* formatting

* typing

* renaming files

* changelog

* script

* changelog

* lint

* routes

* renamed

* state

* formatting

* state

* type check script

* remove strictness

* switched to pyright

* switched to pyright

* fixed flaky tests

* fixed test xray

* fixed load test

* fixed blocks tests

* formatting

* fixed components test

* uncomment tests

* fixed interpretation tests

* formatting

* last tests hopefully

* argh lint

* component

* fixed based on review

* refactor

* components.py t yping

* components.py

* formatting

* lint script

* merge

* merge

* lint

* pathlib

* lint

* events too

* lint script

* fixing tests

* lint

* examples

* serializing

* more files

* formatting

* flagging.py

* added to lint script

* fixed tab

* interface.py

* attempt fix

* refactoring interface

* interface refactor

* formatting

* fix for live interfaces

* lint

* serialize fix

* formatting

* all demos queue

* added type check

* formatting
This commit is contained in:
Abubakar Abid 2022-12-29 15:30:44 -05:00 committed by GitHub
parent 5310782ed9
commit d46f0cd1ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 393 additions and 319 deletions

View File

@ -1,3 +1,4 @@
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
@ -18,3 +19,10 @@ class PredictBody(BaseModel):
class ResetBody(BaseModel):
session_hash: str
fn_index: int
class InterfaceTypes(Enum):
STANDARD = auto()
INPUT_ONLY = auto()
OUTPUT_ONLY = auto()
UNIFIED = auto()

View File

@ -34,7 +34,7 @@ def set_cancel_events(
class Changeable(Block):
def change(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -97,7 +97,7 @@ class Changeable(Block):
class Clickable(Block):
def click(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -161,7 +161,7 @@ class Clickable(Block):
class Submittable(Block):
def submit(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -226,7 +226,7 @@ class Submittable(Block):
class Editable(Block):
def edit(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -290,7 +290,7 @@ class Editable(Block):
class Clearable(Block):
def clear(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -354,7 +354,7 @@ class Clearable(Block):
class Playable(Block):
def play(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -416,7 +416,7 @@ class Playable(Block):
def pause(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -478,7 +478,7 @@ class Playable(Block):
def stop(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -542,7 +542,7 @@ class Playable(Block):
class Streamable(Block):
def stream(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -608,7 +608,7 @@ class Streamable(Block):
class Blurrable(Block):
def blur(
self,
fn: Callable,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
api_name: str | None = None,
@ -665,7 +665,7 @@ class Blurrable(Block):
class Uploadable(Block):
def upload(
self,
fn: Callable,
fn: Callable | None,
inputs: List[Component],
outputs: Component | List[Component] | None = None,
api_name: str | None = None,

View File

@ -551,10 +551,10 @@ class FlagMethod:
Helper class that contains the flagging button option and callback
"""
def __init__(self, flagging_callback, flag_option=None):
def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
self.__name__ = "Flag"
def __call__(self, *flag_data):
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)

View File

@ -12,34 +12,33 @@ import pkgutil
import re
import warnings
import weakref
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Callable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from markdown_it import MarkdownIt
from mdit_py_plugins.dollarmath import dollarmath_plugin
from mdit_py_plugins.footnote import footnote_plugin
from mdit_py_plugins.dollarmath.index import dollarmath_plugin
from mdit_py_plugins.footnote.index import footnote_plugin
from gradio import Examples, interpretation, utils
from gradio.blocks import Blocks
from gradio.components import (
Button,
Component,
Interpretation,
IOComponent,
Markdown,
State,
get_component_instance,
)
from gradio.data_classes import InterfaceTypes
from gradio.documentation import document, set_documentation_group
from gradio.events import Changeable, Streamable
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Column, Row, TabItem, Tabs
from gradio.layouts import Column, Row, Tab, Tabs
from gradio.pipelines import load_from_pipeline
set_documentation_group("interface")
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import transformers
from transformers.pipelines.base import Pipeline
@document("launch", "load", "from_pipeline", "integrate", "queue")
@ -66,12 +65,6 @@ class Interface(Blocks):
# stores references to all currently existing Interface instances
instances: weakref.WeakSet = weakref.WeakSet()
class InterfaceTypes(Enum):
STANDARD = auto()
INPUT_ONLY = auto()
OUTPUT_ONLY = auto()
UNIFIED = auto()
@classmethod
def get_instances(cls) -> List[Interface]:
"""
@ -83,9 +76,9 @@ class Interface(Blocks):
def load(
cls,
name: str,
src: Optional[str] = None,
api_key: Optional[str] = None,
alias: Optional[str] = None,
src: str | None = None,
api_key: str | None = None,
alias: str | None = None,
**kwargs,
) -> Interface:
"""
@ -109,7 +102,7 @@ class Interface(Blocks):
return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs)
@classmethod
def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface:
def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface:
"""
Class method that constructs an Interface from a Hugging Face transformers.Pipeline object.
The input and output components are automatically determined from the pipeline.
@ -131,25 +124,25 @@ class Interface(Blocks):
def __init__(
self,
fn: Callable,
inputs: Optional[str | Component | List[str | Component]],
outputs: Optional[str | Component | List[str | Component]],
examples: Optional[List[Any] | List[List[Any]] | str] = None,
cache_examples: Optional[bool] = None,
inputs: str | IOComponent | List[str | IOComponent] | None,
outputs: str | IOComponent | List[str | IOComponent] | None,
examples: List[Any] | List[List[Any]] | str | None = None,
cache_examples: bool | None = None,
examples_per_page: int = 10,
live: bool = False,
interpretation: Optional[Callable | str] = None,
interpretation: Callable | str | None = None,
num_shap: float = 2.0,
title: Optional[str] = None,
description: Optional[str] = None,
article: Optional[str] = None,
thumbnail: Optional[str] = None,
theme: Optional[str] = None,
css: Optional[str] = None,
allow_flagging: Optional[str] = None,
flagging_options: List[str] = None,
title: str | None = None,
description: str | None = None,
article: str | None = None,
thumbnail: str | None = None,
theme: str = "default",
css: str | None = None,
allow_flagging: str | None = None,
flagging_options: List[str] | None = None,
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback = CSVLogger(),
analytics_enabled: Optional[bool] = None,
analytics_enabled: bool | None = None,
batch: bool = False,
max_batch_size: int = 4,
_api_mode: bool = False,
@ -184,21 +177,11 @@ class Interface(Blocks):
analytics_enabled=analytics_enabled,
mode="interface",
css=css,
title=title,
title=title or "Gradio",
theme=theme,
**kwargs,
)
self.interface_type = self.InterfaceTypes.STANDARD
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
raise ValueError("Must provide at least one of `inputs` or `outputs`")
elif outputs is None or outputs == []:
outputs = []
self.interface_type = self.InterfaceTypes.INPUT_ONLY
elif inputs is None or inputs == []:
inputs = []
self.interface_type = self.InterfaceTypes.OUTPUT_ONLY
if isinstance(fn, list):
raise DeprecationWarning(
"The `fn` parameter only accepts a single function, support for a list "
@ -206,6 +189,19 @@ class Interface(Blocks):
"instead."
)
self.interface_type = InterfaceTypes.STANDARD
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
raise ValueError("Must provide at least one of `inputs` or `outputs`")
elif outputs is None or outputs == []:
outputs = []
self.interface_type = InterfaceTypes.INPUT_ONLY
elif inputs is None or inputs == []:
inputs = []
self.interface_type = InterfaceTypes.OUTPUT_ONLY
assert isinstance(inputs, (str, list, IOComponent))
assert isinstance(outputs, (str, list, IOComponent))
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):
@ -234,7 +230,7 @@ class Interface(Blocks):
state_output_index = state_output_indexes[0]
if inputs[state_input_index] == "state":
default = utils.get_default_args(fn)[state_input_index]
state_variable = State(value=default)
state_variable = State(value=default) # type: ignore
else:
state_variable = inputs[state_input_index]
@ -266,13 +262,14 @@ class Interface(Blocks):
i is o for i, o in zip(self.input_components, self.output_components)
]
if all(same_components):
self.interface_type = self.InterfaceTypes.UNIFIED
self.interface_type = InterfaceTypes.UNIFIED
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
InterfaceTypes.STANDARD,
InterfaceTypes.OUTPUT_ONLY,
]:
for o in self.output_components:
assert isinstance(o, IOComponent)
o.interactive = False # Force output components to be non-interactive
if (
@ -375,6 +372,8 @@ class Interface(Blocks):
self.flagging_options = flagging_options
self.flagging_callback = flagging_callback
self.flagging_dir = flagging_dir
self.batch = batch
self.max_batch_size = max_batch_size
self.save_to = None # Used for selenium tests
self.share = None
@ -395,7 +394,7 @@ class Interface(Blocks):
"allow_flagging": allow_flagging,
"custom_css": self.css is not None,
"theme": self.theme,
"version": pkgutil.get_data(__name__, "version.txt")
"version": (pkgutil.get_data(__name__, "version.txt") or b"")
.decode("ascii")
.strip(),
}
@ -406,9 +405,11 @@ class Interface(Blocks):
param_names = inspect.getfullargspec(self.fn)[0]
for component, param_name in zip(self.input_components, param_names):
assert isinstance(component, IOComponent)
if component.label is None:
component.label = param_name
for i, component in enumerate(self.output_components):
assert isinstance(component, IOComponent)
if component.label is None:
if len(self.output_components) == 1:
component.label = "output"
@ -417,257 +418,342 @@ class Interface(Blocks):
if self.allow_flagging != "never":
if (
self.interface_type == self.InterfaceTypes.UNIFIED
self.interface_type == InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
self.flagging_callback.setup(self.input_components, self.flagging_dir)
elif self.interface_type == self.InterfaceTypes.INPUT_ONLY:
self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
elif self.interface_type == InterfaceTypes.INPUT_ONLY:
pass
else:
self.flagging_callback.setup(
self.input_components + self.output_components, self.flagging_dir
self.input_components + self.output_components, self.flagging_dir # type: ignore
)
# Render the Gradio UI
with self:
if self.title:
Markdown(
"<h1 style='text-align: center; margin-bottom: 1rem'>"
+ self.title
+ "</h1>"
)
if self.description:
Markdown(self.description)
self.render_title_description()
def render_flag_btns(flagging_options):
if flagging_options is None:
return [(Button("Flag"), None)]
else:
return [
(
Button("Flag as " + flag_option),
flag_option,
)
for flag_option in flagging_options
]
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
interpretation_btn, interpretation_set = None, None
input_component_column, interpret_component_column = None, None
with Row().style(equal_height=False):
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
self.InterfaceTypes.UNIFIED,
InterfaceTypes.STANDARD,
InterfaceTypes.INPUT_ONLY,
InterfaceTypes.UNIFIED,
]:
with Column(variant="panel"):
input_component_column = Column()
with input_component_column:
for component in self.input_components:
component.render()
if self.interpretation:
interpret_component_column = Column(visible=False)
interpretation_set = []
with interpret_component_column:
for component in self.input_components:
interpretation_set.append(Interpretation(component))
with Row():
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
]:
clear_btn = Button("Clear")
if not self.live:
submit_btn = Button("Submit", variant="primary")
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
if inspect.isgeneratorfunction(fn):
stop_btn = Button("Stop", variant="stop")
elif self.interface_type == self.InterfaceTypes.UNIFIED:
clear_btn = Button("Clear")
submit_btn = Button("Submit", variant="primary")
if inspect.isgeneratorfunction(fn) and not self.live:
stop_btn = Button("Stop", variant="stop")
if self.allow_flagging == "manual":
flag_btns = render_flag_btns(self.flagging_options)
(
submit_btn,
clear_btn,
stop_btn,
flag_btns,
input_component_column,
interpret_component_column,
interpretation_set,
) = self.render_input_column()
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
InterfaceTypes.STANDARD,
InterfaceTypes.OUTPUT_ONLY,
]:
(
submit_btn_out,
clear_btn_2_out,
stop_btn_2_out,
flag_btns_out,
interpretation_btn,
) = self.render_output_column(submit_btn)
submit_btn = submit_btn or submit_btn_out
clear_btn = clear_btn or clear_btn_2_out
stop_btn = stop_btn or stop_btn_2_out
flag_btns = flag_btns or flag_btns_out
with Column(variant="panel"):
for component in self.output_components:
if not (isinstance(component, State)):
component.render()
with Row():
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
clear_btn = Button("Clear")
submit_btn = Button("Generate", variant="primary")
if inspect.isgeneratorfunction(fn) and not self.live:
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
stop_btn = Button("Stop", variant="stop")
if self.allow_flagging == "manual":
flag_btns = render_flag_btns(self.flagging_options)
if self.interpretation:
interpretation_btn = Button("Interpret")
if self.live:
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
super().load(self.fn, None, self.output_components)
submit_btn.click(
self.fn,
None,
self.output_components,
api_name="predict",
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=batch,
max_batch_size=max_batch_size,
)
else:
for component in self.input_components:
if isinstance(component, Streamable):
if component.streaming:
component.stream(
self.fn,
self.input_components,
self.output_components,
api_name="predict",
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
)
continue
else:
print(
"Hint: Set streaming=True for "
+ component.__class__.__name__
+ " component to use live streaming."
)
if isinstance(component, Changeable):
component.change(
self.fn,
self.input_components,
self.output_components,
api_name="predict",
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
)
else:
pred = submit_btn.click(
self.fn,
self.input_components,
self.output_components,
api_name="predict",
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=batch,
max_batch_size=max_batch_size,
)
if inspect.isgeneratorfunction(fn):
stop_btn.click(
None,
inputs=None,
outputs=None,
cancels=[pred],
)
assert clear_btn is not None, "Clear button not rendered"
clear_btn.click(
None,
[],
(
self.input_components
+ self.output_components
+ (
[input_component_column]
if self.interface_type
in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]
else []
)
+ ([interpret_component_column] if self.interpretation else [])
),
_js=f"""() => {json.dumps(
[component.cleared_value if hasattr(component, "cleared_value") else None
for component in self.input_components + self.output_components] + (
[Column.update(visible=True)]
if self.interface_type
in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.INPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]
else []
)
+ ([Column.update(visible=False)] if self.interpretation else [])
)}
""",
self.attach_submit_events(submit_btn, stop_btn)
self.attach_clear_events(
clear_btn, input_component_column, interpret_component_column
)
self.attach_interpretation_events(
interpretation_btn,
interpretation_set,
input_component_column,
interpret_component_column,
)
if self.allow_flagging in ["manual", "auto"]:
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]:
if self.allow_flagging == "auto":
flag_btns = [(submit_btn, None)]
if (
self.interface_type == self.InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
flag_components = self.input_components
else:
flag_components = self.input_components + self.output_components
for flag_btn, flag_option in flag_btns:
flag_method = FlagMethod(self.flagging_callback, flag_option)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=[],
preprocess=False,
queue=False,
)
if self.examples:
non_state_inputs = [
c for c in self.input_components if not isinstance(c, State)
]
non_state_outputs = [
c for c in self.output_components if not isinstance(c, State)
]
self.examples_handler = Examples(
examples=examples,
inputs=non_state_inputs,
outputs=non_state_outputs,
fn=self.fn,
cache_examples=self.cache_examples,
examples_per_page=examples_per_page,
_api_mode=_api_mode,
batch=batch,
)
if self.interpretation:
interpretation_btn.click(
self.interpret_func,
inputs=self.input_components + self.output_components,
outputs=interpretation_set
+ [input_component_column, interpret_component_column],
preprocess=False,
)
if self.article:
Markdown(self.article)
self.render_flagging_buttons(flag_btns)
self.render_examples()
self.render_article()
self.config = self.get_config_file()
def render_title_description(self) -> None:
if self.title:
Markdown(
"<h1 style='text-align: center; margin-bottom: 1rem'>"
+ self.title
+ "</h1>"
)
if self.description:
Markdown(self.description)
def render_flag_btns(self) -> List[Tuple[Button, str | None]]:
if self.flagging_options is None:
return [(Button("Flag"), None)]
else:
return [
(
Button("Flag as " + flag_option),
flag_option,
)
for flag_option in self.flagging_options
]
def render_input_column(
self,
) -> Tuple[
Button | None,
Button | None,
Button | None,
List | None,
Column,
Column | None,
List[Interpretation] | None,
]:
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
interpret_component_column, interpretation_set = None, None
with Column(variant="panel"):
input_component_column = Column()
with input_component_column:
for component in self.input_components:
component.render()
if self.interpretation:
interpret_component_column = Column(visible=False)
interpretation_set = []
with interpret_component_column:
for component in self.input_components:
interpretation_set.append(Interpretation(component))
with Row():
if self.interface_type in [
InterfaceTypes.STANDARD,
InterfaceTypes.INPUT_ONLY,
]:
clear_btn = Button("Clear")
if not self.live:
submit_btn = Button("Submit", variant="primary")
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
if inspect.isgeneratorfunction(self.fn):
stop_btn = Button("Stop", variant="stop")
elif self.interface_type == InterfaceTypes.UNIFIED:
clear_btn = Button("Clear")
submit_btn = Button("Submit", variant="primary")
if inspect.isgeneratorfunction(self.fn) and not self.live:
stop_btn = Button("Stop", variant="stop")
if self.allow_flagging == "manual":
flag_btns = self.render_flag_btns()
elif self.allow_flagging == "auto":
flag_btns = [(submit_btn, None)]
return (
submit_btn,
clear_btn,
stop_btn,
flag_btns,
input_component_column,
interpret_component_column,
interpretation_set,
)
def render_output_column(
self,
submit_btn_in: Button | None,
) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
submit_btn = submit_btn_in
interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
with Column(variant="panel"):
for component in self.output_components:
if not (isinstance(component, State)):
component.render()
with Row():
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
clear_btn = Button("Clear")
submit_btn = Button("Generate", variant="primary")
if inspect.isgeneratorfunction(self.fn) and not self.live:
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
stop_btn = Button("Stop", variant="stop")
if self.allow_flagging == "manual":
flag_btns = self.render_flag_btns()
elif self.allow_flagging == "auto":
assert submit_btn is not None, "Submit button not rendered"
flag_btns = [(submit_btn, None)]
if self.interpretation:
interpretation_btn = Button("Interpret")
return submit_btn, clear_btn, stop_btn, flag_btns, interpretation_btn
def render_article(self):
if self.article:
Markdown(self.article)
def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
if self.live:
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
assert submit_btn is not None, "Submit button not rendered"
super().load(self.fn, None, self.output_components)
# For output-only interfaces, the user probably still want a "generate"
# button even if the Interface is live
submit_btn.click(
self.fn,
None,
self.output_components,
api_name="predict",
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
)
else:
for component in self.input_components:
if isinstance(component, Streamable) and component.streaming:
component.stream(
self.fn,
self.input_components,
self.output_components,
api_name="predict",
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
)
continue
if isinstance(component, Changeable):
component.change(
self.fn,
self.input_components,
self.output_components,
api_name="predict",
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
)
else:
assert submit_btn is not None, "Submit button not rendered"
pred = submit_btn.click(
self.fn,
self.input_components,
self.output_components,
api_name="predict",
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
batch=self.batch,
max_batch_size=self.max_batch_size,
)
if stop_btn:
stop_btn.click(
None,
inputs=None,
outputs=None,
cancels=[pred],
)
def attach_clear_events(
self,
clear_btn: Button,
input_component_column: Column | None,
interpret_component_column: Column | None,
):
clear_btn.click(
None,
[],
(
self.input_components
+ self.output_components
+ ([input_component_column] if input_component_column else [])
+ ([interpret_component_column] if self.interpretation else [])
), # type: ignore
_js=f"""() => {json.dumps(
[getattr(component, "cleared_value", None)
for component in self.input_components + self.output_components] + (
[Column.update(visible=True)]
if self.interface_type
in [
InterfaceTypes.STANDARD,
InterfaceTypes.INPUT_ONLY,
InterfaceTypes.UNIFIED,
]
else []
)
+ ([Column.update(visible=False)] if self.interpretation else [])
)}
""",
)
def attach_interpretation_events(
self,
interpretation_btn: Button | None,
interpretation_set: List[Interpretation] | None,
input_component_column: Column | None,
interpret_component_column: Column | None,
):
if interpretation_btn:
interpretation_btn.click(
self.interpret_func,
inputs=self.input_components + self.output_components,
outputs=interpretation_set
or [] + [input_component_column, interpret_component_column], # type: ignore
preprocess=False,
)
def render_flagging_buttons(self, flag_btns: List | None):
if flag_btns:
if self.interface_type in [
InterfaceTypes.STANDARD,
InterfaceTypes.OUTPUT_ONLY,
InterfaceTypes.UNIFIED,
]:
if (
self.interface_type == InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
flag_components = self.input_components
else:
flag_components = self.input_components + self.output_components
for flag_btn, flag_option in flag_btns:
flag_method = FlagMethod(self.flagging_callback, flag_option)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=[],
preprocess=False,
queue=False,
)
def render_examples(self):
if self.examples:
non_state_inputs = [
c for c in self.input_components if not isinstance(c, State)
]
non_state_outputs = [
c for c in self.output_components if not isinstance(c, State)
]
self.examples_handler = Examples(
examples=self.examples,
inputs=non_state_inputs, # type: ignore
outputs=non_state_outputs, # type: ignore
fn=self.fn,
cache_examples=self.cache_examples,
examples_per_page=self.examples_per_page,
_api_mode=self.api_mode,
batch=self.batch,
)
def __str__(self):
return self.__repr__()
@ -683,7 +769,7 @@ class Interface(Blocks):
return repr
async def interpret_func(self, *args):
return await self.interpret(args) + [
return await self.interpret(list(args)) + [
Column.update(visible=False),
Column.update(visible=True),
]
@ -698,20 +784,9 @@ class Interface(Blocks):
def test_launch(self) -> None:
"""
Passes a few samples through the function to test if the inputs/outputs
components are consistent with the function parameter and return values.
Deprecated.
"""
print("Test launch: {}()...".format(self.__name__), end=" ")
raw_input = []
for input_component in self.input_components:
if input_component.test_input is None:
print("SKIPPED")
break
else:
raw_input.append(input_component.test_input)
else:
self(raw_input)
print("PASSED")
warnings.warn("The Interface.test_launch() function is deprecated.")
@document()
@ -725,11 +800,11 @@ class TabbedInterface(Blocks):
def __init__(
self,
interface_list: List[Interface],
tab_names: Optional[List[str]] = None,
title: Optional[str] = None,
tab_names: List[str] | None = None,
title: str | None = None,
theme: str = "default",
analytics_enabled: Optional[bool] = None,
css: Optional[str] = None,
analytics_enabled: bool | None = None,
css: str | None = None,
):
"""
Parameters:
@ -743,7 +818,7 @@ class TabbedInterface(Blocks):
a Gradio Tabbed Interface for the given interfaces
"""
super().__init__(
title=title,
title=title or "Gradio",
theme=theme,
analytics_enabled=analytics_enabled,
mode="tabbed_interface",
@ -760,7 +835,7 @@ class TabbedInterface(Blocks):
)
with Tabs():
for (interface, tab_name) in zip(interface_list, tab_names):
with TabItem(label=tab_name):
with Tab(label=tab_name):
interface.render()

View File

@ -6,4 +6,4 @@ pip_required
pip install --upgrade pip
pip install pyright
cd gradio
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py

View File

@ -69,15 +69,6 @@ class TestInterface:
)
assert dataset_check
def test_test_launch(self):
with captured_output() as (out, err):
prediction_fn = lambda x: x
prediction_fn.__name__ = "prediction_fn"
interface = Interface(prediction_fn, "textbox", "label")
interface.test_launch()
output = out.getvalue().strip()
assert output == "Test launch: prediction_fn()... PASSED"
@mock.patch("time.sleep")
def test_block_thread(self, mock_sleep):
with pytest.raises(KeyboardInterrupt):