mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Refactored interface.py
(#2902)
* started pathlib * blocks.py * more changes * fixes * typing * formatting * typing * renaming files * changelog * script * changelog * lint * routes * renamed * state * formatting * state * type check script * remove strictness * switched to pyright * switched to pyright * fixed flaky tests * fixed test xray * fixed load test * fixed blocks tests * formatting * fixed components test * uncomment tests * fixed interpretation tests * formatting * last tests hopefully * argh lint * component * fixed based on review * refactor * components.py t yping * components.py * formatting * lint script * merge * merge * lint * pathlib * lint * events too * lint script * fixing tests * lint * examples * serializing * more files * formatting * flagging.py * added to lint script * fixed tab * interface.py * attempt fix * refactoring interface * interface refactor * formatting * fix for live interfaces * lint * serialize fix * formatting * all demos queue * added type check * formatting
This commit is contained in:
parent
5310782ed9
commit
d46f0cd1ed
@ -1,3 +1,4 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -18,3 +19,10 @@ class PredictBody(BaseModel):
|
||||
class ResetBody(BaseModel):
|
||||
session_hash: str
|
||||
fn_index: int
|
||||
|
||||
|
||||
class InterfaceTypes(Enum):
|
||||
STANDARD = auto()
|
||||
INPUT_ONLY = auto()
|
||||
OUTPUT_ONLY = auto()
|
||||
UNIFIED = auto()
|
||||
|
@ -34,7 +34,7 @@ def set_cancel_events(
|
||||
class Changeable(Block):
|
||||
def change(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -97,7 +97,7 @@ class Changeable(Block):
|
||||
class Clickable(Block):
|
||||
def click(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -161,7 +161,7 @@ class Clickable(Block):
|
||||
class Submittable(Block):
|
||||
def submit(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -226,7 +226,7 @@ class Submittable(Block):
|
||||
class Editable(Block):
|
||||
def edit(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -290,7 +290,7 @@ class Editable(Block):
|
||||
class Clearable(Block):
|
||||
def clear(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -354,7 +354,7 @@ class Clearable(Block):
|
||||
class Playable(Block):
|
||||
def play(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -416,7 +416,7 @@ class Playable(Block):
|
||||
|
||||
def pause(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -478,7 +478,7 @@ class Playable(Block):
|
||||
|
||||
def stop(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -542,7 +542,7 @@ class Playable(Block):
|
||||
class Streamable(Block):
|
||||
def stream(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -608,7 +608,7 @@ class Streamable(Block):
|
||||
class Blurrable(Block):
|
||||
def blur(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
@ -665,7 +665,7 @@ class Blurrable(Block):
|
||||
class Uploadable(Block):
|
||||
def upload(
|
||||
self,
|
||||
fn: Callable,
|
||||
fn: Callable | None,
|
||||
inputs: List[Component],
|
||||
outputs: Component | List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
|
@ -551,10 +551,10 @@ class FlagMethod:
|
||||
Helper class that contains the flagging button option and callback
|
||||
"""
|
||||
|
||||
def __init__(self, flagging_callback, flag_option=None):
|
||||
def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
|
||||
self.flagging_callback = flagging_callback
|
||||
self.flag_option = flag_option
|
||||
self.__name__ = "Flag"
|
||||
|
||||
def __call__(self, *flag_data):
|
||||
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
|
||||
self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)
|
||||
|
@ -12,34 +12,33 @@ import pkgutil
|
||||
import re
|
||||
import warnings
|
||||
import weakref
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
from mdit_py_plugins.dollarmath import dollarmath_plugin
|
||||
from mdit_py_plugins.footnote import footnote_plugin
|
||||
from mdit_py_plugins.dollarmath.index import dollarmath_plugin
|
||||
from mdit_py_plugins.footnote.index import footnote_plugin
|
||||
|
||||
from gradio import Examples, interpretation, utils
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.components import (
|
||||
Button,
|
||||
Component,
|
||||
Interpretation,
|
||||
IOComponent,
|
||||
Markdown,
|
||||
State,
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.data_classes import InterfaceTypes
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.events import Changeable, Streamable
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
|
||||
from gradio.layouts import Column, Row, TabItem, Tabs
|
||||
from gradio.layouts import Column, Row, Tab, Tabs
|
||||
from gradio.pipelines import load_from_pipeline
|
||||
|
||||
set_documentation_group("interface")
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import transformers
|
||||
from transformers.pipelines.base import Pipeline
|
||||
|
||||
|
||||
@document("launch", "load", "from_pipeline", "integrate", "queue")
|
||||
@ -66,12 +65,6 @@ class Interface(Blocks):
|
||||
# stores references to all currently existing Interface instances
|
||||
instances: weakref.WeakSet = weakref.WeakSet()
|
||||
|
||||
class InterfaceTypes(Enum):
|
||||
STANDARD = auto()
|
||||
INPUT_ONLY = auto()
|
||||
OUTPUT_ONLY = auto()
|
||||
UNIFIED = auto()
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls) -> List[Interface]:
|
||||
"""
|
||||
@ -83,9 +76,9 @@ class Interface(Blocks):
|
||||
def load(
|
||||
cls,
|
||||
name: str,
|
||||
src: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
alias: Optional[str] = None,
|
||||
src: str | None = None,
|
||||
api_key: str | None = None,
|
||||
alias: str | None = None,
|
||||
**kwargs,
|
||||
) -> Interface:
|
||||
"""
|
||||
@ -109,7 +102,7 @@ class Interface(Blocks):
|
||||
return super().load(name=name, src=src, api_key=api_key, alias=alias, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pipeline(cls, pipeline: transformers.Pipeline, **kwargs) -> Interface:
|
||||
def from_pipeline(cls, pipeline: Pipeline, **kwargs) -> Interface:
|
||||
"""
|
||||
Class method that constructs an Interface from a Hugging Face transformers.Pipeline object.
|
||||
The input and output components are automatically determined from the pipeline.
|
||||
@ -131,25 +124,25 @@ class Interface(Blocks):
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: Optional[str | Component | List[str | Component]],
|
||||
outputs: Optional[str | Component | List[str | Component]],
|
||||
examples: Optional[List[Any] | List[List[Any]] | str] = None,
|
||||
cache_examples: Optional[bool] = None,
|
||||
inputs: str | IOComponent | List[str | IOComponent] | None,
|
||||
outputs: str | IOComponent | List[str | IOComponent] | None,
|
||||
examples: List[Any] | List[List[Any]] | str | None = None,
|
||||
cache_examples: bool | None = None,
|
||||
examples_per_page: int = 10,
|
||||
live: bool = False,
|
||||
interpretation: Optional[Callable | str] = None,
|
||||
interpretation: Callable | str | None = None,
|
||||
num_shap: float = 2.0,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
article: Optional[str] = None,
|
||||
thumbnail: Optional[str] = None,
|
||||
theme: Optional[str] = None,
|
||||
css: Optional[str] = None,
|
||||
allow_flagging: Optional[str] = None,
|
||||
flagging_options: List[str] = None,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
article: str | None = None,
|
||||
thumbnail: str | None = None,
|
||||
theme: str = "default",
|
||||
css: str | None = None,
|
||||
allow_flagging: str | None = None,
|
||||
flagging_options: List[str] | None = None,
|
||||
flagging_dir: str = "flagged",
|
||||
flagging_callback: FlaggingCallback = CSVLogger(),
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
analytics_enabled: bool | None = None,
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
_api_mode: bool = False,
|
||||
@ -184,21 +177,11 @@ class Interface(Blocks):
|
||||
analytics_enabled=analytics_enabled,
|
||||
mode="interface",
|
||||
css=css,
|
||||
title=title,
|
||||
title=title or "Gradio",
|
||||
theme=theme,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.interface_type = self.InterfaceTypes.STANDARD
|
||||
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
|
||||
raise ValueError("Must provide at least one of `inputs` or `outputs`")
|
||||
elif outputs is None or outputs == []:
|
||||
outputs = []
|
||||
self.interface_type = self.InterfaceTypes.INPUT_ONLY
|
||||
elif inputs is None or inputs == []:
|
||||
inputs = []
|
||||
self.interface_type = self.InterfaceTypes.OUTPUT_ONLY
|
||||
|
||||
if isinstance(fn, list):
|
||||
raise DeprecationWarning(
|
||||
"The `fn` parameter only accepts a single function, support for a list "
|
||||
@ -206,6 +189,19 @@ class Interface(Blocks):
|
||||
"instead."
|
||||
)
|
||||
|
||||
self.interface_type = InterfaceTypes.STANDARD
|
||||
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
|
||||
raise ValueError("Must provide at least one of `inputs` or `outputs`")
|
||||
elif outputs is None or outputs == []:
|
||||
outputs = []
|
||||
self.interface_type = InterfaceTypes.INPUT_ONLY
|
||||
elif inputs is None or inputs == []:
|
||||
inputs = []
|
||||
self.interface_type = InterfaceTypes.OUTPUT_ONLY
|
||||
|
||||
assert isinstance(inputs, (str, list, IOComponent))
|
||||
assert isinstance(outputs, (str, list, IOComponent))
|
||||
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(outputs, list):
|
||||
@ -234,7 +230,7 @@ class Interface(Blocks):
|
||||
state_output_index = state_output_indexes[0]
|
||||
if inputs[state_input_index] == "state":
|
||||
default = utils.get_default_args(fn)[state_input_index]
|
||||
state_variable = State(value=default)
|
||||
state_variable = State(value=default) # type: ignore
|
||||
else:
|
||||
state_variable = inputs[state_input_index]
|
||||
|
||||
@ -266,13 +262,14 @@ class Interface(Blocks):
|
||||
i is o for i, o in zip(self.input_components, self.output_components)
|
||||
]
|
||||
if all(same_components):
|
||||
self.interface_type = self.InterfaceTypes.UNIFIED
|
||||
self.interface_type = InterfaceTypes.UNIFIED
|
||||
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.OUTPUT_ONLY,
|
||||
InterfaceTypes.STANDARD,
|
||||
InterfaceTypes.OUTPUT_ONLY,
|
||||
]:
|
||||
for o in self.output_components:
|
||||
assert isinstance(o, IOComponent)
|
||||
o.interactive = False # Force output components to be non-interactive
|
||||
|
||||
if (
|
||||
@ -375,6 +372,8 @@ class Interface(Blocks):
|
||||
self.flagging_options = flagging_options
|
||||
self.flagging_callback = flagging_callback
|
||||
self.flagging_dir = flagging_dir
|
||||
self.batch = batch
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
self.save_to = None # Used for selenium tests
|
||||
self.share = None
|
||||
@ -395,7 +394,7 @@ class Interface(Blocks):
|
||||
"allow_flagging": allow_flagging,
|
||||
"custom_css": self.css is not None,
|
||||
"theme": self.theme,
|
||||
"version": pkgutil.get_data(__name__, "version.txt")
|
||||
"version": (pkgutil.get_data(__name__, "version.txt") or b"")
|
||||
.decode("ascii")
|
||||
.strip(),
|
||||
}
|
||||
@ -406,9 +405,11 @@ class Interface(Blocks):
|
||||
|
||||
param_names = inspect.getfullargspec(self.fn)[0]
|
||||
for component, param_name in zip(self.input_components, param_names):
|
||||
assert isinstance(component, IOComponent)
|
||||
if component.label is None:
|
||||
component.label = param_name
|
||||
for i, component in enumerate(self.output_components):
|
||||
assert isinstance(component, IOComponent)
|
||||
if component.label is None:
|
||||
if len(self.output_components) == 1:
|
||||
component.label = "output"
|
||||
@ -417,257 +418,342 @@ class Interface(Blocks):
|
||||
|
||||
if self.allow_flagging != "never":
|
||||
if (
|
||||
self.interface_type == self.InterfaceTypes.UNIFIED
|
||||
self.interface_type == InterfaceTypes.UNIFIED
|
||||
or self.allow_flagging == "auto"
|
||||
):
|
||||
self.flagging_callback.setup(self.input_components, self.flagging_dir)
|
||||
elif self.interface_type == self.InterfaceTypes.INPUT_ONLY:
|
||||
self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
|
||||
elif self.interface_type == InterfaceTypes.INPUT_ONLY:
|
||||
pass
|
||||
else:
|
||||
self.flagging_callback.setup(
|
||||
self.input_components + self.output_components, self.flagging_dir
|
||||
self.input_components + self.output_components, self.flagging_dir # type: ignore
|
||||
)
|
||||
|
||||
# Render the Gradio UI
|
||||
with self:
|
||||
if self.title:
|
||||
Markdown(
|
||||
"<h1 style='text-align: center; margin-bottom: 1rem'>"
|
||||
+ self.title
|
||||
+ "</h1>"
|
||||
)
|
||||
if self.description:
|
||||
Markdown(self.description)
|
||||
self.render_title_description()
|
||||
|
||||
def render_flag_btns(flagging_options):
|
||||
if flagging_options is None:
|
||||
return [(Button("Flag"), None)]
|
||||
else:
|
||||
return [
|
||||
(
|
||||
Button("Flag as " + flag_option),
|
||||
flag_option,
|
||||
)
|
||||
for flag_option in flagging_options
|
||||
]
|
||||
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
|
||||
interpretation_btn, interpretation_set = None, None
|
||||
input_component_column, interpret_component_column = None, None
|
||||
|
||||
with Row().style(equal_height=False):
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
InterfaceTypes.STANDARD,
|
||||
InterfaceTypes.INPUT_ONLY,
|
||||
InterfaceTypes.UNIFIED,
|
||||
]:
|
||||
with Column(variant="panel"):
|
||||
input_component_column = Column()
|
||||
with input_component_column:
|
||||
for component in self.input_components:
|
||||
component.render()
|
||||
if self.interpretation:
|
||||
interpret_component_column = Column(visible=False)
|
||||
interpretation_set = []
|
||||
with interpret_component_column:
|
||||
for component in self.input_components:
|
||||
interpretation_set.append(Interpretation(component))
|
||||
with Row():
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
]:
|
||||
clear_btn = Button("Clear")
|
||||
if not self.live:
|
||||
submit_btn = Button("Submit", variant="primary")
|
||||
# Stopping jobs only works if the queue is enabled
|
||||
# We don't know if the queue is enabled when the interface
|
||||
# is created. We use whether a generator function is provided
|
||||
# as a proxy of whether the queue will be enabled.
|
||||
# Using a generator function without the queue will raise an error.
|
||||
if inspect.isgeneratorfunction(fn):
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
|
||||
elif self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Submit", variant="primary")
|
||||
if inspect.isgeneratorfunction(fn) and not self.live:
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = render_flag_btns(self.flagging_options)
|
||||
|
||||
(
|
||||
submit_btn,
|
||||
clear_btn,
|
||||
stop_btn,
|
||||
flag_btns,
|
||||
input_component_column,
|
||||
interpret_component_column,
|
||||
interpretation_set,
|
||||
) = self.render_input_column()
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.OUTPUT_ONLY,
|
||||
InterfaceTypes.STANDARD,
|
||||
InterfaceTypes.OUTPUT_ONLY,
|
||||
]:
|
||||
(
|
||||
submit_btn_out,
|
||||
clear_btn_2_out,
|
||||
stop_btn_2_out,
|
||||
flag_btns_out,
|
||||
interpretation_btn,
|
||||
) = self.render_output_column(submit_btn)
|
||||
submit_btn = submit_btn or submit_btn_out
|
||||
clear_btn = clear_btn or clear_btn_2_out
|
||||
stop_btn = stop_btn or stop_btn_2_out
|
||||
flag_btns = flag_btns or flag_btns_out
|
||||
|
||||
with Column(variant="panel"):
|
||||
for component in self.output_components:
|
||||
if not (isinstance(component, State)):
|
||||
component.render()
|
||||
with Row():
|
||||
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Generate", variant="primary")
|
||||
if inspect.isgeneratorfunction(fn) and not self.live:
|
||||
# Stopping jobs only works if the queue is enabled
|
||||
# We don't know if the queue is enabled when the interface
|
||||
# is created. We use whether a generator function is provided
|
||||
# as a proxy of whether the queue will be enabled.
|
||||
# Using a generator function without the queue will raise an error.
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = render_flag_btns(self.flagging_options)
|
||||
if self.interpretation:
|
||||
interpretation_btn = Button("Interpret")
|
||||
if self.live:
|
||||
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
|
||||
super().load(self.fn, None, self.output_components)
|
||||
submit_btn.click(
|
||||
self.fn,
|
||||
None,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
else:
|
||||
for component in self.input_components:
|
||||
if isinstance(component, Streamable):
|
||||
if component.streaming:
|
||||
component.stream(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
continue
|
||||
else:
|
||||
print(
|
||||
"Hint: Set streaming=True for "
|
||||
+ component.__class__.__name__
|
||||
+ " component to use live streaming."
|
||||
)
|
||||
if isinstance(component, Changeable):
|
||||
component.change(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
else:
|
||||
pred = submit_btn.click(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
if inspect.isgeneratorfunction(fn):
|
||||
stop_btn.click(
|
||||
None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[pred],
|
||||
)
|
||||
assert clear_btn is not None, "Clear button not rendered"
|
||||
|
||||
clear_btn.click(
|
||||
None,
|
||||
[],
|
||||
(
|
||||
self.input_components
|
||||
+ self.output_components
|
||||
+ (
|
||||
[input_component_column]
|
||||
if self.interface_type
|
||||
in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]
|
||||
else []
|
||||
)
|
||||
+ ([interpret_component_column] if self.interpretation else [])
|
||||
),
|
||||
_js=f"""() => {json.dumps(
|
||||
[component.cleared_value if hasattr(component, "cleared_value") else None
|
||||
for component in self.input_components + self.output_components] + (
|
||||
[Column.update(visible=True)]
|
||||
if self.interface_type
|
||||
in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.INPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]
|
||||
else []
|
||||
)
|
||||
+ ([Column.update(visible=False)] if self.interpretation else [])
|
||||
)}
|
||||
""",
|
||||
self.attach_submit_events(submit_btn, stop_btn)
|
||||
self.attach_clear_events(
|
||||
clear_btn, input_component_column, interpret_component_column
|
||||
)
|
||||
self.attach_interpretation_events(
|
||||
interpretation_btn,
|
||||
interpretation_set,
|
||||
input_component_column,
|
||||
interpret_component_column,
|
||||
)
|
||||
|
||||
if self.allow_flagging in ["manual", "auto"]:
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.OUTPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]:
|
||||
if self.allow_flagging == "auto":
|
||||
flag_btns = [(submit_btn, None)]
|
||||
if (
|
||||
self.interface_type == self.InterfaceTypes.UNIFIED
|
||||
or self.allow_flagging == "auto"
|
||||
):
|
||||
flag_components = self.input_components
|
||||
else:
|
||||
flag_components = self.input_components + self.output_components
|
||||
for flag_btn, flag_option in flag_btns:
|
||||
flag_method = FlagMethod(self.flagging_callback, flag_option)
|
||||
flag_btn.click(
|
||||
flag_method,
|
||||
inputs=flag_components,
|
||||
outputs=[],
|
||||
preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if self.examples:
|
||||
non_state_inputs = [
|
||||
c for c in self.input_components if not isinstance(c, State)
|
||||
]
|
||||
non_state_outputs = [
|
||||
c for c in self.output_components if not isinstance(c, State)
|
||||
]
|
||||
self.examples_handler = Examples(
|
||||
examples=examples,
|
||||
inputs=non_state_inputs,
|
||||
outputs=non_state_outputs,
|
||||
fn=self.fn,
|
||||
cache_examples=self.cache_examples,
|
||||
examples_per_page=examples_per_page,
|
||||
_api_mode=_api_mode,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if self.interpretation:
|
||||
interpretation_btn.click(
|
||||
self.interpret_func,
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=interpretation_set
|
||||
+ [input_component_column, interpret_component_column],
|
||||
preprocess=False,
|
||||
)
|
||||
|
||||
if self.article:
|
||||
Markdown(self.article)
|
||||
self.render_flagging_buttons(flag_btns)
|
||||
self.render_examples()
|
||||
self.render_article()
|
||||
|
||||
self.config = self.get_config_file()
|
||||
|
||||
def render_title_description(self) -> None:
|
||||
if self.title:
|
||||
Markdown(
|
||||
"<h1 style='text-align: center; margin-bottom: 1rem'>"
|
||||
+ self.title
|
||||
+ "</h1>"
|
||||
)
|
||||
if self.description:
|
||||
Markdown(self.description)
|
||||
|
||||
def render_flag_btns(self) -> List[Tuple[Button, str | None]]:
|
||||
if self.flagging_options is None:
|
||||
return [(Button("Flag"), None)]
|
||||
else:
|
||||
return [
|
||||
(
|
||||
Button("Flag as " + flag_option),
|
||||
flag_option,
|
||||
)
|
||||
for flag_option in self.flagging_options
|
||||
]
|
||||
|
||||
def render_input_column(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Button | None,
|
||||
Button | None,
|
||||
Button | None,
|
||||
List | None,
|
||||
Column,
|
||||
Column | None,
|
||||
List[Interpretation] | None,
|
||||
]:
|
||||
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
|
||||
interpret_component_column, interpretation_set = None, None
|
||||
|
||||
with Column(variant="panel"):
|
||||
input_component_column = Column()
|
||||
with input_component_column:
|
||||
for component in self.input_components:
|
||||
component.render()
|
||||
if self.interpretation:
|
||||
interpret_component_column = Column(visible=False)
|
||||
interpretation_set = []
|
||||
with interpret_component_column:
|
||||
for component in self.input_components:
|
||||
interpretation_set.append(Interpretation(component))
|
||||
with Row():
|
||||
if self.interface_type in [
|
||||
InterfaceTypes.STANDARD,
|
||||
InterfaceTypes.INPUT_ONLY,
|
||||
]:
|
||||
clear_btn = Button("Clear")
|
||||
if not self.live:
|
||||
submit_btn = Button("Submit", variant="primary")
|
||||
# Stopping jobs only works if the queue is enabled
|
||||
# We don't know if the queue is enabled when the interface
|
||||
# is created. We use whether a generator function is provided
|
||||
# as a proxy of whether the queue will be enabled.
|
||||
# Using a generator function without the queue will raise an error.
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
elif self.interface_type == InterfaceTypes.UNIFIED:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Submit", variant="primary")
|
||||
if inspect.isgeneratorfunction(self.fn) and not self.live:
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = self.render_flag_btns()
|
||||
elif self.allow_flagging == "auto":
|
||||
flag_btns = [(submit_btn, None)]
|
||||
return (
|
||||
submit_btn,
|
||||
clear_btn,
|
||||
stop_btn,
|
||||
flag_btns,
|
||||
input_component_column,
|
||||
interpret_component_column,
|
||||
interpretation_set,
|
||||
)
|
||||
|
||||
def render_output_column(
|
||||
self,
|
||||
submit_btn_in: Button | None,
|
||||
) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
|
||||
submit_btn = submit_btn_in
|
||||
interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
|
||||
|
||||
with Column(variant="panel"):
|
||||
for component in self.output_components:
|
||||
if not (isinstance(component, State)):
|
||||
component.render()
|
||||
with Row():
|
||||
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Generate", variant="primary")
|
||||
if inspect.isgeneratorfunction(self.fn) and not self.live:
|
||||
# Stopping jobs only works if the queue is enabled
|
||||
# We don't know if the queue is enabled when the interface
|
||||
# is created. We use whether a generator function is provided
|
||||
# as a proxy of whether the queue will be enabled.
|
||||
# Using a generator function without the queue will raise an error.
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = self.render_flag_btns()
|
||||
elif self.allow_flagging == "auto":
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
flag_btns = [(submit_btn, None)]
|
||||
if self.interpretation:
|
||||
interpretation_btn = Button("Interpret")
|
||||
|
||||
return submit_btn, clear_btn, stop_btn, flag_btns, interpretation_btn
|
||||
|
||||
def render_article(self):
|
||||
if self.article:
|
||||
Markdown(self.article)
|
||||
|
||||
def attach_submit_events(self, submit_btn: Button | None, stop_btn: Button | None):
|
||||
if self.live:
|
||||
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
super().load(self.fn, None, self.output_components)
|
||||
# For output-only interfaces, the user probably still want a "generate"
|
||||
# button even if the Interface is live
|
||||
submit_btn.click(
|
||||
self.fn,
|
||||
None,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
else:
|
||||
for component in self.input_components:
|
||||
if isinstance(component, Streamable) and component.streaming:
|
||||
component.stream(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
continue
|
||||
if isinstance(component, Changeable):
|
||||
component.change(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
else:
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
pred = submit_btn.click(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name="predict",
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
if stop_btn:
|
||||
stop_btn.click(
|
||||
None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[pred],
|
||||
)
|
||||
|
||||
def attach_clear_events(
|
||||
self,
|
||||
clear_btn: Button,
|
||||
input_component_column: Column | None,
|
||||
interpret_component_column: Column | None,
|
||||
):
|
||||
clear_btn.click(
|
||||
None,
|
||||
[],
|
||||
(
|
||||
self.input_components
|
||||
+ self.output_components
|
||||
+ ([input_component_column] if input_component_column else [])
|
||||
+ ([interpret_component_column] if self.interpretation else [])
|
||||
), # type: ignore
|
||||
_js=f"""() => {json.dumps(
|
||||
[getattr(component, "cleared_value", None)
|
||||
for component in self.input_components + self.output_components] + (
|
||||
[Column.update(visible=True)]
|
||||
if self.interface_type
|
||||
in [
|
||||
InterfaceTypes.STANDARD,
|
||||
InterfaceTypes.INPUT_ONLY,
|
||||
InterfaceTypes.UNIFIED,
|
||||
]
|
||||
else []
|
||||
)
|
||||
+ ([Column.update(visible=False)] if self.interpretation else [])
|
||||
)}
|
||||
""",
|
||||
)
|
||||
|
||||
def attach_interpretation_events(
|
||||
self,
|
||||
interpretation_btn: Button | None,
|
||||
interpretation_set: List[Interpretation] | None,
|
||||
input_component_column: Column | None,
|
||||
interpret_component_column: Column | None,
|
||||
):
|
||||
if interpretation_btn:
|
||||
interpretation_btn.click(
|
||||
self.interpret_func,
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=interpretation_set
|
||||
or [] + [input_component_column, interpret_component_column], # type: ignore
|
||||
preprocess=False,
|
||||
)
|
||||
|
||||
def render_flagging_buttons(self, flag_btns: List | None):
|
||||
if flag_btns:
|
||||
if self.interface_type in [
|
||||
InterfaceTypes.STANDARD,
|
||||
InterfaceTypes.OUTPUT_ONLY,
|
||||
InterfaceTypes.UNIFIED,
|
||||
]:
|
||||
if (
|
||||
self.interface_type == InterfaceTypes.UNIFIED
|
||||
or self.allow_flagging == "auto"
|
||||
):
|
||||
flag_components = self.input_components
|
||||
else:
|
||||
flag_components = self.input_components + self.output_components
|
||||
for flag_btn, flag_option in flag_btns:
|
||||
flag_method = FlagMethod(self.flagging_callback, flag_option)
|
||||
flag_btn.click(
|
||||
flag_method,
|
||||
inputs=flag_components,
|
||||
outputs=[],
|
||||
preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def render_examples(self):
|
||||
if self.examples:
|
||||
non_state_inputs = [
|
||||
c for c in self.input_components if not isinstance(c, State)
|
||||
]
|
||||
non_state_outputs = [
|
||||
c for c in self.output_components if not isinstance(c, State)
|
||||
]
|
||||
self.examples_handler = Examples(
|
||||
examples=self.examples,
|
||||
inputs=non_state_inputs, # type: ignore
|
||||
outputs=non_state_outputs, # type: ignore
|
||||
fn=self.fn,
|
||||
cache_examples=self.cache_examples,
|
||||
examples_per_page=self.examples_per_page,
|
||||
_api_mode=self.api_mode,
|
||||
batch=self.batch,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
@ -683,7 +769,7 @@ class Interface(Blocks):
|
||||
return repr
|
||||
|
||||
async def interpret_func(self, *args):
|
||||
return await self.interpret(args) + [
|
||||
return await self.interpret(list(args)) + [
|
||||
Column.update(visible=False),
|
||||
Column.update(visible=True),
|
||||
]
|
||||
@ -698,20 +784,9 @@ class Interface(Blocks):
|
||||
|
||||
def test_launch(self) -> None:
|
||||
"""
|
||||
Passes a few samples through the function to test if the inputs/outputs
|
||||
components are consistent with the function parameter and return values.
|
||||
Deprecated.
|
||||
"""
|
||||
print("Test launch: {}()...".format(self.__name__), end=" ")
|
||||
raw_input = []
|
||||
for input_component in self.input_components:
|
||||
if input_component.test_input is None:
|
||||
print("SKIPPED")
|
||||
break
|
||||
else:
|
||||
raw_input.append(input_component.test_input)
|
||||
else:
|
||||
self(raw_input)
|
||||
print("PASSED")
|
||||
warnings.warn("The Interface.test_launch() function is deprecated.")
|
||||
|
||||
|
||||
@document()
|
||||
@ -725,11 +800,11 @@ class TabbedInterface(Blocks):
|
||||
def __init__(
|
||||
self,
|
||||
interface_list: List[Interface],
|
||||
tab_names: Optional[List[str]] = None,
|
||||
title: Optional[str] = None,
|
||||
tab_names: List[str] | None = None,
|
||||
title: str | None = None,
|
||||
theme: str = "default",
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
css: Optional[str] = None,
|
||||
analytics_enabled: bool | None = None,
|
||||
css: str | None = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@ -743,7 +818,7 @@ class TabbedInterface(Blocks):
|
||||
a Gradio Tabbed Interface for the given interfaces
|
||||
"""
|
||||
super().__init__(
|
||||
title=title,
|
||||
title=title or "Gradio",
|
||||
theme=theme,
|
||||
analytics_enabled=analytics_enabled,
|
||||
mode="tabbed_interface",
|
||||
@ -760,7 +835,7 @@ class TabbedInterface(Blocks):
|
||||
)
|
||||
with Tabs():
|
||||
for (interface, tab_name) in zip(interface_list, tab_names):
|
||||
with TabItem(label=tab_name):
|
||||
with Tab(label=tab_name):
|
||||
interface.render()
|
||||
|
||||
|
||||
|
@ -6,4 +6,4 @@ pip_required
|
||||
pip install --upgrade pip
|
||||
pip install pyright
|
||||
cd gradio
|
||||
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py
|
||||
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py
|
||||
|
@ -69,15 +69,6 @@ class TestInterface:
|
||||
)
|
||||
assert dataset_check
|
||||
|
||||
def test_test_launch(self):
|
||||
with captured_output() as (out, err):
|
||||
prediction_fn = lambda x: x
|
||||
prediction_fn.__name__ = "prediction_fn"
|
||||
interface = Interface(prediction_fn, "textbox", "label")
|
||||
interface.test_launch()
|
||||
output = out.getvalue().strip()
|
||||
assert output == "Test launch: prediction_fn()... PASSED"
|
||||
|
||||
@mock.patch("time.sleep")
|
||||
def test_block_thread(self, mock_sleep):
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
|
Loading…
x
Reference in New Issue
Block a user