Enforcing typing in blocks.py and context.py (#2887)

* started pathlib

* blocks.py

* more changes

* fixes

* typing

* formatting

* typing

* renaming files

* changelog

* script

* changelog

* lint

* routes

* renamed

* state

* formatting

* state

* type check script

* remove strictness

* switched to pyright

* switched to pyright

* fixed flaky tests

* fixed test xray

* fixed load test

* fixed blocks tests

* formatting

* fixed components test

* uncomment tests

* fixed interpretation tests

* formatting

* last tests hopefully

* argh lint

* component

* fixed based on review

* refactor
This commit is contained in:
Abubakar Abid 2022-12-27 16:54:47 -05:00 committed by GitHub
parent e7cca92831
commit de0c41c1c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 263 additions and 190 deletions

View File

@ -55,6 +55,10 @@ jobs:
command: |
. venv/bin/activate
bash scripts/lint_backend.sh
- run:
command: |
. venv/bin/activate
bash scripts/type_check_backend.sh
- run:
command: |
. venv/bin/activate

View File

@ -21,6 +21,7 @@ No changes to highlight.
## Full Changelog:
* The `default_enabled` parameter of the `Blocks.queue` method has no effect by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2876](https://github.com/gradio-app/gradio/pull/2876)
* Added typing to several Python files in codebase by [@abidlabs](https://github.com/abidlabs) in [PR 2887](https://github.com/gradio-app/gradio/pull/2887)
## Contributors Shoutout:
* @JaySmithWpg for making their first contribution to gradio!

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from abc import abstractmethod
import copy
import getpass
import inspect
@ -12,30 +13,21 @@ import time
import typing
import warnings
import webbrowser
from pathlib import Path
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type
import anyio
import requests
from anyio import CapacityLimiter
from typing_extensions import Literal
from gradio import (
components,
encryptor,
external,
networking,
queue,
queueing,
routes,
strings,
utils,
@ -59,11 +51,9 @@ set_documentation_group("blocks")
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import comet_ml
import mlflow
import wandb
from fastapi.applications import FastAPI
from gradio.components import Component, IOComponent
from gradio.components import Component
class Block:
@ -84,6 +74,8 @@ class Block:
self.root_url = root_url
self._skip_init_processing = _skip_init_processing
self._style = {}
self.parent: BlockContext | None = None
if render:
self.render()
check_deprecated_parameters(self.__class__.__name__, **kwargs)
@ -135,6 +127,9 @@ class Block:
else self.__class__.__name__.lower()
)
def get_expected_parent(self) -> Type[BlockContext] | None:
return None
def set_event_trigger(
self,
event_name: str,
@ -145,7 +140,7 @@ class Block:
postprocess: bool = True,
scroll_to_output: bool = False,
show_progress: bool = True,
api_name: AnyStr | None = None,
api_name: str | None = None,
js: str | None = None,
no_target: bool = False,
queue: bool | None = None,
@ -207,8 +202,10 @@ class Block:
"Either batch is True or every is non-zero but not both."
)
if every:
if every and fn:
fn = get_continuous_fn(fn, every)
elif every:
raise ValueError("Cannot set a value for `every` without a `fn`.")
Context.root_block.fns.append(
BlockFunction(fn, inputs, outputs, preprocess, postprocess, inputs_as_dict)
@ -250,11 +247,16 @@ class Block:
"root_url": self.root_url,
}
@staticmethod
@abstractmethod
def update(**kwargs) -> Dict:
return {}
@classmethod
def get_specific_update(cls, generic_update):
def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict:
del generic_update["__type__"]
generic_update = cls.update(**generic_update)
return generic_update
specific_update = cls.update(**generic_update)
return specific_update
class BlockContext(Block):
@ -269,7 +271,7 @@ class BlockContext(Block):
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
render: If False, this will not be included in the Blocks config file at all.
"""
self.children = []
self.children: List[Block] = []
super().__init__(visible=visible, render=render, **kwargs)
def __enter__(self):
@ -277,7 +279,7 @@ class BlockContext(Block):
Context.block = self
return self
def add(self, child):
def add(self, child: Block):
child.parent = self
self.children.append(child)
@ -285,7 +287,7 @@ class BlockContext(Block):
children = []
pseudo_parent = None
for child in self.children:
expected_parent = getattr(child.__class__, "expected_parent", False)
expected_parent = child.get_expected_parent()
if not expected_parent or isinstance(self, expected_parent):
pseudo_parent = None
children.append(child)
@ -298,7 +300,8 @@ class BlockContext(Block):
pseudo_parent = expected_parent(render=False)
children.append(pseudo_parent)
pseudo_parent.children = [child]
Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
if Context.root_block:
Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
child.parent = pseudo_parent
self.children = children
@ -317,7 +320,7 @@ class BlockContext(Block):
class BlockFunction:
def __init__(
self,
fn: Optional[Callable],
fn: Callable | None,
inputs: List[Component],
outputs: List[Component],
preprocess: bool,
@ -393,7 +396,7 @@ def update(**kwargs) -> dict:
),
gr.Textbox(lines=2),
live=True,
).launch()
).launch()abstrac
"""
kwargs["__type__"] = "generic_update"
return kwargs
@ -423,14 +426,19 @@ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool =
update_dict.pop("value")
prediction_value = delete_none(update_dict, skip_value=True)
if "value" in prediction_value and postprocess:
assert isinstance(
block, components.IOComponent
), f"Component {block.__class__} does not support value"
prediction_value["value"] = block.postprocess(prediction_value["value"])
return prediction_value
def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List:
def convert_component_dict_to_list(
outputs_ids: List[int], predictions: Dict
) -> List | Dict:
"""
Converts a dictionary of component updates into a list of updates in the order of
the outputs_ids and including every output component.
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
"""
@ -453,7 +461,9 @@ def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) ->
def add_request_to_inputs(
fn: Callable, inputs: List[Any], request: routes.Request | List[routes.Request]
fn: Callable,
inputs: List[Any],
request: routes.Request | List[routes.Request] | None,
):
"""
Adds the FastAPI Request object to the inputs of a function if the type of the parameter is FastAPI.Request.
@ -508,17 +518,17 @@ class Blocks(BlockContext):
def __init__(
self,
theme: str = "default",
analytics_enabled: Optional[bool] = None,
analytics_enabled: bool | None = None,
mode: str = "blocks",
title: str = "Gradio",
css: Optional[str] = None,
css: str | None = None,
**kwargs,
):
"""
Parameters:
theme: which theme to use - right now, only "default" is supported.
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
mode: a human-friendly name for the kind of Blocks interface being created.
mode: a human-friendly name for the kind of Blocks or Interface being created.
title: The tab title to display when this is opened in a browser window.
css: custom css or path to custom css file to apply to entire Blocks
"""
@ -531,7 +541,7 @@ class Blocks(BlockContext):
self.enable_queue = None
self.max_threads = 40
self.show_error = True
if css is not None and os.path.exists(css):
if css is not None and Path(css).exists():
with open(css) as css_file:
self.css = css_file.read()
else:
@ -558,7 +568,7 @@ class Blocks(BlockContext):
self.height = None
self.api_open = True
self.ip_address = None
self.ip_address = ""
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
self.favicon_path = None
self.auth = None
@ -568,6 +578,13 @@ class Blocks(BlockContext):
self.title = title
self.show_api = True
# Only used when an Interface is loaded from a config
self.predict = None
self.input_components = None
self.output_components = None
self.__name__ = None
self.api_mode = None
if self.analytics_enabled:
self.ip_address = utils.get_local_ip_address()
data = {
@ -575,7 +592,7 @@ class Blocks(BlockContext):
"ip_address": self.ip_address,
"custom_css": self.css is not None,
"theme": self.theme,
"version": pkgutil.get_data(__name__, "version.txt")
"version": (pkgutil.get_data(__name__, "version.txt") or b"")
.decode("ascii")
.strip(),
}
@ -611,7 +628,7 @@ class Blocks(BlockContext):
block_config["props"]["root_url"] = root_url + "/"
# Any component has already processed its initial value, so we skip that step here
block = cls(**block_config["props"], _skip_init_processing=True)
if style:
if style and isinstance(block, components.IOComponent):
block.style(**style)
return block
@ -619,10 +636,14 @@ class Blocks(BlockContext):
for child_config in children_list:
id = child_config["id"]
block = get_block_instance(id)
original_mapping[id] = block
children = child_config.get("children")
if children is not None:
assert isinstance(
block, BlockContext
), f"Invalid config, Block with id {id} has children but is not a BlockContext."
with block:
iterate_over_children(children)
@ -665,7 +686,7 @@ class Blocks(BlockContext):
first_dependency = dependency
# Allows some use of Interface-specific methods with loaded Spaces
if first_dependency:
if first_dependency and Context.root_block:
blocks.predict = [fns[0]]
blocks.input_components = [
Context.root_block.blocks[i] for i in first_dependency["inputs"]
@ -673,11 +694,8 @@ class Blocks(BlockContext):
blocks.output_components = [
Context.root_block.blocks[o] for o in first_dependency["outputs"]
]
if config.get("mode", "blocks") == "interface":
blocks.__name__ = "Interface"
blocks.mode = "interface"
blocks.api_mode = True
blocks.__name__ = "Interface"
blocks.api_mode = True
return blocks
@ -764,7 +782,7 @@ class Blocks(BlockContext):
if inspect.isasyncgenfunction(block_fn.fn):
return False
if inspect.isgeneratorfunction(block_fn.fn):
raise False
return False
for input_id in dependency["inputs"]:
block = self.blocks[input_id]
if getattr(block, "stateful", False):
@ -776,7 +794,7 @@ class Blocks(BlockContext):
return True
def __call__(self, *inputs, fn_index: int = 0, api_name: str = None):
def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
"""
Allows Blocks objects to be called as functions. Supply the parameters to the
function as positional arguments. To choose which function to call, use the
@ -788,7 +806,7 @@ class Blocks(BlockContext):
api_name: The api_name of the dependency to call. Will take precedence over fn_index.
"""
if api_name is not None:
fn_index = next(
inferred_fn_index = next(
(
i
for i, d in enumerate(self.dependencies)
@ -796,8 +814,9 @@ class Blocks(BlockContext):
),
None,
)
if fn_index is None:
if inferred_fn_index is None:
raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
fn_index = inferred_fn_index
if not (self.is_callable(fn_index)):
raise ValueError(
"This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
@ -814,6 +833,7 @@ class Blocks(BlockContext):
fn_index=fn_index,
inputs=processed_inputs,
request=None,
state={},
)
outputs = outputs["data"]
@ -834,6 +854,7 @@ class Blocks(BlockContext):
):
"""Calls and times function with given index and preprocessed input."""
block_fn = self.fns[fn_index]
assert block_fn.fn, f"function with index {fn_index} not defined."
is_generating = False
if block_fn.inputs_as_dict:
@ -856,6 +877,8 @@ class Blocks(BlockContext):
prediction = await anyio.to_thread.run_sync(
block_fn.fn, *processed_input, limiter=self.limiter
)
else:
prediction = None
if inspect.isasyncgenfunction(block_fn.fn):
raise ValueError("Gradio does not support async generators.")
@ -892,7 +915,10 @@ class Blocks(BlockContext):
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block: IOComponent = self.blocks[input_id]
block = self.blocks[input_id]
assert isinstance(
block, components.IOComponent
), f"{block.__class__} Component with id {input_id} not a valid input component."
serialized_input = block.serialize(inputs[i])
processed_input.append(serialized_input)
@ -903,7 +929,10 @@ class Blocks(BlockContext):
predictions = []
for o, output_id in enumerate(dependency["outputs"]):
block: IOComponent = self.blocks[output_id]
block = self.blocks[output_id]
assert isinstance(
block, components.IOComponent
), f"{block.__class__} Component with id {output_id} not a valid output component."
deserialized = block.deserialize(outputs[o])
predictions.append(deserialized)
@ -916,7 +945,10 @@ class Blocks(BlockContext):
if block_fn.preprocess:
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block: IOComponent = self.blocks[input_id]
block = self.blocks[input_id]
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {input_id} not a valid input component."
if getattr(block, "stateful", False):
processed_input.append(state.get(input_id))
else:
@ -926,7 +958,7 @@ class Blocks(BlockContext):
return processed_input
def postprocess_data(
self, fn_index: int, predictions: List[Any], state: Dict[int, Any]
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
@ -938,7 +970,9 @@ class Blocks(BlockContext):
)
if len(dependency["outputs"]) == 1 and not (batch):
predictions = (predictions,)
predictions = [
predictions,
]
output = []
for i, output_id in enumerate(dependency["outputs"]):
@ -953,12 +987,16 @@ class Blocks(BlockContext):
else:
prediction_value = predictions[i]
if utils.is_update(prediction_value):
assert isinstance(prediction_value, dict)
prediction_value = postprocess_update_dict(
block=block,
update_dict=prediction_value,
postprocess=block_fn.postprocess,
)
elif block_fn.postprocess:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {output_id} not a valid output component."
prediction_value = block.postprocess(prediction_value)
output.append(prediction_value)
return output
@ -967,9 +1005,8 @@ class Blocks(BlockContext):
self,
fn_index: int,
inputs: List[Any],
state: Dict[int, Any],
request: routes.Request | List[routes.Request] | None = None,
username: str = None,
state: Dict[int, Any] | List[Dict[int, Any]] | None = None,
iterators: Dict[int, Any] | None = None,
) -> Dict[str, Any]:
"""
@ -1003,10 +1040,16 @@ class Blocks(BlockContext):
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
)
inputs = [self.preprocess_data(fn_index, i, state) for i in zip(*inputs)]
result = await self.call_function(fn_index, zip(*inputs), None, request)
inputs = [
self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
]
result = await self.call_function(
fn_index, list(zip(*inputs)), None, request
)
preds = result["prediction"]
data = [self.postprocess_data(fn_index, o, state) for o in zip(*preds)]
data = [
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
]
data = list(zip(*data))
is_generating, iterator = None, None
else:
@ -1098,10 +1141,10 @@ class Blocks(BlockContext):
@class_or_instancemethod
def load(
self_or_cls,
fn: Optional[Callable] = None,
inputs: Optional[List[Component]] = None,
outputs: Optional[List[Component]] = None,
api_name: AnyStr = None,
fn: Callable | None = None,
inputs: List[Component] | None = None,
outputs: List[Component] | None = None,
api_name: str | None = None,
scroll_to_output: bool = False,
show_progress: bool = True,
queue=None,
@ -1110,12 +1153,12 @@ class Blocks(BlockContext):
preprocess: bool = True,
postprocess: bool = True,
every: float | None = None,
_js: Optional[str] = None,
_js: str | None = None,
*,
name: Optional[str] = None,
src: Optional[str] = None,
api_key: Optional[str] = None,
alias: Optional[str] = None,
name: str | None = None,
src: str | None = None,
api_key: str | None = None,
alias: str | None = None,
**kwargs,
) -> Blocks | Dict[str, Any] | None:
"""
@ -1192,7 +1235,7 @@ class Blocks(BlockContext):
def queue(
self,
concurrency_count: int = 1,
status_update_rate: float | str = "auto",
status_update_rate: float | Literal["auto"] = "auto",
client_position_to_load_data: int | None = None,
default_enabled: bool | None = None,
api_open: bool = True,
@ -1221,7 +1264,7 @@ class Blocks(BlockContext):
self.api_open = api_open
if client_position_to_load_data is not None:
warnings.warn("The client_position_to_load_data parameter is deprecated.")
self._queue = queue.Queue(
self._queue = queueing.Queue(
live_updates=status_update_rate == "auto",
concurrency_count=concurrency_count,
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
@ -1233,26 +1276,26 @@ class Blocks(BlockContext):
def launch(
self,
inline: bool = None,
inline: bool | None = None,
inbrowser: bool = False,
share: Optional[bool] = None,
share: bool | None = None,
debug: bool = False,
enable_queue: bool = None,
enable_queue: bool | None = None,
max_threads: int = 40,
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
auth_message: Optional[str] = None,
auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None,
auth_message: str | None = None,
prevent_thread_lock: bool = False,
show_error: bool = False,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
server_name: str | None = None,
server_port: int | None = None,
show_tips: bool = False,
height: int = 500,
width: int | str = "100%",
encrypt: bool = False,
favicon_path: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_keyfile_password: Optional[str] = None,
favicon_path: str | None = None,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
quiet: bool = False,
show_api: bool = True,
_frontend: bool = True,
@ -1302,8 +1345,9 @@ class Blocks(BlockContext):
and not isinstance(auth[0], tuple)
and not isinstance(auth[0], list)
):
auth = [auth]
self.auth = auth
self.auth = [auth]
else:
self.auth = auth
self.auth_message = auth_message
self.show_tips = show_tips
self.show_error = show_error
@ -1352,7 +1396,9 @@ class Blocks(BlockContext):
)
if self.is_running:
self.server_app.launchable = self
assert isinstance(
self.local_url, str
), f"Invalid local_url: {self.local_url}"
if not (quiet):
print(
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
@ -1448,14 +1494,14 @@ class Blocks(BlockContext):
self.share_url = None
if inbrowser:
link = self.share_url if self.share else self.local_url
link = self.share_url if self.share and self.share_url else self.local_url
webbrowser.open(link)
# Check if running in a Python notebook in which case, display inline
if inline is None:
inline = utils.ipython_check() and (auth is None)
inline = utils.ipython_check() and (self.auth is None)
if inline:
if auth is not None:
if self.auth is not None:
print(
"Warning: authentication is not supported inline. Please"
"click the link to access the interface in a new tab."
@ -1463,7 +1509,7 @@ class Blocks(BlockContext):
try:
from IPython.display import HTML, Javascript, display # type: ignore
if self.share:
if self.share and self.share_url:
while not networking.url_ok(self.share_url):
time.sleep(0.25)
display(
@ -1546,9 +1592,9 @@ class Blocks(BlockContext):
def integrate(
self,
comet_ml: comet_ml.Experiment = None,
wandb: ModuleType("wandb") = None,
mlflow: ModuleType("mlflow") = None,
comet_ml: comet_ml.Experiment | None = None,
wandb: ModuleType | None = None,
mlflow: ModuleType | None = None,
) -> None:
"""
A catch-all method for integrating with other libraries. This method should be run after launch()
@ -1564,9 +1610,11 @@ class Blocks(BlockContext):
if self.share_url is not None:
comet_ml.log_text("gradio: " + self.share_url)
comet_ml.end()
else:
elif self.local_url:
comet_ml.log_text("gradio: " + self.local_url)
comet_ml.end()
else:
raise ValueError("Please run `launch()` first.")
if wandb is not None:
analytics_integration = "WandB"
if self.share_url is not None:
@ -1627,23 +1675,23 @@ class Blocks(BlockContext):
def attach_load_events(self):
"""Add a load event for every component whose initial value should be randomized."""
for component in Context.root_block.blocks.values():
if (
isinstance(component, components.IOComponent)
and component.load_event_to_attach
):
load_fn, every = component.load_event_to_attach
# Use set_event_trigger to avoid ambiguity between load class/instance method
self.set_event_trigger(
"load",
load_fn,
None,
component,
no_target=True,
queue=False,
every=every,
)
if Context.root_block:
for component in Context.root_block.blocks.values():
if (
isinstance(component, components.IOComponent)
and component.load_event_to_attach
):
load_fn, every = component.load_event_to_attach
# Use set_event_trigger to avoid ambiguity between load class/instance method
self.set_event_trigger(
"load",
load_fn,
None,
component,
no_target=True,
queue=False,
every=every,
)
def startup_events(self):
"""Events that should be run when the app containing this block starts up."""

