mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
Enforcing typing in blocks.py
and context.py
(#2887)
* started pathlib * blocks.py * more changes * fixes * typing * formatting * typing * renaming files * changelog * script * changelog * lint * routes * renamed * state * formatting * state * type check script * remove strictness * switched to pyright * switched to pyright * fixed flaky tests * fixed test xray * fixed load test * fixed blocks tests * formatting * fixed components test * uncomment tests * fixed interpretation tests * formatting * last tests hopefully * argh lint * component * fixed based on review * refactor
This commit is contained in:
parent
e7cca92831
commit
de0c41c1c4
@ -55,6 +55,10 @@ jobs:
|
||||
command: |
|
||||
. venv/bin/activate
|
||||
bash scripts/lint_backend.sh
|
||||
- run:
|
||||
command: |
|
||||
. venv/bin/activate
|
||||
bash scripts/type_check_backend.sh
|
||||
- run:
|
||||
command: |
|
||||
. venv/bin/activate
|
||||
|
@ -21,6 +21,7 @@ No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
* The `default_enabled` parameter of the `Blocks.queue` method has no effect by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2876](https://github.com/gradio-app/gradio/pull/2876)
|
||||
* Added typing to several Python files in codebase by [@abidlabs](https://github.com/abidlabs) in [PR 2887](https://github.com/gradio-app/gradio/pull/2887)
|
||||
|
||||
## Contributors Shoutout:
|
||||
* @JaySmithWpg for making their first contribution to gradio!
|
||||
|
262
gradio/blocks.py
262
gradio/blocks.py
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
import getpass
|
||||
import inspect
|
||||
@ -12,30 +13,21 @@ import time
|
||||
import typing
|
||||
import warnings
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AnyStr,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type
|
||||
|
||||
import anyio
|
||||
import requests
|
||||
from anyio import CapacityLimiter
|
||||
from typing_extensions import Literal
|
||||
|
||||
from gradio import (
|
||||
components,
|
||||
encryptor,
|
||||
external,
|
||||
networking,
|
||||
queue,
|
||||
queueing,
|
||||
routes,
|
||||
strings,
|
||||
utils,
|
||||
@ -59,11 +51,9 @@ set_documentation_group("blocks")
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import comet_ml
|
||||
import mlflow
|
||||
import wandb
|
||||
from fastapi.applications import FastAPI
|
||||
|
||||
from gradio.components import Component, IOComponent
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Block:
|
||||
@ -84,6 +74,8 @@ class Block:
|
||||
self.root_url = root_url
|
||||
self._skip_init_processing = _skip_init_processing
|
||||
self._style = {}
|
||||
self.parent: BlockContext | None = None
|
||||
|
||||
if render:
|
||||
self.render()
|
||||
check_deprecated_parameters(self.__class__.__name__, **kwargs)
|
||||
@ -135,6 +127,9 @@ class Block:
|
||||
else self.__class__.__name__.lower()
|
||||
)
|
||||
|
||||
def get_expected_parent(self) -> Type[BlockContext] | None:
|
||||
return None
|
||||
|
||||
def set_event_trigger(
|
||||
self,
|
||||
event_name: str,
|
||||
@ -145,7 +140,7 @@ class Block:
|
||||
postprocess: bool = True,
|
||||
scroll_to_output: bool = False,
|
||||
show_progress: bool = True,
|
||||
api_name: AnyStr | None = None,
|
||||
api_name: str | None = None,
|
||||
js: str | None = None,
|
||||
no_target: bool = False,
|
||||
queue: bool | None = None,
|
||||
@ -207,8 +202,10 @@ class Block:
|
||||
"Either batch is True or every is non-zero but not both."
|
||||
)
|
||||
|
||||
if every:
|
||||
if every and fn:
|
||||
fn = get_continuous_fn(fn, every)
|
||||
elif every:
|
||||
raise ValueError("Cannot set a value for `every` without a `fn`.")
|
||||
|
||||
Context.root_block.fns.append(
|
||||
BlockFunction(fn, inputs, outputs, preprocess, postprocess, inputs_as_dict)
|
||||
@ -250,11 +247,16 @@ class Block:
|
||||
"root_url": self.root_url,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def update(**kwargs) -> Dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_specific_update(cls, generic_update):
|
||||
def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict:
|
||||
del generic_update["__type__"]
|
||||
generic_update = cls.update(**generic_update)
|
||||
return generic_update
|
||||
specific_update = cls.update(**generic_update)
|
||||
return specific_update
|
||||
|
||||
|
||||
class BlockContext(Block):
|
||||
@ -269,7 +271,7 @@ class BlockContext(Block):
|
||||
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
|
||||
render: If False, this will not be included in the Blocks config file at all.
|
||||
"""
|
||||
self.children = []
|
||||
self.children: List[Block] = []
|
||||
super().__init__(visible=visible, render=render, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
@ -277,7 +279,7 @@ class BlockContext(Block):
|
||||
Context.block = self
|
||||
return self
|
||||
|
||||
def add(self, child):
|
||||
def add(self, child: Block):
|
||||
child.parent = self
|
||||
self.children.append(child)
|
||||
|
||||
@ -285,7 +287,7 @@ class BlockContext(Block):
|
||||
children = []
|
||||
pseudo_parent = None
|
||||
for child in self.children:
|
||||
expected_parent = getattr(child.__class__, "expected_parent", False)
|
||||
expected_parent = child.get_expected_parent()
|
||||
if not expected_parent or isinstance(self, expected_parent):
|
||||
pseudo_parent = None
|
||||
children.append(child)
|
||||
@ -298,7 +300,8 @@ class BlockContext(Block):
|
||||
pseudo_parent = expected_parent(render=False)
|
||||
children.append(pseudo_parent)
|
||||
pseudo_parent.children = [child]
|
||||
Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
|
||||
if Context.root_block:
|
||||
Context.root_block.blocks[pseudo_parent._id] = pseudo_parent
|
||||
child.parent = pseudo_parent
|
||||
self.children = children
|
||||
|
||||
@ -317,7 +320,7 @@ class BlockContext(Block):
|
||||
class BlockFunction:
|
||||
def __init__(
|
||||
self,
|
||||
fn: Optional[Callable],
|
||||
fn: Callable | None,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
preprocess: bool,
|
||||
@ -393,7 +396,7 @@ def update(**kwargs) -> dict:
|
||||
),
|
||||
gr.Textbox(lines=2),
|
||||
live=True,
|
||||
).launch()
|
||||
).launch()abstrac
|
||||
"""
|
||||
kwargs["__type__"] = "generic_update"
|
||||
return kwargs
|
||||
@ -423,14 +426,19 @@ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool =
|
||||
update_dict.pop("value")
|
||||
prediction_value = delete_none(update_dict, skip_value=True)
|
||||
if "value" in prediction_value and postprocess:
|
||||
assert isinstance(
|
||||
block, components.IOComponent
|
||||
), f"Component {block.__class__} does not support value"
|
||||
prediction_value["value"] = block.postprocess(prediction_value["value"])
|
||||
return prediction_value
|
||||
|
||||
|
||||
def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) -> List:
|
||||
def convert_component_dict_to_list(
|
||||
outputs_ids: List[int], predictions: Dict
|
||||
) -> List | Dict:
|
||||
"""
|
||||
Converts a dictionary of component updates into a list of updates in the order of
|
||||
the outputs_ids and including every output component.
|
||||
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
|
||||
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
|
||||
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
|
||||
"""
|
||||
@ -453,7 +461,9 @@ def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) ->
|
||||
|
||||
|
||||
def add_request_to_inputs(
|
||||
fn: Callable, inputs: List[Any], request: routes.Request | List[routes.Request]
|
||||
fn: Callable,
|
||||
inputs: List[Any],
|
||||
request: routes.Request | List[routes.Request] | None,
|
||||
):
|
||||
"""
|
||||
Adds the FastAPI Request object to the inputs of a function if the type of the parameter is FastAPI.Request.
|
||||
@ -508,17 +518,17 @@ class Blocks(BlockContext):
|
||||
def __init__(
|
||||
self,
|
||||
theme: str = "default",
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
analytics_enabled: bool | None = None,
|
||||
mode: str = "blocks",
|
||||
title: str = "Gradio",
|
||||
css: Optional[str] = None,
|
||||
css: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
theme: which theme to use - right now, only "default" is supported.
|
||||
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
||||
mode: a human-friendly name for the kind of Blocks interface being created.
|
||||
mode: a human-friendly name for the kind of Blocks or Interface being created.
|
||||
title: The tab title to display when this is opened in a browser window.
|
||||
css: custom css or path to custom css file to apply to entire Blocks
|
||||
"""
|
||||
@ -531,7 +541,7 @@ class Blocks(BlockContext):
|
||||
self.enable_queue = None
|
||||
self.max_threads = 40
|
||||
self.show_error = True
|
||||
if css is not None and os.path.exists(css):
|
||||
if css is not None and Path(css).exists():
|
||||
with open(css) as css_file:
|
||||
self.css = css_file.read()
|
||||
else:
|
||||
@ -558,7 +568,7 @@ class Blocks(BlockContext):
|
||||
self.height = None
|
||||
self.api_open = True
|
||||
|
||||
self.ip_address = None
|
||||
self.ip_address = ""
|
||||
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
|
||||
self.favicon_path = None
|
||||
self.auth = None
|
||||
@ -568,6 +578,13 @@ class Blocks(BlockContext):
|
||||
self.title = title
|
||||
self.show_api = True
|
||||
|
||||
# Only used when an Interface is loaded from a config
|
||||
self.predict = None
|
||||
self.input_components = None
|
||||
self.output_components = None
|
||||
self.__name__ = None
|
||||
self.api_mode = None
|
||||
|
||||
if self.analytics_enabled:
|
||||
self.ip_address = utils.get_local_ip_address()
|
||||
data = {
|
||||
@ -575,7 +592,7 @@ class Blocks(BlockContext):
|
||||
"ip_address": self.ip_address,
|
||||
"custom_css": self.css is not None,
|
||||
"theme": self.theme,
|
||||
"version": pkgutil.get_data(__name__, "version.txt")
|
||||
"version": (pkgutil.get_data(__name__, "version.txt") or b"")
|
||||
.decode("ascii")
|
||||
.strip(),
|
||||
}
|
||||
@ -611,7 +628,7 @@ class Blocks(BlockContext):
|
||||
block_config["props"]["root_url"] = root_url + "/"
|
||||
# Any component has already processed its initial value, so we skip that step here
|
||||
block = cls(**block_config["props"], _skip_init_processing=True)
|
||||
if style:
|
||||
if style and isinstance(block, components.IOComponent):
|
||||
block.style(**style)
|
||||
return block
|
||||
|
||||
@ -619,10 +636,14 @@ class Blocks(BlockContext):
|
||||
for child_config in children_list:
|
||||
id = child_config["id"]
|
||||
block = get_block_instance(id)
|
||||
|
||||
original_mapping[id] = block
|
||||
|
||||
children = child_config.get("children")
|
||||
if children is not None:
|
||||
assert isinstance(
|
||||
block, BlockContext
|
||||
), f"Invalid config, Block with id {id} has children but is not a BlockContext."
|
||||
with block:
|
||||
iterate_over_children(children)
|
||||
|
||||
@ -665,7 +686,7 @@ class Blocks(BlockContext):
|
||||
first_dependency = dependency
|
||||
|
||||
# Allows some use of Interface-specific methods with loaded Spaces
|
||||
if first_dependency:
|
||||
if first_dependency and Context.root_block:
|
||||
blocks.predict = [fns[0]]
|
||||
blocks.input_components = [
|
||||
Context.root_block.blocks[i] for i in first_dependency["inputs"]
|
||||
@ -673,11 +694,8 @@ class Blocks(BlockContext):
|
||||
blocks.output_components = [
|
||||
Context.root_block.blocks[o] for o in first_dependency["outputs"]
|
||||
]
|
||||
|
||||
if config.get("mode", "blocks") == "interface":
|
||||
blocks.__name__ = "Interface"
|
||||
blocks.mode = "interface"
|
||||
blocks.api_mode = True
|
||||
blocks.__name__ = "Interface"
|
||||
blocks.api_mode = True
|
||||
|
||||
return blocks
|
||||
|
||||
@ -764,7 +782,7 @@ class Blocks(BlockContext):
|
||||
if inspect.isasyncgenfunction(block_fn.fn):
|
||||
return False
|
||||
if inspect.isgeneratorfunction(block_fn.fn):
|
||||
raise False
|
||||
return False
|
||||
for input_id in dependency["inputs"]:
|
||||
block = self.blocks[input_id]
|
||||
if getattr(block, "stateful", False):
|
||||
@ -776,7 +794,7 @@ class Blocks(BlockContext):
|
||||
|
||||
return True
|
||||
|
||||
def __call__(self, *inputs, fn_index: int = 0, api_name: str = None):
|
||||
def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
|
||||
"""
|
||||
Allows Blocks objects to be called as functions. Supply the parameters to the
|
||||
function as positional arguments. To choose which function to call, use the
|
||||
@ -788,7 +806,7 @@ class Blocks(BlockContext):
|
||||
api_name: The api_name of the dependency to call. Will take precedence over fn_index.
|
||||
"""
|
||||
if api_name is not None:
|
||||
fn_index = next(
|
||||
inferred_fn_index = next(
|
||||
(
|
||||
i
|
||||
for i, d in enumerate(self.dependencies)
|
||||
@ -796,8 +814,9 @@ class Blocks(BlockContext):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if fn_index is None:
|
||||
if inferred_fn_index is None:
|
||||
raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
|
||||
fn_index = inferred_fn_index
|
||||
if not (self.is_callable(fn_index)):
|
||||
raise ValueError(
|
||||
"This function is not callable because it is either stateful or is a generator. Please use the .launch() method instead to create an interactive user interface."
|
||||
@ -814,6 +833,7 @@ class Blocks(BlockContext):
|
||||
fn_index=fn_index,
|
||||
inputs=processed_inputs,
|
||||
request=None,
|
||||
state={},
|
||||
)
|
||||
outputs = outputs["data"]
|
||||
|
||||
@ -834,6 +854,7 @@ class Blocks(BlockContext):
|
||||
):
|
||||
"""Calls and times function with given index and preprocessed input."""
|
||||
block_fn = self.fns[fn_index]
|
||||
assert block_fn.fn, f"function with index {fn_index} not defined."
|
||||
is_generating = False
|
||||
|
||||
if block_fn.inputs_as_dict:
|
||||
@ -856,6 +877,8 @@ class Blocks(BlockContext):
|
||||
prediction = await anyio.to_thread.run_sync(
|
||||
block_fn.fn, *processed_input, limiter=self.limiter
|
||||
)
|
||||
else:
|
||||
prediction = None
|
||||
|
||||
if inspect.isasyncgenfunction(block_fn.fn):
|
||||
raise ValueError("Gradio does not support async generators.")
|
||||
@ -892,7 +915,10 @@ class Blocks(BlockContext):
|
||||
processed_input = []
|
||||
|
||||
for i, input_id in enumerate(dependency["inputs"]):
|
||||
block: IOComponent = self.blocks[input_id]
|
||||
block = self.blocks[input_id]
|
||||
assert isinstance(
|
||||
block, components.IOComponent
|
||||
), f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
serialized_input = block.serialize(inputs[i])
|
||||
processed_input.append(serialized_input)
|
||||
|
||||
@ -903,7 +929,10 @@ class Blocks(BlockContext):
|
||||
predictions = []
|
||||
|
||||
for o, output_id in enumerate(dependency["outputs"]):
|
||||
block: IOComponent = self.blocks[output_id]
|
||||
block = self.blocks[output_id]
|
||||
assert isinstance(
|
||||
block, components.IOComponent
|
||||
), f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
deserialized = block.deserialize(outputs[o])
|
||||
predictions.append(deserialized)
|
||||
|
||||
@ -916,7 +945,10 @@ class Blocks(BlockContext):
|
||||
if block_fn.preprocess:
|
||||
processed_input = []
|
||||
for i, input_id in enumerate(dependency["inputs"]):
|
||||
block: IOComponent = self.blocks[input_id]
|
||||
block = self.blocks[input_id]
|
||||
assert isinstance(
|
||||
block, components.Component
|
||||
), f"{block.__class__} Component with id {input_id} not a valid input component."
|
||||
if getattr(block, "stateful", False):
|
||||
processed_input.append(state.get(input_id))
|
||||
else:
|
||||
@ -926,7 +958,7 @@ class Blocks(BlockContext):
|
||||
return processed_input
|
||||
|
||||
def postprocess_data(
|
||||
self, fn_index: int, predictions: List[Any], state: Dict[int, Any]
|
||||
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
|
||||
):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
@ -938,7 +970,9 @@ class Blocks(BlockContext):
|
||||
)
|
||||
|
||||
if len(dependency["outputs"]) == 1 and not (batch):
|
||||
predictions = (predictions,)
|
||||
predictions = [
|
||||
predictions,
|
||||
]
|
||||
|
||||
output = []
|
||||
for i, output_id in enumerate(dependency["outputs"]):
|
||||
@ -953,12 +987,16 @@ class Blocks(BlockContext):
|
||||
else:
|
||||
prediction_value = predictions[i]
|
||||
if utils.is_update(prediction_value):
|
||||
assert isinstance(prediction_value, dict)
|
||||
prediction_value = postprocess_update_dict(
|
||||
block=block,
|
||||
update_dict=prediction_value,
|
||||
postprocess=block_fn.postprocess,
|
||||
)
|
||||
elif block_fn.postprocess:
|
||||
assert isinstance(
|
||||
block, components.Component
|
||||
), f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
prediction_value = block.postprocess(prediction_value)
|
||||
output.append(prediction_value)
|
||||
return output
|
||||
@ -967,9 +1005,8 @@ class Blocks(BlockContext):
|
||||
self,
|
||||
fn_index: int,
|
||||
inputs: List[Any],
|
||||
state: Dict[int, Any],
|
||||
request: routes.Request | List[routes.Request] | None = None,
|
||||
username: str = None,
|
||||
state: Dict[int, Any] | List[Dict[int, Any]] | None = None,
|
||||
iterators: Dict[int, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -1003,10 +1040,16 @@ class Blocks(BlockContext):
|
||||
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
|
||||
)
|
||||
|
||||
inputs = [self.preprocess_data(fn_index, i, state) for i in zip(*inputs)]
|
||||
result = await self.call_function(fn_index, zip(*inputs), None, request)
|
||||
inputs = [
|
||||
self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
|
||||
]
|
||||
result = await self.call_function(
|
||||
fn_index, list(zip(*inputs)), None, request
|
||||
)
|
||||
preds = result["prediction"]
|
||||
data = [self.postprocess_data(fn_index, o, state) for o in zip(*preds)]
|
||||
data = [
|
||||
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
|
||||
]
|
||||
data = list(zip(*data))
|
||||
is_generating, iterator = None, None
|
||||
else:
|
||||
@ -1098,10 +1141,10 @@ class Blocks(BlockContext):
|
||||
@class_or_instancemethod
|
||||
def load(
|
||||
self_or_cls,
|
||||
fn: Optional[Callable] = None,
|
||||
inputs: Optional[List[Component]] = None,
|
||||
outputs: Optional[List[Component]] = None,
|
||||
api_name: AnyStr = None,
|
||||
fn: Callable | None = None,
|
||||
inputs: List[Component] | None = None,
|
||||
outputs: List[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
scroll_to_output: bool = False,
|
||||
show_progress: bool = True,
|
||||
queue=None,
|
||||
@ -1110,12 +1153,12 @@ class Blocks(BlockContext):
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
every: float | None = None,
|
||||
_js: Optional[str] = None,
|
||||
_js: str | None = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
src: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
alias: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
src: str | None = None,
|
||||
api_key: str | None = None,
|
||||
alias: str | None = None,
|
||||
**kwargs,
|
||||
) -> Blocks | Dict[str, Any] | None:
|
||||
"""
|
||||
@ -1192,7 +1235,7 @@ class Blocks(BlockContext):
|
||||
def queue(
|
||||
self,
|
||||
concurrency_count: int = 1,
|
||||
status_update_rate: float | str = "auto",
|
||||
status_update_rate: float | Literal["auto"] = "auto",
|
||||
client_position_to_load_data: int | None = None,
|
||||
default_enabled: bool | None = None,
|
||||
api_open: bool = True,
|
||||
@ -1221,7 +1264,7 @@ class Blocks(BlockContext):
|
||||
self.api_open = api_open
|
||||
if client_position_to_load_data is not None:
|
||||
warnings.warn("The client_position_to_load_data parameter is deprecated.")
|
||||
self._queue = queue.Queue(
|
||||
self._queue = queueing.Queue(
|
||||
live_updates=status_update_rate == "auto",
|
||||
concurrency_count=concurrency_count,
|
||||
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
|
||||
@ -1233,26 +1276,26 @@ class Blocks(BlockContext):
|
||||
|
||||
def launch(
|
||||
self,
|
||||
inline: bool = None,
|
||||
inline: bool | None = None,
|
||||
inbrowser: bool = False,
|
||||
share: Optional[bool] = None,
|
||||
share: bool | None = None,
|
||||
debug: bool = False,
|
||||
enable_queue: bool = None,
|
||||
enable_queue: bool | None = None,
|
||||
max_threads: int = 40,
|
||||
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
|
||||
auth_message: Optional[str] = None,
|
||||
auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None,
|
||||
auth_message: str | None = None,
|
||||
prevent_thread_lock: bool = False,
|
||||
show_error: bool = False,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
server_name: str | None = None,
|
||||
server_port: int | None = None,
|
||||
show_tips: bool = False,
|
||||
height: int = 500,
|
||||
width: int | str = "100%",
|
||||
encrypt: bool = False,
|
||||
favicon_path: Optional[str] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile_password: Optional[str] = None,
|
||||
favicon_path: str | None = None,
|
||||
ssl_keyfile: str | None = None,
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
quiet: bool = False,
|
||||
show_api: bool = True,
|
||||
_frontend: bool = True,
|
||||
@ -1302,8 +1345,9 @@ class Blocks(BlockContext):
|
||||
and not isinstance(auth[0], tuple)
|
||||
and not isinstance(auth[0], list)
|
||||
):
|
||||
auth = [auth]
|
||||
self.auth = auth
|
||||
self.auth = [auth]
|
||||
else:
|
||||
self.auth = auth
|
||||
self.auth_message = auth_message
|
||||
self.show_tips = show_tips
|
||||
self.show_error = show_error
|
||||
@ -1352,7 +1396,9 @@ class Blocks(BlockContext):
|
||||
)
|
||||
|
||||
if self.is_running:
|
||||
self.server_app.launchable = self
|
||||
assert isinstance(
|
||||
self.local_url, str
|
||||
), f"Invalid local_url: {self.local_url}"
|
||||
if not (quiet):
|
||||
print(
|
||||
"Rerunning server... use `close()` to stop if you need to change `launch()` parameters.\n----"
|
||||
@ -1448,14 +1494,14 @@ class Blocks(BlockContext):
|
||||
self.share_url = None
|
||||
|
||||
if inbrowser:
|
||||
link = self.share_url if self.share else self.local_url
|
||||
link = self.share_url if self.share and self.share_url else self.local_url
|
||||
webbrowser.open(link)
|
||||
|
||||
# Check if running in a Python notebook in which case, display inline
|
||||
if inline is None:
|
||||
inline = utils.ipython_check() and (auth is None)
|
||||
inline = utils.ipython_check() and (self.auth is None)
|
||||
if inline:
|
||||
if auth is not None:
|
||||
if self.auth is not None:
|
||||
print(
|
||||
"Warning: authentication is not supported inline. Please"
|
||||
"click the link to access the interface in a new tab."
|
||||
@ -1463,7 +1509,7 @@ class Blocks(BlockContext):
|
||||
try:
|
||||
from IPython.display import HTML, Javascript, display # type: ignore
|
||||
|
||||
if self.share:
|
||||
if self.share and self.share_url:
|
||||
while not networking.url_ok(self.share_url):
|
||||
time.sleep(0.25)
|
||||
display(
|
||||
@ -1546,9 +1592,9 @@ class Blocks(BlockContext):
|
||||
|
||||
def integrate(
|
||||
self,
|
||||
comet_ml: comet_ml.Experiment = None,
|
||||
wandb: ModuleType("wandb") = None,
|
||||
mlflow: ModuleType("mlflow") = None,
|
||||
comet_ml: comet_ml.Experiment | None = None,
|
||||
wandb: ModuleType | None = None,
|
||||
mlflow: ModuleType | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
A catch-all method for integrating with other libraries. This method should be run after launch()
|
||||
@ -1564,9 +1610,11 @@ class Blocks(BlockContext):
|
||||
if self.share_url is not None:
|
||||
comet_ml.log_text("gradio: " + self.share_url)
|
||||
comet_ml.end()
|
||||
else:
|
||||
elif self.local_url:
|
||||
comet_ml.log_text("gradio: " + self.local_url)
|
||||
comet_ml.end()
|
||||
else:
|
||||
raise ValueError("Please run `launch()` first.")
|
||||
if wandb is not None:
|
||||
analytics_integration = "WandB"
|
||||
if self.share_url is not None:
|
||||
@ -1627,23 +1675,23 @@ class Blocks(BlockContext):
|
||||
|
||||
def attach_load_events(self):
|
||||
"""Add a load event for every component whose initial value should be randomized."""
|
||||
|
||||
for component in Context.root_block.blocks.values():
|
||||
if (
|
||||
isinstance(component, components.IOComponent)
|
||||
and component.load_event_to_attach
|
||||
):
|
||||
load_fn, every = component.load_event_to_attach
|
||||
# Use set_event_trigger to avoid ambiguity between load class/instance method
|
||||
self.set_event_trigger(
|
||||
"load",
|
||||
load_fn,
|
||||
None,
|
||||
component,
|
||||
no_target=True,
|
||||
queue=False,
|
||||
every=every,
|
||||
)
|
||||
if Context.root_block:
|
||||
for component in Context.root_block.blocks.values():
|
||||
if (
|
||||
isinstance(component, components.IOComponent)
|
||||
and component.load_event_to_attach
|
||||
):
|
||||
load_fn, every = component.load_event_to_attach
|
||||
# Use set_event_trigger to avoid ambiguity between load class/instance method
|
||||
self.set_event_trigger(
|
||||
"load",
|
||||
load_fn,
|
||||
None,
|
||||
component,
|
||||
no_target=True,
|
||||
queue=False,
|
||||
every=every,
|
||||
)
|
||||
|
||||
def startup_events(self):
|
||||
"""Events that should be run when the app containing this block starts up."""
|
||||
|
@ -18,7 +18,7 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import altair as alt
|
||||
import matplotlib.figure
|
||||
@ -28,7 +28,7 @@ import PIL
|
||||
import PIL.ImageOps
|
||||
from ffmpy import FFmpeg
|
||||
from markdown_it import MarkdownIt
|
||||
from mdit_py_plugins.dollarmath import dollarmath_plugin
|
||||
from mdit_py_plugins.dollarmath.index import dollarmath_plugin
|
||||
from pandas.api.types import is_numeric_dtype
|
||||
from PIL import Image as _Image # using _ to minimize namespace pollution
|
||||
|
||||
@ -94,6 +94,18 @@ class Component(Block):
|
||||
**super().get_config(),
|
||||
}
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
"""
|
||||
Any preprocessing needed to be performed on function input.
|
||||
"""
|
||||
return x
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
"""
|
||||
return y
|
||||
|
||||
|
||||
class IOComponent(Component, Serializable):
|
||||
"""
|
||||
@ -140,12 +152,6 @@ class IOComponent(Component, Serializable):
|
||||
**super().get_config(),
|
||||
}
|
||||
|
||||
def preprocess(self, x: Any) -> Any:
|
||||
"""
|
||||
Any preprocessing needed to be performed on function input.
|
||||
"""
|
||||
return x
|
||||
|
||||
def set_interpret_parameters(self):
|
||||
"""
|
||||
Set any parameters for interpretation.
|
||||
@ -184,12 +190,6 @@ class IOComponent(Component, Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
"""
|
||||
return y
|
||||
|
||||
def style(
|
||||
self,
|
||||
*,
|
||||
@ -273,12 +273,13 @@ class IOComponent(Component, Serializable):
|
||||
|
||||
|
||||
class FormComponent:
|
||||
expected_parent = Form
|
||||
def get_expected_parent(self) -> Type[Form]:
|
||||
return Form
|
||||
|
||||
|
||||
@document("change", "submit", "blur", "style")
|
||||
class Textbox(
|
||||
Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable, FormComponent
|
||||
FormComponent, Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable
|
||||
):
|
||||
"""
|
||||
Creates a textarea for user to enter string input or display string output.
|
||||
@ -459,7 +460,7 @@ class Textbox(
|
||||
|
||||
@document("change", "submit", "style")
|
||||
class Number(
|
||||
Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable, FormComponent
|
||||
FormComponent, Changeable, Submittable, Blurrable, IOComponent, SimpleSerializable
|
||||
):
|
||||
"""
|
||||
Creates a numeric field for user to enter numbers as input or display numeric output.
|
||||
@ -630,7 +631,7 @@ class Number(
|
||||
|
||||
|
||||
@document("change", "style")
|
||||
class Slider(Changeable, IOComponent, SimpleSerializable, FormComponent):
|
||||
class Slider(FormComponent, Changeable, IOComponent, SimpleSerializable):
|
||||
"""
|
||||
Creates a slider that ranges from `minimum` to `maximum` with a step size of `step`.
|
||||
Preprocessing: passes slider value as a {float} into the function.
|
||||
@ -792,7 +793,7 @@ class Slider(Changeable, IOComponent, SimpleSerializable, FormComponent):
|
||||
|
||||
|
||||
@document("change", "style")
|
||||
class Checkbox(Changeable, IOComponent, SimpleSerializable, FormComponent):
|
||||
class Checkbox(FormComponent, Changeable, IOComponent, SimpleSerializable):
|
||||
"""
|
||||
Creates a checkbox that can be set to `True` or `False`.
|
||||
|
||||
@ -886,7 +887,7 @@ class Checkbox(Changeable, IOComponent, SimpleSerializable, FormComponent):
|
||||
|
||||
|
||||
@document("change", "style")
|
||||
class CheckboxGroup(Changeable, IOComponent, SimpleSerializable, FormComponent):
|
||||
class CheckboxGroup(FormComponent, Changeable, IOComponent, SimpleSerializable):
|
||||
"""
|
||||
Creates a set of checkboxes of which a subset can be checked.
|
||||
Preprocessing: passes the list of checked checkboxes as a {List[str]} or their indices as a {List[int]} into the function, depending on `type`.
|
||||
@ -1056,7 +1057,7 @@ class CheckboxGroup(Changeable, IOComponent, SimpleSerializable, FormComponent):
|
||||
|
||||
|
||||
@document("change", "style")
|
||||
class Radio(Changeable, IOComponent, SimpleSerializable, FormComponent):
|
||||
class Radio(FormComponent, Changeable, IOComponent, SimpleSerializable):
|
||||
"""
|
||||
Creates a set of radio buttons of which only one can be selected.
|
||||
Preprocessing: passes the value of the selected radio button as a {str} or its index as an {int} into the function, depending on `type`.
|
||||
@ -1615,7 +1616,7 @@ class Image(
|
||||
)
|
||||
|
||||
def as_example(self, input_data: str | None) -> str:
|
||||
return os.path.abspath(input_data)
|
||||
return str(Path(input_data).resolve())
|
||||
|
||||
|
||||
@document("change", "clear", "play", "pause", "stop", "style")
|
||||
@ -1758,7 +1759,7 @@ class Video(
|
||||
output_file_name = str(
|
||||
file_name.with_name(f"{file_name.stem}{flip_suffix}{format}")
|
||||
)
|
||||
if os.path.exists(output_file_name):
|
||||
if Path(output_file_name).exists():
|
||||
return output_file_name
|
||||
ff = FFmpeg(
|
||||
inputs={str(file_name): None},
|
||||
@ -2025,7 +2026,7 @@ class Audio(
|
||||
out_data = processing_utils.encode_file_to_base64(file.name)
|
||||
leave_one_out_sets.append(out_data)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
Path(file.name).unlink()
|
||||
|
||||
# Handle the tokens
|
||||
token = np.copy(data)
|
||||
@ -2035,7 +2036,7 @@ class Audio(
|
||||
processing_utils.audio_to_file(sample_rate, token, file.name)
|
||||
token_data = processing_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
Path(file.name).unlink()
|
||||
|
||||
tokens.append(token_data)
|
||||
tokens = [{"name": "token.wav", "data": token} for token in tokens]
|
||||
@ -2066,7 +2067,7 @@ class Audio(
|
||||
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
|
||||
masked_data = processing_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
Path(file.name).unlink()
|
||||
masked_inputs.append(masked_data)
|
||||
return masked_inputs
|
||||
|
||||
@ -3656,11 +3657,11 @@ class Gallery(IOComponent, TempFileManager):
|
||||
img, caption = img
|
||||
if isinstance(img, np.ndarray):
|
||||
file = processing_utils.save_array_to_file(img)
|
||||
file_path = os.path.abspath(file.name)
|
||||
file_path = str(Path(file.name).resolve())
|
||||
self.temp_files.add(file_path)
|
||||
elif isinstance(img, PIL.Image.Image):
|
||||
file = processing_utils.save_pil_to_file(img)
|
||||
file_path = os.path.abspath(file.name)
|
||||
file_path = str(Path(file.name).resolve())
|
||||
self.temp_files.add(file_path)
|
||||
elif isinstance(img, str):
|
||||
if utils.validate_url(img):
|
||||
@ -3706,8 +3707,8 @@ class Gallery(IOComponent, TempFileManager):
|
||||
) -> None | str:
|
||||
if x is None:
|
||||
return None
|
||||
gallery_path = os.path.join(save_dir, str(uuid.uuid4()))
|
||||
os.makedirs(gallery_path)
|
||||
gallery_path = Path(save_dir) / str(uuid.uuid4())
|
||||
gallery_path.mkdir(exist_ok=True, parents=True)
|
||||
captions = {}
|
||||
for img_data in x:
|
||||
if isinstance(img_data, list) or isinstance(img_data, tuple):
|
||||
@ -3716,15 +3717,15 @@ class Gallery(IOComponent, TempFileManager):
|
||||
caption = None
|
||||
name = FileSerializable.deserialize(self, img_data, gallery_path)
|
||||
captions[name] = caption
|
||||
captions_file = os.path.join(gallery_path, "captions.json")
|
||||
with open(captions_file, "w") as captions_json:
|
||||
captions_file = gallery_path / "captions.json"
|
||||
with captions_file.open("w") as captions_json:
|
||||
json.dump(captions, captions_json)
|
||||
return os.path.abspath(gallery_path)
|
||||
return str(gallery_path.resolve())
|
||||
|
||||
def serialize(self, x: Any, load_dir: str = "", called_directly: bool = False):
|
||||
files = []
|
||||
captions_file = os.path.join(x, "captions.json")
|
||||
with open(captions_file) as captions_json:
|
||||
captions_file = Path(x) / "captions.json"
|
||||
with captions_file.open("r") as captions_json:
|
||||
captions = json.load(captions_json)
|
||||
for file_name, caption in captions.items():
|
||||
img = FileSerializable.serialize(self, file_name)
|
||||
@ -4978,9 +4979,6 @@ class Interpretation(Component):
|
||||
def style(self):
|
||||
return self
|
||||
|
||||
def postprocess(self, y: Any) -> Any:
|
||||
return y
|
||||
|
||||
|
||||
class StatusTracker(Component):
|
||||
"""
|
||||
|
@ -9,8 +9,6 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
|
||||
|
||||
class Context:
|
||||
root_block: Blocks = None # The current root block that holds all blocks.
|
||||
block: BlockContext = (
|
||||
None # The current block that all children should be added to.
|
||||
)
|
||||
id = 0 # Running id to uniquely refer to any block that gets defined
|
||||
root_block: Blocks | None = None # The current root block that holds all blocks.
|
||||
block: BlockContext | None = None # The current block that children are added to.
|
||||
id: int = 0 # Running id to uniquely refer to any block that gets defined
|
||||
|
@ -291,7 +291,7 @@ class Examples:
|
||||
if self.batch:
|
||||
processed_input = [[value] for value in processed_input]
|
||||
prediction = await Context.root_block.process_api(
|
||||
fn_index=fn_index, inputs=processed_input, request=None
|
||||
fn_index=fn_index, inputs=processed_input, request=None, state={}
|
||||
)
|
||||
output = prediction["data"]
|
||||
if self.batch:
|
||||
|
@ -33,7 +33,11 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def load_blocks_from_repo(
|
||||
name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs
|
||||
name: str,
|
||||
src: str | None = None,
|
||||
api_key: str | None = None,
|
||||
alias: str | None = None,
|
||||
**kwargs,
|
||||
) -> Blocks:
|
||||
"""Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
|
||||
if src is None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Type
|
||||
|
||||
from gradio.blocks import BlockContext
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
@ -184,8 +184,6 @@ class Tabs(BlockContext):
|
||||
|
||||
|
||||
class TabItem(BlockContext):
|
||||
expected_parent = Tabs
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
@ -221,6 +219,9 @@ class TabItem(BlockContext):
|
||||
"""
|
||||
self.set_event_trigger("select", fn, inputs, outputs)
|
||||
|
||||
def get_expected_parent(self) -> Type[Tabs]:
|
||||
return Tabs
|
||||
|
||||
|
||||
@document()
|
||||
class Tab(TabItem):
|
||||
|
@ -9,7 +9,7 @@ import socket
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import fastapi
|
||||
import requests
|
||||
@ -87,12 +87,12 @@ def configure_app(app: fastapi.FastAPI, blocks: Blocks) -> fastapi.FastAPI:
|
||||
|
||||
def start_server(
|
||||
blocks: Blocks,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile_password: Optional[str] = None,
|
||||
) -> Tuple[int, str, App, Server]:
|
||||
server_name: str | None = None,
|
||||
server_port: int | None = None,
|
||||
ssl_keyfile: str | None = None,
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
) -> Tuple[str, int, str, App, Server]:
|
||||
"""Launches a local server running the provided Interface
|
||||
Parameters:
|
||||
blocks: The Blocks object to run on the server
|
||||
|
@ -10,7 +10,7 @@ from typing import Any, Deque, Dict, List, Optional, Tuple
|
||||
import fastapi
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gradio.dataclasses import PredictBody
|
||||
from gradio.data_classes import PredictBody
|
||||
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ class Queue:
|
||||
self,
|
||||
live_updates: bool,
|
||||
concurrency_count: int,
|
||||
update_intervals: int,
|
||||
update_intervals: float,
|
||||
max_size: Optional[int],
|
||||
blocks_dependencies: List,
|
||||
):
|
@ -37,10 +37,10 @@ from starlette.websockets import WebSocketState
|
||||
|
||||
import gradio
|
||||
from gradio import encryptor, utils
|
||||
from gradio.dataclasses import PredictBody, ResetBody
|
||||
from gradio.data_classes import PredictBody, ResetBody
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.exceptions import Error
|
||||
from gradio.queue import Estimation, Event
|
||||
from gradio.queueing import Estimation, Event
|
||||
from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name
|
||||
|
||||
mimetypes.init()
|
||||
@ -125,7 +125,7 @@ class App(FastAPI):
|
||||
self.tokens = {}
|
||||
|
||||
@staticmethod
|
||||
def create_app(blocks: gradio.Blocks) -> FastAPI:
|
||||
def create_app(blocks: gradio.Blocks) -> App:
|
||||
app = App(default_response_class=ORJSONResponse)
|
||||
app.configure_app(blocks)
|
||||
|
||||
@ -322,7 +322,6 @@ class App(FastAPI):
|
||||
fn_index=fn_index,
|
||||
inputs=raw_input,
|
||||
request=request,
|
||||
username=username,
|
||||
state=session_state,
|
||||
iterators=iterators,
|
||||
)
|
||||
|
@ -118,7 +118,7 @@ class FileSerializable(Serializable):
|
||||
def deserialize(
|
||||
self,
|
||||
x: str | Dict | None,
|
||||
save_dir: str | None = None,
|
||||
save_dir: Path | str | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
):
|
||||
"""
|
||||
@ -131,6 +131,8 @@ class FileSerializable(Serializable):
|
||||
"""
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(save_dir, Path):
|
||||
save_dir = str(save_dir)
|
||||
if isinstance(x, str):
|
||||
file_name = processing_utils.decode_base64_to_file(
|
||||
x, dir=save_dir, encryption_key=encryption_key
|
||||
|
@ -33,6 +33,7 @@ from typing import (
|
||||
NewType,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
@ -58,6 +59,8 @@ analytics_url = "https://api.gradio.app/"
|
||||
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
|
||||
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def version_check():
|
||||
try:
|
||||
@ -292,7 +295,7 @@ def format_ner_list(input_string: str, ner_groups: Dict[str : str | int]):
|
||||
return output
|
||||
|
||||
|
||||
def delete_none(_dict, skip_value=False):
|
||||
def delete_none(_dict: T, skip_value: bool = False) -> T:
|
||||
"""
|
||||
Delete None values recursively from all of the dictionaries, tuples, lists, sets.
|
||||
Credit: https://stackoverflow.com/a/66127889/5209347
|
||||
@ -319,7 +322,7 @@ def resolve_singleton(_list: List[Any] | Any) -> Any:
|
||||
return _list
|
||||
|
||||
|
||||
def component_or_layout_class(cls_name: str) -> Component | BlockContext:
|
||||
def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockContext]:
|
||||
"""
|
||||
Returns the component, template, or layout class with the given class name, or
|
||||
raises a ValueError if not found.
|
||||
@ -357,7 +360,7 @@ def component_or_layout_class(cls_name: str) -> Component | BlockContext:
|
||||
raise ValueError(f"No such component or layout: {cls_name}")
|
||||
|
||||
|
||||
def synchronize_async(func: Callable, *args, **kwargs):
|
||||
def synchronize_async(func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Runs async functions in sync scopes.
|
||||
|
||||
@ -711,10 +714,10 @@ def validate_url(possible_url: str) -> bool:
|
||||
|
||||
|
||||
def is_update(val):
|
||||
return type(val) is dict and "update" in val.get("__type__", "")
|
||||
return isinstance(val, dict) and "update" in val.get("__type__", "")
|
||||
|
||||
|
||||
def get_continuous_fn(fn, every):
|
||||
def get_continuous_fn(fn: Callable, every: float) -> Callable:
|
||||
def continuous_fn(*args):
|
||||
while True:
|
||||
output = fn(*args)
|
||||
|
@ -20,3 +20,4 @@ fsspec
|
||||
httpx
|
||||
pydantic
|
||||
websockets>=10.0
|
||||
typing_extensions
|
8
scripts/type_check_backend.sh
Normal file
8
scripts/type_check_backend.sh
Normal file
@ -0,0 +1,8 @@
|
||||
cd "$(dirname ${0})/.."
|
||||
source scripts/helpers.sh
|
||||
|
||||
pip_required
|
||||
|
||||
pip install --upgrade pip
|
||||
pip install pyright
|
||||
pyright gradio/context.py gradio/blocks.py
|
@ -179,7 +179,7 @@ class TestBlocksMethods:
|
||||
|
||||
btn.click(greet, {first, last}, greeting)
|
||||
|
||||
result = await demo.process_api(inputs=["huggy", "face"], fn_index=0)
|
||||
result = await demo.process_api(inputs=["huggy", "face"], fn_index=0, state={})
|
||||
assert result["data"] == ["Hello huggy face"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -194,7 +194,7 @@ class TestBlocksMethods:
|
||||
button.click(wait, [text], [text])
|
||||
|
||||
start = time.time()
|
||||
result = await demo.process_api(inputs=[1], fn_index=0)
|
||||
result = await demo.process_api(inputs=[1], fn_index=0, state={})
|
||||
end = time.time()
|
||||
difference = end - start
|
||||
assert difference >= 0.01
|
||||
@ -395,7 +395,7 @@ class TestComponentsInBlocks:
|
||||
share_button = gr.Button("share", visible=False)
|
||||
run_button.click(infer, prompt, [image, share_button], postprocess=False)
|
||||
|
||||
output = await demo.process_api(0, ["test"])
|
||||
output = await demo.process_api(0, ["test"], state={})
|
||||
assert output["data"][0] == gr.media_data.BASE64_IMAGE
|
||||
assert output["data"][1] == {"__type__": "update", "visible": True}
|
||||
|
||||
@ -412,7 +412,7 @@ class TestComponentsInBlocks:
|
||||
run_button = gr.Button()
|
||||
run_button.click(infer, [prompt], [image], postprocess=False)
|
||||
|
||||
output = await demo.process_api(0, ["test"])
|
||||
output = await demo.process_api(0, ["test"], state={})
|
||||
assert output["data"][0] == {
|
||||
"__type__": "update",
|
||||
"value": gr.media_data.BASE64_IMAGE,
|
||||
@ -439,7 +439,7 @@ class TestComponentsInBlocks:
|
||||
run.click(generic_update, None, [image, textbox])
|
||||
|
||||
for fn_index in range(2):
|
||||
output = await demo.process_api(fn_index, [])
|
||||
output = await demo.process_api(fn_index, [], state={})
|
||||
assert output["data"][0] == {
|
||||
"interactive": True,
|
||||
"__type__": "update",
|
||||
@ -686,7 +686,7 @@ class TestBatchProcessing:
|
||||
btn = gr.Button()
|
||||
btn.click(batch_fn, inputs=text, outputs=text, batch=True)
|
||||
|
||||
await demo.process_api(0, [["Adam", "Yahya"]])
|
||||
await demo.process_api(0, [["Adam", "Yahya"]], state={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exceeds_max_batch_size(self):
|
||||
@ -705,7 +705,7 @@ class TestBatchProcessing:
|
||||
batch_fn, inputs=text, outputs=text, batch=True, max_batch_size=2
|
||||
)
|
||||
|
||||
await demo.process_api(0, [["A", "B", "C"]])
|
||||
await demo.process_api(0, [["A", "B", "C"]], state={})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unequal_batch_sizes(self):
|
||||
@ -723,7 +723,7 @@ class TestBatchProcessing:
|
||||
btn = gr.Button()
|
||||
btn.click(batch_fn, inputs=[t1, t2], outputs=t1, batch=True)
|
||||
|
||||
await demo.process_api(0, [["A", "B", "C"], ["D", "E"]])
|
||||
await demo.process_api(0, [["A", "B", "C"], ["D", "E"]], state={})
|
||||
|
||||
|
||||
class TestSpecificUpdate:
|
||||
@ -802,13 +802,17 @@ class TestSpecificUpdate:
|
||||
inputs=None,
|
||||
outputs=[accordion],
|
||||
)
|
||||
result = await demo.process_api(fn_index=0, inputs=[None], request=None)
|
||||
result = await demo.process_api(
|
||||
fn_index=0, inputs=[None], request=None, state={}
|
||||
)
|
||||
assert result["data"][0] == {
|
||||
"open": True,
|
||||
"label": "Open Accordion",
|
||||
"__type__": "update",
|
||||
}
|
||||
result = await demo.process_api(fn_index=1, inputs=[None], request=None)
|
||||
result = await demo.process_api(
|
||||
fn_index=1, inputs=[None], request=None, state={}
|
||||
)
|
||||
assert result["data"][0] == {
|
||||
"open": False,
|
||||
"label": "Closed Accordion",
|
||||
|
@ -671,7 +671,7 @@ class TestPlot:
|
||||
return fig
|
||||
|
||||
iface = gr.Interface(plot, "slider", "plot")
|
||||
output = await iface.process_api(fn_index=0, inputs=[10])
|
||||
output = await iface.process_api(fn_index=0, inputs=[10], state={})
|
||||
assert output["data"][0]["type"] == "matplotlib"
|
||||
assert output["data"][0]["plot"].startswith("data:image/png;base64")
|
||||
|
||||
@ -1244,7 +1244,7 @@ class TestVideo:
|
||||
)
|
||||
assert processing_utils.video_is_playable(str(full_path_to_output))
|
||||
|
||||
@patch("os.path.exists", MagicMock(return_value=False))
|
||||
@patch("pathlib.Path.exists", MagicMock(return_value=False))
|
||||
@patch("gradio.components.FFmpeg")
|
||||
def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg):
|
||||
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
|
||||
@ -1658,7 +1658,9 @@ class TestJSON:
|
||||
["F", 30],
|
||||
]
|
||||
assert (
|
||||
await iface.process_api(0, [{"data": y_data, "headers": ["gender", "age"]}])
|
||||
await iface.process_api(
|
||||
0, [{"data": y_data, "headers": ["gender", "age"]}], state={}
|
||||
)
|
||||
)["data"][0] == {
|
||||
"M": 35,
|
||||
"F": 25,
|
||||
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gradio.queue import Event, Queue
|
||||
from gradio.queueing import Event, Queue
|
||||
from gradio.utils import AsyncRequest
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
@ -142,7 +142,7 @@ class TestQueueProcessEvents:
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
|
||||
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
|
||||
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
|
||||
queue.gather_event_data = AsyncMock()
|
||||
queue.gather_event_data.return_value = True
|
||||
@ -284,7 +284,7 @@ class TestQueueProcessEvents:
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
|
||||
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
|
||||
async def test_process_event_handles_exception_during_disconnect(
|
||||
self, mock_request, queue: Queue, mock_event: Event
|
||||
):
|
Loading…
Reference in New Issue
Block a user