View File

@ -18,7 +18,7 @@ from copy import deepcopy
from enum import Enum
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
import altair as alt
import matplotlib.figure
@ -28,7 +28,7 @@ import PIL
import PIL.ImageOps
from ffmpy import FFmpeg
from markdown_it import MarkdownIt
from mdit_py_plugins.dollarmath import dollarmath_plugin
from mdit_py_plugins.dollarmath.index import dollarmath_plugin
from pandas.api.types import is_numeric_dtype
from PIL import Image as _Image # using _ to minimize namespace pollution
@ -94,6 +94,18 @@ class Component(Block):
**super().get_config(),
}
def preprocess(self, x: Any) -> Any:
"""
Any preprocessing needed to be performed on function input.
"""
return x
def postprocess(self, y):
"""
Any postprocessing needed to be performed on function output.
"""
return y
class IOComponent(Component, Serializable):
"""
@ -140,12 +152,6 @@ class IOComponent(Component, Serializable):
**super().get_config(),
}
def preprocess(self, x: Any) -> Any:
"""
Any preprocessing needed to be performed on function input.
"""
return x
def set_interpret_parameters(self):
"""
Set any parameters for interpretation.
@ -184,12 +190,6 @@ class IOComponent(Component, Serializable):
"""
pass
def postprocess(self, y):
"""
Any postprocessing needed to be performed on function output.
"""
return y
def style(
self,
*,
@ -273,12 +273,13 @@ class IOComponent(Component, Serializable):
class FormComponent:
expected_parent = Form
def get_expected_parent(self) -> Type[Form]:
return Form
@document("change", "submit", "blur", "style")
class Textbox(
Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable, FormComponent
FormComponent, Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable
):
"""
Creates a textarea for user to enter string input or display string output.
@ -459,7 +460,7 @@ class Textbox(
@document("change", "submit", "style")
class Number(
Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable, FormComponent
FormComponent, Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable
):
"""
Creates a numeric field for user to enter numbers as input or display numeric output.
@ -630,7 +631,7 @@ class Number(
@document("change", "style")
class Slider(Changeable, IOComponent, SimpleSerializable, FormComponent):
class Slider(FormComponent, Changeable, IOComponent, SimpleSerializable):
"""
Creates a slider that ranges from `minimum` to `maximum` with a step size of `step`.
Preprocessing: passes slider value as a {float} into the function.
@ -792,7 +793,7 @@ class Slider(Changeable, IOComponent, SimpleSerializable, FormComponent):
@document("change", "style")
class Checkbox(Changeable, IOComponent, SimpleSerializable, FormComponent):
class Checkbox(FormComponent, Changeable, IOComponent, SimpleSerializable):
"""
Creates a checkbox that can be set to `True` or `False`.
@ -886,7 +887,7 @@ class Checkbox(Changeable, IOComponent, SimpleSerializable, FormComponent):
@document("change", "style")
class CheckboxGroup(Changeable, IOComponent, SimpleSerializable, FormComponent):
class CheckboxGroup(FormComponent, Changeable, IOComponent, SimpleSerializable):
"""
Creates a set of checkboxes of which a subset can be checked.
Preprocessing: passes the list of checked checkboxes as a {List[str]} or their indices as a {List[int]} into the function, depending on `type`.
@ -1056,7 +1057,7 @@ class CheckboxGroup(Changeable, IOComponent, SimpleSerializable, FormComponent):
@document("change", "style")
class Radio(Changeable, IOComponent, SimpleSerializable, FormComponent):
class Radio(FormComponent, Changeable, IOComponent, SimpleSerializable):
"""
Creates a set of radio buttons of which only one can be selected.
Preprocessing: passes the value of the selected radio button as a {str} or its index as an {int} into the function, depending on `type`.
@ -1615,7 +1616,7 @@ class Image(
)
def as_example(self, input_data: str | None) -> str:
return os.path.abspath(input_data)
return str(Path(input_data).resolve())
@document("change", "clear", "play", "pause", "stop", "style")
@ -1758,7 +1759,7 @@ class Video(
output_file_name = str(
file_name.with_name(f"{file_name.stem}{flip_suffix}{format}")
)
if os.path.exists(output_file_name):
if Path(output_file_name).exists():
return output_file_name
ff = FFmpeg(
inputs={str(file_name): None},
@ -2025,7 +2026,7 @@ class Audio(
out_data = processing_utils.encode_file_to_base64(file.name)
leave_one_out_sets.append(out_data)
file.close()
os.unlink(file.name)
Path(file.name).unlink()
# Handle the tokens
token = np.copy(data)
@ -2035,7 +2036,7 @@ class Audio(
processing_utils.audio_to_file(sample_rate, token, file.name)
token_data = processing_utils.encode_file_to_base64(file.name)
file.close()
os.unlink(file.name)
Path(file.name).unlink()
tokens.append(token_data)
tokens = [{"name": "token.wav", "data": token} for token in tokens]
@ -2066,7 +2067,7 @@ class Audio(
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
masked_data = processing_utils.encode_file_to_base64(file.name)
file.close()
os.unlink(file.name)
Path(file.name).unlink()
masked_inputs.append(masked_data)
return masked_inputs
@ -3656,11 +3657,11 @@ class Gallery(IOComponent, TempFileManager):
img, caption = img
if isinstance(img, np.ndarray):
file = processing_utils.save_array_to_file(img)
file_path = os.path.abspath(file.name)
file_path = str(Path(file.name).resolve())
self.temp_files.add(file_path)
elif isinstance(img, PIL.Image.Image):
file = processing_utils.save_pil_to_file(img)
file_path = os.path.abspath(file.name)
file_path = str(Path(file.name).resolve())
self.temp_files.add(file_path)
elif isinstance(img, str):
if utils.validate_url(img):
@ -3706,8 +3707,8 @@ class Gallery(IOComponent, TempFileManager):
) -> None | str:
if x is None:
return None
gallery_path = os.path.join(save_dir, str(uuid.uuid4()))
os.makedirs(gallery_path)
gallery_path = Path(save_dir) / str(uuid.uuid4())
gallery_path.mkdir(exist_ok=True, parents=True)
captions = {}
for img_data in x:
if isinstance(img_data, list) or isinstance(img_data, tuple):
@ -3716,15 +3717,15 @@ class Gallery(IOComponent, TempFileManager):
caption = None
name = FileSerializable.deserialize(self, img_data, gallery_path)
captions[name] = caption
captions_file = os.path.join(gallery_path, "captions.json")
with open(captions_file, "w") as captions_json:
captions_file = gallery_path / "captions.json"
with captions_file.open("w") as captions_json:
json.dump(captions, captions_json)
return os.path.abspath(gallery_path)
return str(gallery_path.resolve())
def serialize(self, x: Any, load_dir: str = "", called_directly: bool = False):
files = []
captions_file = os.path.join(x, "captions.json")
with open(captions_file) as captions_json:
captions_file = Path(x) / "captions.json"
with captions_file.open("r") as captions_json:
captions = json.load(captions_json)
for file_name, caption in captions.items():
img = FileSerializable.serialize(self, file_name)
@ -4978,9 +4979,6 @@ class Interpretation(Component):
def style(self):
return self
def postprocess(self, y: Any) -> Any:
return y
class StatusTracker(Component):
"""

View File

@ -9,8 +9,6 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
class Context:
root_block: Blocks = None # The current root block that holds all blocks.
block: BlockContext = (
None # The current block that all children should be added to.
)
id = 0 # Running id to uniquely refer to any block that gets defined
root_block: Blocks | None = None # The current root block that holds all blocks.
block: BlockContext | None = None # The current block that children are added to.
id: int = 0 # Running id to uniquely refer to any block that gets defined

View File

@ -291,7 +291,7 @@ class Examples:
if self.batch:
processed_input = [[value] for value in processed_input]
prediction = await Context.root_block.process_api(
fn_index=fn_index, inputs=processed_input, request=None
fn_index=fn_index, inputs=processed_input, request=None, state={}
)
output = prediction["data"]
if self.batch:

View File

@ -33,7 +33,11 @@ if TYPE_CHECKING:
def load_blocks_from_repo(
name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs
name: str,
src: str | None = None,
api_key: str | None = None,
alias: str | None = None,
**kwargs,
) -> Blocks:
"""Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
if src is None:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Type
from gradio.blocks import BlockContext
from gradio.documentation import document, set_documentation_group
@ -184,8 +184,6 @@ class Tabs(BlockContext):
class TabItem(BlockContext):
expected_parent = Tabs
def __init__(
self,
label: str,
@ -221,6 +219,9 @@ class TabItem(BlockContext):
"""
self.set_event_trigger("select", fn, inputs, outputs)
def get_expected_parent(self) -> Type[Tabs]:
return Tabs
@document()
class Tab(TabItem):

View File

@ -9,7 +9,7 @@ import socket
import threading
import time
import warnings
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Tuple
import fastapi
import requests
@ -87,12 +87,12 @@ def configure_app(app: fastapi.FastAPI, blocks: Blocks) -> fastapi.FastAPI:
def start_server(
blocks: Blocks,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_keyfile_password: Optional[str] = None,
) -> Tuple[int, str, App, Server]:
server_name: str | None = None,
server_port: int | None = None,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
) -> Tuple[str, int, str, App, Server]:
"""Launches a local server running the provided Interface
Parameters:
blocks: The Blocks object to run on the server

View File

@ -10,7 +10,7 @@ from typing import Any, Deque, Dict, List, Optional, Tuple
import fastapi
from pydantic import BaseModel
from gradio.dataclasses import PredictBody
from gradio.data_classes import PredictBody
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
@ -46,7 +46,7 @@ class Queue:
self,
live_updates: bool,
concurrency_count: int,
update_intervals: int,
update_intervals: float,
max_size: Optional[int],
blocks_dependencies: List,
):

View File

@ -37,10 +37,10 @@ from starlette.websockets import WebSocketState
import gradio
from gradio import encryptor, utils
from gradio.dataclasses import PredictBody, ResetBody
from gradio.data_classes import PredictBody, ResetBody
from gradio.documentation import document, set_documentation_group
from gradio.exceptions import Error
from gradio.queue import Estimation, Event
from gradio.queueing import Estimation, Event
from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name
mimetypes.init()
@ -125,7 +125,7 @@ class App(FastAPI):
self.tokens = {}
@staticmethod
def create_app(blocks: gradio.Blocks) -> FastAPI:
def create_app(blocks: gradio.Blocks) -> App:
app = App(default_response_class=ORJSONResponse)
app.configure_app(blocks)
@ -322,7 +322,6 @@ class App(FastAPI):
fn_index=fn_index,
inputs=raw_input,
request=request,
username=username,
state=session_state,
iterators=iterators,
)

View File

@ -118,7 +118,7 @@ class FileSerializable(Serializable):
def deserialize(
self,
x: str | Dict | None,
save_dir: str | None = None,
save_dir: Path | str | None = None,
encryption_key: bytes | None = None,
):
"""
@ -131,6 +131,8 @@ class FileSerializable(Serializable):
"""
if x is None:
return None
if isinstance(save_dir, Path):
save_dir = str(save_dir)
if isinstance(x, str):
file_name = processing_utils.decode_base64_to_file(
x, dir=save_dir, encryption_key=encryption_key

View File

@ -33,6 +33,7 @@ from typing import (
NewType,
Tuple,
Type,
TypeVar,
)
import aiohttp
@ -58,6 +59,8 @@ analytics_url = "https://api.gradio.app/"
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
T = TypeVar("T")
def version_check():
try:
@ -292,7 +295,7 @@ def format_ner_list(input_string: str, ner_groups: Dict[str : str | int]):
return output
def delete_none(_dict, skip_value=False):
def delete_none(_dict: T, skip_value: bool = False) -> T:
"""
Delete None values recursively from all of the dictionaries, tuples, lists, sets.
Credit: https://stackoverflow.com/a/66127889/5209347
@ -319,7 +322,7 @@ def resolve_singleton(_list: List[Any] | Any) -> Any:
return _list
def component_or_layout_class(cls_name: str) -> Component | BlockContext:
def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockContext]:
"""
Returns the component, template, or layout class with the given class name, or
raises a ValueError if not found.
@ -357,7 +360,7 @@ def component_or_layout_class(cls_name: str) -> Component | BlockContext:
raise ValueError(f"No such component or layout: {cls_name}")
def synchronize_async(func: Callable, *args, **kwargs):
def synchronize_async(func: Callable, *args, **kwargs) -> Any:
"""
Runs async functions in sync scopes.
@ -711,10 +714,10 @@ def validate_url(possible_url: str) -> bool:
def is_update(val):
return type(val) is dict and "update" in val.get("__type__", "")
return isinstance(val, dict) and "update" in val.get("__type__", "")
def get_continuous_fn(fn, every):
def get_continuous_fn(fn: Callable, every: float) -> Callable:
def continuous_fn(*args):
while True:
output = fn(*args)

View File

@ -20,3 +20,4 @@ fsspec
httpx
pydantic
websockets>=10.0
typing_extensions

View File

@ -0,0 +1,8 @@
cd "$(dirname ${0})/.."
source scripts/helpers.sh
pip_required
pip install --upgrade pip
pip install pyright
pyright gradio/context.py gradio/blocks.py

View File

@ -179,7 +179,7 @@ class TestBlocksMethods:
btn.click(greet, {first, last}, greeting)
result = await demo.process_api(inputs=["huggy", "face"], fn_index=0)
result = await demo.process_api(inputs=["huggy", "face"], fn_index=0, state={})
assert result["data"] == ["Hello huggy face"]
@pytest.mark.asyncio
@ -194,7 +194,7 @@ class TestBlocksMethods:
button.click(wait, [text], [text])
start = time.time()
result = await demo.process_api(inputs=[1], fn_index=0)
result = await demo.process_api(inputs=[1], fn_index=0, state={})
end = time.time()
difference = end - start
assert difference >= 0.01
@ -395,7 +395,7 @@ class TestComponentsInBlocks:
share_button = gr.Button("share", visible=False)
run_button.click(infer, prompt, [image, share_button], postprocess=False)
output = await demo.process_api(0, ["test"])
output = await demo.process_api(0, ["test"], state={})
assert output["data"][0] == gr.media_data.BASE64_IMAGE
assert output["data"][1] == {"__type__": "update", "visible": True}
@ -412,7 +412,7 @@ class TestComponentsInBlocks:
run_button = gr.Button()
run_button.click(infer, [prompt], [image], postprocess=False)
output = await demo.process_api(0, ["test"])
output = await demo.process_api(0, ["test"], state={})
assert output["data"][0] == {
"__type__": "update",
"value": gr.media_data.BASE64_IMAGE,
@ -439,7 +439,7 @@ class TestComponentsInBlocks:
run.click(generic_update, None, [image, textbox])
for fn_index in range(2):
output = await demo.process_api(fn_index, [])
output = await demo.process_api(fn_index, [], state={})
assert output["data"][0] == {
"interactive": True,
"__type__": "update",
@ -686,7 +686,7 @@ class TestBatchProcessing:
btn = gr.Button()
btn.click(batch_fn, inputs=text, outputs=text, batch=True)
await demo.process_api(0, [["Adam", "Yahya"]])
await demo.process_api(0, [["Adam", "Yahya"]], state={})
@pytest.mark.asyncio
async def test_exceeds_max_batch_size(self):
@ -705,7 +705,7 @@ class TestBatchProcessing:
batch_fn, inputs=text, outputs=text, batch=True, max_batch_size=2
)
await demo.process_api(0, [["A", "B", "C"]])
await demo.process_api(0, [["A", "B", "C"]], state={})
@pytest.mark.asyncio
async def test_unequal_batch_sizes(self):
@ -723,7 +723,7 @@ class TestBatchProcessing:
btn = gr.Button()
btn.click(batch_fn, inputs=[t1, t2], outputs=t1, batch=True)
await demo.process_api(0, [["A", "B", "C"], ["D", "E"]])
await demo.process_api(0, [["A", "B", "C"], ["D", "E"]], state={})
class TestSpecificUpdate:
@ -802,13 +802,17 @@ class TestSpecificUpdate:
inputs=None,
outputs=[accordion],
)
result = await demo.process_api(fn_index=0, inputs=[None], request=None)
result = await demo.process_api(
fn_index=0, inputs=[None], request=None, state={}
)
assert result["data"][0] == {
"open": True,
"label": "Open Accordion",
"__type__": "update",
}
result = await demo.process_api(fn_index=1, inputs=[None], request=None)
result = await demo.process_api(
fn_index=1, inputs=[None], request=None, state={}
)
assert result["data"][0] == {
"open": False,
"label": "Closed Accordion",

View File

@ -671,7 +671,7 @@ class TestPlot:
return fig
iface = gr.Interface(plot, "slider", "plot")
output = await iface.process_api(fn_index=0, inputs=[10])
output = await iface.process_api(fn_index=0, inputs=[10], state={})
assert output["data"][0]["type"] == "matplotlib"
assert output["data"][0]["plot"].startswith("data:image/png;base64")
@ -1244,7 +1244,7 @@ class TestVideo:
)
assert processing_utils.video_is_playable(str(full_path_to_output))
@patch("os.path.exists", MagicMock(return_value=False))
@patch("pathlib.Path.exists", MagicMock(return_value=False))
@patch("gradio.components.FFmpeg")
def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg):
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
@ -1658,7 +1658,9 @@ class TestJSON:
["F", 30],
]
assert (
await iface.process_api(0, [{"data": y_data, "headers": ["gender", "age"]}])
await iface.process_api(
0, [{"data": y_data, "headers": ["gender", "age"]}], state={}
)
)["data"][0] == {
"M": 35,
"F": 25,

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from gradio.queue import Event, Queue
from gradio.queueing import Event, Queue
from gradio.utils import AsyncRequest
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -142,7 +142,7 @@ class TestQueueProcessEvents:
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
@ -284,7 +284,7 @@ class TestQueueProcessEvents:
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event_handles_exception_during_disconnect(
self, mock_request, queue: Queue, mock_event: Event
):