mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
More Ruff rules (#4038)
* Bump ruff to 0.0.264 * Enable Ruff Naming rules and fix most errors * Move `clean_html` to utils (to fix an N lint error) * Changelog * Clean up possibly leaking file handles * Enable and autofix Ruff SIM * Fix remaining Ruff SIMs * Enable and autofix Ruff UP issues * Fix misordered import from #4048 * Fix bare except from #4048 --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
71f1e654ab
commit
d1853625fd
@ -120,6 +120,7 @@ No changes to highlight.
|
||||
|
||||
- CI: Simplified Python CI workflow by [@akx](https://github.com/akx) in [PR 3982](https://github.com/gradio-app/gradio/pull/3982)
|
||||
- Upgrade pyright to 1.1.305 by [@akx](https://github.com/akx) in [PR 4042](https://github.com/gradio-app/gradio/pull/4042)
|
||||
- More Ruff rules are enabled and lint errors fixed by [@akx](https://github.com/akx) in [PR 4038](https://github.com/gradio-app/gradio/pull/4038)
|
||||
|
||||
## Breaking Changes:
|
||||
|
||||
|
@ -13,7 +13,7 @@ from concurrent.futures import Future, TimeoutError
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
@ -132,7 +132,7 @@ class Client:
|
||||
hf_token: str | None = None,
|
||||
private: bool = True,
|
||||
hardware: str | None = None,
|
||||
secrets: Dict[str, str] | None = None,
|
||||
secrets: dict[str, str] | None = None,
|
||||
sleep_timeout: int = 5,
|
||||
max_workers: int = 40,
|
||||
verbose: bool = True,
|
||||
@ -216,7 +216,7 @@ class Client:
|
||||
current_info.hardware or huggingface_hub.SpaceHardware.CPU_BASIC
|
||||
)
|
||||
hardware = hardware or original_info.hardware
|
||||
if not current_hardware == hardware:
|
||||
if current_hardware != hardware:
|
||||
huggingface_hub.request_space_hardware(space_id, hardware) # type: ignore
|
||||
print(
|
||||
f"-------\nNOTE: this Space uses upgraded hardware: {hardware}... see billing info at https://huggingface.co/settings/billing\n-------"
|
||||
@ -262,7 +262,7 @@ class Client:
|
||||
*args,
|
||||
api_name: str | None = None,
|
||||
fn_index: int | None = None,
|
||||
result_callbacks: Callable | List[Callable] | None = None,
|
||||
result_callbacks: Callable | list[Callable] | None = None,
|
||||
) -> Job:
|
||||
"""
|
||||
Creates and returns a Job object which calls the Gradio API in a background thread. The job can be used to retrieve the status and result of the remote API call.
|
||||
@ -323,7 +323,7 @@ class Client:
|
||||
all_endpoints: bool | None = None,
|
||||
print_info: bool = True,
|
||||
return_format: Literal["dict", "str"] | None = None,
|
||||
) -> Dict | str | None:
|
||||
) -> dict | str | None:
|
||||
"""
|
||||
Prints the usage info for the API. If the Gradio app has multiple API endpoints, the usage info for each endpoint will be printed separately. If return_format="dict" the info is returned in dictionary format, as shown in the example below.
|
||||
|
||||
@ -449,7 +449,7 @@ class Client:
|
||||
def _render_endpoints_info(
|
||||
self,
|
||||
name_or_index: str | int,
|
||||
endpoints_info: Dict[str, List[Dict[str, str]]],
|
||||
endpoints_info: dict[str, list[dict[str, str]]],
|
||||
) -> str:
|
||||
parameter_names = [p["label"] for p in endpoints_info["parameters"]]
|
||||
parameter_names = [utils.sanitize_parameter_names(p) for p in parameter_names]
|
||||
@ -542,7 +542,7 @@ class Client:
|
||||
def _space_name_to_src(self, space) -> str | None:
|
||||
return huggingface_hub.space_info(space, token=self.hf_token).host # type: ignore
|
||||
|
||||
def _get_config(self) -> Dict:
|
||||
def _get_config(self) -> dict:
|
||||
r = requests.get(
|
||||
urllib.parse.urljoin(self.src, utils.CONFIG_URL), headers=self.headers
|
||||
)
|
||||
@ -568,7 +568,7 @@ class Client:
|
||||
class Endpoint:
|
||||
"""Helper class for storing all the information about a single API endpoint."""
|
||||
|
||||
def __init__(self, client: Client, fn_index: int, dependency: Dict):
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict):
|
||||
self.client: Client = client
|
||||
self.fn_index = fn_index
|
||||
self.dependency = dependency
|
||||
@ -615,7 +615,7 @@ class Endpoint:
|
||||
return _inner
|
||||
|
||||
def make_predict(self, helper: Communicator | None = None):
|
||||
def _predict(*data) -> Tuple:
|
||||
def _predict(*data) -> tuple:
|
||||
data = json.dumps(
|
||||
{
|
||||
"data": data,
|
||||
@ -669,8 +669,8 @@ class Endpoint:
|
||||
return outputs
|
||||
|
||||
def _upload(
|
||||
self, file_paths: List[str | List[str]]
|
||||
) -> List[str | List[str]] | List[Dict[str, Any] | List[Dict[str, Any]]]:
|
||||
self, file_paths: list[str | list[str]]
|
||||
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
|
||||
if not file_paths:
|
||||
return []
|
||||
# Put all the filepaths in one file
|
||||
@ -683,7 +683,7 @@ class Endpoint:
|
||||
if not isinstance(fs, list):
|
||||
fs = [fs]
|
||||
for f in fs:
|
||||
files.append(("files", (Path(f).name, open(f, "rb"))))
|
||||
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
|
||||
indices.append(i)
|
||||
r = requests.post(
|
||||
self.client.upload_url, headers=self.client.headers, files=files
|
||||
@ -718,8 +718,8 @@ class Endpoint:
|
||||
|
||||
def _add_uploaded_files_to_data(
|
||||
self,
|
||||
files: List[str | List[str]] | List[Dict[str, Any] | List[Dict[str, Any]]],
|
||||
data: List[Any],
|
||||
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
|
||||
data: list[Any],
|
||||
) -> None:
|
||||
"""Helper function to modify the input data with the uploaded files."""
|
||||
file_counter = 0
|
||||
@ -728,7 +728,7 @@ class Endpoint:
|
||||
data[i] = files[file_counter]
|
||||
file_counter += 1
|
||||
|
||||
def serialize(self, *data) -> Tuple:
|
||||
def serialize(self, *data) -> tuple:
|
||||
data = list(data)
|
||||
for i, input_component_type in enumerate(self.input_component_types):
|
||||
if input_component_type == utils.STATE_COMPONENT:
|
||||
@ -748,7 +748,7 @@ class Endpoint:
|
||||
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||
return o
|
||||
|
||||
def deserialize(self, *data) -> Tuple | Any:
|
||||
def deserialize(self, *data) -> tuple | Any:
|
||||
assert len(data) == len(
|
||||
self.deserializers
|
||||
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
@ -758,7 +758,7 @@ class Endpoint:
|
||||
for s, d, oct in zip(
|
||||
self.deserializers, data, self.output_component_types
|
||||
)
|
||||
if not oct == utils.STATE_COMPONENT
|
||||
if oct != utils.STATE_COMPONENT
|
||||
]
|
||||
)
|
||||
if (
|
||||
@ -766,7 +766,7 @@ class Endpoint:
|
||||
[
|
||||
oct
|
||||
for oct in self.output_component_types
|
||||
if not oct == utils.STATE_COMPONENT
|
||||
if oct != utils.STATE_COMPONENT
|
||||
]
|
||||
)
|
||||
== 1
|
||||
@ -776,7 +776,7 @@ class Endpoint:
|
||||
output = outputs
|
||||
return output
|
||||
|
||||
def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
|
||||
def _setup_serializers(self) -> tuple[list[Serializable], list[Serializable]]:
|
||||
inputs = self.dependency["inputs"]
|
||||
serializers = []
|
||||
|
||||
@ -820,7 +820,7 @@ class Endpoint:
|
||||
|
||||
return serializers, deserializers
|
||||
|
||||
def _use_websocket(self, dependency: Dict) -> bool:
|
||||
def _use_websocket(self, dependency: dict) -> bool:
|
||||
queue_enabled = self.client.config.get("enable_queue", False)
|
||||
queue_uses_websocket = version.parse(
|
||||
self.client.config.get("version", "2.0")
|
||||
@ -873,7 +873,7 @@ class Job(Future):
|
||||
def __iter__(self) -> Job:
|
||||
return self
|
||||
|
||||
def __next__(self) -> Tuple | Any:
|
||||
def __next__(self) -> tuple | Any:
|
||||
if not self.communicator:
|
||||
raise StopIteration()
|
||||
|
||||
@ -925,7 +925,7 @@ class Job(Future):
|
||||
else:
|
||||
return super().result(timeout=timeout)
|
||||
|
||||
def outputs(self) -> List[Tuple | Any]:
|
||||
def outputs(self) -> list[tuple | Any]:
|
||||
"""
|
||||
Returns a list containing the latest outputs from the Job.
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
from typing import Callable
|
||||
|
||||
classes_to_document = {}
|
||||
classes_inherit_documentation = {}
|
||||
@ -61,7 +61,7 @@ def document(*fns, inherit=False):
|
||||
return inner_doc
|
||||
|
||||
|
||||
def document_fn(fn: Callable, cls) -> Tuple[str, List[Dict], Dict, str | None]:
|
||||
def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]:
|
||||
"""
|
||||
Generates documentation for any function.
|
||||
Parameters:
|
||||
|
@ -4,21 +4,21 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from gradio_client import media_data, utils
|
||||
from gradio_client.data_classes import FileData
|
||||
|
||||
|
||||
class Serializable:
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
The typing information for this component as a dictionary whose values are a list of 2 strings: [Python type, language-agnostic description].
|
||||
Keys of the dictionary are: raw_input, raw_output, serialized_input, serialized_output
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
"""
|
||||
The example inputs for this component as a dictionary whose values are example inputs compatible with this component.
|
||||
Keys of the dictionary are: raw, serialized
|
||||
@ -26,12 +26,12 @@ class Serializable:
|
||||
raise NotImplementedError()
|
||||
|
||||
# For backwards compatibility
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
def input_api_info(self) -> tuple[str, str]:
|
||||
api_info = self.api_info()
|
||||
return (api_info["serialized_input"][0], api_info["serialized_input"][1])
|
||||
|
||||
# For backwards compatibility
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
def output_api_info(self) -> tuple[str, str]:
|
||||
api_info = self.api_info()
|
||||
return (api_info["serialized_output"][0], api_info["serialized_output"][1])
|
||||
|
||||
@ -57,7 +57,7 @@ class Serializable:
|
||||
class SimpleSerializable(Serializable):
|
||||
"""General class that does not perform any serialization or deserialization."""
|
||||
|
||||
def api_info(self) -> Dict[str, str | List[str]]:
|
||||
def api_info(self) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
"raw_input": ["Any", ""],
|
||||
"raw_output": ["Any", ""],
|
||||
@ -65,7 +65,7 @@ class SimpleSerializable(Serializable):
|
||||
"serialized_output": ["Any", ""],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": None,
|
||||
"serialized": None,
|
||||
@ -75,7 +75,7 @@ class SimpleSerializable(Serializable):
|
||||
class StringSerializable(Serializable):
|
||||
"""Expects a string as input/output but performs no serialization."""
|
||||
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": ["str", "string value"],
|
||||
"raw_output": ["str", "string value"],
|
||||
@ -83,7 +83,7 @@ class StringSerializable(Serializable):
|
||||
"serialized_output": ["str", "string value"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": "Howdy!",
|
||||
"serialized": "Howdy!",
|
||||
@ -93,7 +93,7 @@ class StringSerializable(Serializable):
|
||||
class ListStringSerializable(Serializable):
|
||||
"""Expects a list of strings as input/output but performs no serialization."""
|
||||
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": ["List[str]", "list of string values"],
|
||||
"raw_output": ["List[str]", "list of string values"],
|
||||
@ -101,7 +101,7 @@ class ListStringSerializable(Serializable):
|
||||
"serialized_output": ["List[str]", "list of string values"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": ["Howdy!", "Merhaba"],
|
||||
"serialized": ["Howdy!", "Merhaba"],
|
||||
@ -111,7 +111,7 @@ class ListStringSerializable(Serializable):
|
||||
class BooleanSerializable(Serializable):
|
||||
"""Expects a boolean as input/output but performs no serialization."""
|
||||
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": ["bool", "boolean value"],
|
||||
"raw_output": ["bool", "boolean value"],
|
||||
@ -119,7 +119,7 @@ class BooleanSerializable(Serializable):
|
||||
"serialized_output": ["bool", "boolean value"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": True,
|
||||
"serialized": True,
|
||||
@ -129,7 +129,7 @@ class BooleanSerializable(Serializable):
|
||||
class NumberSerializable(Serializable):
|
||||
"""Expects a number (int/float) as input/output but performs no serialization."""
|
||||
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": ["int | float", "numeric value"],
|
||||
"raw_output": ["int | float", "numeric value"],
|
||||
@ -137,7 +137,7 @@ class NumberSerializable(Serializable):
|
||||
"serialized_output": ["int | float", "numeric value"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": 5,
|
||||
"serialized": 5,
|
||||
@ -147,7 +147,7 @@ class NumberSerializable(Serializable):
|
||||
class ImgSerializable(Serializable):
|
||||
"""Expects a base64 string as input/output which is serialized to a filepath."""
|
||||
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": ["str", "base64 representation of image"],
|
||||
"raw_output": ["str", "base64 representation of image"],
|
||||
@ -155,7 +155,7 @@ class ImgSerializable(Serializable):
|
||||
"serialized_output": ["str", "filepath or URL to image"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": media_data.BASE64_IMAGE,
|
||||
"serialized": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
|
||||
@ -204,7 +204,7 @@ class ImgSerializable(Serializable):
|
||||
class FileSerializable(Serializable):
|
||||
"""Expects a dict with base64 representation of object as input/output which is serialized to a filepath."""
|
||||
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": [
|
||||
"str | Dict",
|
||||
@ -218,7 +218,7 @@ class FileSerializable(Serializable):
|
||||
"serialized_output": ["str", "filepath or URL to file"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": {"is_file": False, "data": media_data.BASE64_FILE},
|
||||
"serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf",
|
||||
@ -280,9 +280,9 @@ class FileSerializable(Serializable):
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
x: str | FileData | None | List[str | FileData | None],
|
||||
x: str | FileData | None | list[str | FileData | None],
|
||||
load_dir: str | Path = "",
|
||||
) -> FileData | None | List[FileData | None]:
|
||||
) -> FileData | None | list[FileData | None]:
|
||||
"""
|
||||
Convert from human-friendly version of a file (string filepath) to a
|
||||
seralized representation (base64)
|
||||
@ -299,11 +299,11 @@ class FileSerializable(Serializable):
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
x: str | FileData | None | List[str | FileData | None],
|
||||
x: str | FileData | None | list[str | FileData | None],
|
||||
save_dir: Path | str | None = None,
|
||||
root_url: str | None = None,
|
||||
hf_token: str | None = None,
|
||||
) -> str | None | List[str | None]:
|
||||
) -> str | None | list[str | None]:
|
||||
"""
|
||||
Convert from serialized representation of a file (base64) to a human-friendly
|
||||
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
|
||||
@ -331,7 +331,7 @@ class FileSerializable(Serializable):
|
||||
|
||||
|
||||
class VideoSerializable(FileSerializable):
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": [
|
||||
"str | Dict",
|
||||
@ -345,7 +345,7 @@ class VideoSerializable(FileSerializable):
|
||||
"serialized_output": ["str", "filepath or URL to file"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": {"is_file": False, "data": media_data.BASE64_VIDEO},
|
||||
"serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/video_sample.mp4",
|
||||
@ -353,16 +353,16 @@ class VideoSerializable(FileSerializable):
|
||||
|
||||
def serialize(
|
||||
self, x: str | None, load_dir: str | Path = ""
|
||||
) -> Tuple[FileData | None, None]:
|
||||
) -> tuple[FileData | None, None]:
|
||||
return (super().serialize(x, load_dir), None) # type: ignore
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
x: Tuple[FileData | None, FileData | None] | None,
|
||||
x: tuple[FileData | None, FileData | None] | None,
|
||||
save_dir: Path | str | None = None,
|
||||
root_url: str | None = None,
|
||||
hf_token: str | None = None,
|
||||
) -> str | Tuple[str | None, str | None] | None:
|
||||
) -> str | tuple[str | None, str | None] | None:
|
||||
"""
|
||||
Convert from serialized representation of a file (base64) to a human-friendly
|
||||
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
|
||||
@ -378,7 +378,7 @@ class VideoSerializable(FileSerializable):
|
||||
|
||||
|
||||
class JSONSerializable(Serializable):
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": ["str | Dict | List", "JSON-serializable object or a string"],
|
||||
"raw_output": ["Dict | List", "dictionary- or list-like object"],
|
||||
@ -386,7 +386,7 @@ class JSONSerializable(Serializable):
|
||||
"serialized_output": ["str", "filepath to JSON file"],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": {"a": 1, "b": 2},
|
||||
"serialized": None,
|
||||
@ -396,7 +396,7 @@ class JSONSerializable(Serializable):
|
||||
self,
|
||||
x: str | None,
|
||||
load_dir: str | Path = "",
|
||||
) -> Dict | List | None:
|
||||
) -> dict | list | None:
|
||||
"""
|
||||
Convert from a a human-friendly version (string path to json file) to a
|
||||
serialized representation (json string)
|
||||
@ -410,7 +410,7 @@ class JSONSerializable(Serializable):
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
x: str | Dict | List,
|
||||
x: str | dict | list,
|
||||
save_dir: str | Path | None = None,
|
||||
root_url: str | None = None,
|
||||
hf_token: str | None = None,
|
||||
@ -430,7 +430,7 @@ class JSONSerializable(Serializable):
|
||||
|
||||
|
||||
class GallerySerializable(Serializable):
|
||||
def api_info(self) -> Dict[str, List[str]]:
|
||||
def api_info(self) -> dict[str, list[str]]:
|
||||
return {
|
||||
"raw_input": [
|
||||
"List[List[str | None]]",
|
||||
@ -450,7 +450,7 @@ class GallerySerializable(Serializable):
|
||||
],
|
||||
}
|
||||
|
||||
def example_inputs(self) -> Dict[str, Any]:
|
||||
def example_inputs(self) -> dict[str, Any]:
|
||||
return {
|
||||
"raw": [media_data.BASE64_IMAGE] * 2,
|
||||
"serialized": [
|
||||
@ -461,7 +461,7 @@ class GallerySerializable(Serializable):
|
||||
|
||||
def serialize(
|
||||
self, x: str | None, load_dir: str | Path = ""
|
||||
) -> List[List[str | None]] | None:
|
||||
) -> list[list[str | None]] | None:
|
||||
if x is None or x == "":
|
||||
return None
|
||||
files = []
|
||||
@ -475,7 +475,7 @@ class GallerySerializable(Serializable):
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
x: List[List[str | None]] | None,
|
||||
x: list[list[str | None]] | None,
|
||||
save_dir: str = "",
|
||||
root_url: str | None = None,
|
||||
hf_token: str | None = None,
|
||||
@ -486,7 +486,7 @@ class GallerySerializable(Serializable):
|
||||
gallery_path.mkdir(exist_ok=True, parents=True)
|
||||
captions = {}
|
||||
for img_data in x:
|
||||
if isinstance(img_data, list) or isinstance(img_data, tuple):
|
||||
if isinstance(img_data, (list, tuple)):
|
||||
img_data, caption = img_data
|
||||
else:
|
||||
caption = None
|
||||
@ -510,7 +510,7 @@ SERIALIZER_MAPPING["Serializable"] = SimpleSerializable
|
||||
SERIALIZER_MAPPING["File"] = FileSerializable
|
||||
SERIALIZER_MAPPING["UploadButton"] = FileSerializable
|
||||
|
||||
COMPONENT_MAPPING: Dict[str, type] = {
|
||||
COMPONENT_MAPPING: dict[str, type] = {
|
||||
"textbox": StringSerializable,
|
||||
"number": NumberSerializable,
|
||||
"slider": NumberSerializable,
|
||||
|
@ -14,7 +14,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import fsspec.asyn
|
||||
import httpx
|
||||
@ -84,7 +84,7 @@ class Status(Enum):
|
||||
CANCELLED = "CANCELLED"
|
||||
|
||||
@staticmethod
|
||||
def ordering(status: "Status") -> int:
|
||||
def ordering(status: Status) -> int:
|
||||
"""Order of messages. Helpful for testing."""
|
||||
order = [
|
||||
Status.STARTING,
|
||||
@ -100,11 +100,11 @@ class Status(Enum):
|
||||
]
|
||||
return order.index(status)
|
||||
|
||||
def __lt__(self, other: "Status"):
|
||||
def __lt__(self, other: Status):
|
||||
return self.ordering(self) < self.ordering(other)
|
||||
|
||||
@staticmethod
|
||||
def msg_to_status(msg: str) -> "Status":
|
||||
def msg_to_status(msg: str) -> Status:
|
||||
"""Map the raw message from the backend to the status code presented to users."""
|
||||
return {
|
||||
"send_hash": Status.JOINING_QUEUE,
|
||||
@ -127,7 +127,7 @@ class ProgressUnit:
|
||||
desc: Optional[str]
|
||||
|
||||
@classmethod
|
||||
def from_ws_msg(cls, data: List[Dict]) -> List["ProgressUnit"]:
|
||||
def from_ws_msg(cls, data: list[dict]) -> list[ProgressUnit]:
|
||||
return [
|
||||
cls(
|
||||
index=d.get("index"),
|
||||
@ -150,7 +150,7 @@ class StatusUpdate:
|
||||
eta: float | None
|
||||
success: bool | None
|
||||
time: datetime | None
|
||||
progress_data: List[ProgressUnit] | None
|
||||
progress_data: list[ProgressUnit] | None
|
||||
|
||||
|
||||
def create_initial_status_update():
|
||||
@ -173,7 +173,7 @@ class JobStatus:
|
||||
"""
|
||||
|
||||
latest_status: StatusUpdate = field(default_factory=create_initial_status_update)
|
||||
outputs: List[Any] = field(default_factory=list)
|
||||
outputs: list[Any] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -182,7 +182,7 @@ class Communicator:
|
||||
|
||||
lock: Lock
|
||||
job: JobStatus
|
||||
deserialize: Callable[..., Tuple]
|
||||
deserialize: Callable[..., tuple]
|
||||
reset_url: str
|
||||
should_cancel: bool = False
|
||||
|
||||
@ -208,7 +208,7 @@ async def get_pred_from_ws(
|
||||
data: str,
|
||||
hash_data: str,
|
||||
helper: Communicator | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
completed = False
|
||||
resp = {}
|
||||
while not completed:
|
||||
@ -285,9 +285,10 @@ def download_tmp_copy_of_file(
|
||||
suffix=suffix,
|
||||
dir=dir,
|
||||
)
|
||||
with requests.get(url_path, headers=headers, stream=True) as r:
|
||||
with open(file_obj.name, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
with requests.get(url_path, headers=headers, stream=True) as r, open(
|
||||
file_obj.name, "wb"
|
||||
) as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
return file_obj
|
||||
|
||||
|
||||
@ -360,7 +361,7 @@ def encode_url_or_file_to_base64(path: str | Path):
|
||||
return encode_file_to_base64(path)
|
||||
|
||||
|
||||
def decode_base64_to_binary(encoding: str) -> Tuple[bytes, str | None]:
|
||||
def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]:
|
||||
extension = get_extension(encoding)
|
||||
data = encoding.rsplit(",", 1)[-1]
|
||||
return base64.b64decode(data), extension
|
||||
@ -421,7 +422,7 @@ def decode_base64_to_file(
|
||||
return file_obj
|
||||
|
||||
|
||||
def dict_or_str_to_json_file(jsn: str | Dict | List, dir: str | Path | None = None):
|
||||
def dict_or_str_to_json_file(jsn: str | dict | list, dir: str | Path | None = None):
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
@ -435,7 +436,7 @@ def dict_or_str_to_json_file(jsn: str | Dict | List, dir: str | Path | None = No
|
||||
return file_obj
|
||||
|
||||
|
||||
def file_to_json(file_path: str | Path) -> Dict | List:
|
||||
def file_to_json(file_path: str | Path) -> dict | list:
|
||||
with open(file_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
black==22.6.0
|
||||
pytest-asyncio
|
||||
pytest==7.1.2
|
||||
ruff==0.0.263
|
||||
ruff==0.0.264
|
||||
pyright==1.1.305
|
||||
gradio
|
||||
|
@ -48,8 +48,8 @@ class TestPredictionsFromSpaces:
|
||||
@pytest.mark.flaky
|
||||
def test_numerical_to_label_space(self):
|
||||
client = Client("gradio-tests/titanic-survival")
|
||||
output = client.predict("male", 77, 10, api_name="/predict")
|
||||
assert json.load(open(output))["label"] == "Perishes"
|
||||
with open(client.predict("male", 77, 10, api_name="/predict")) as f:
|
||||
assert json.load(f)["label"] == "Perishes"
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="This Gradio app might have multiple endpoints. Please specify an `api_name` or `fn_index`",
|
||||
@ -258,7 +258,8 @@ class TestPredictionsFromSpaces:
|
||||
f.write("Hello from private space!")
|
||||
|
||||
output = client.submit(1, "foo", f.name, api_name="/file_upload").result()
|
||||
assert open(output).read() == "Hello from private space!"
|
||||
with open(output) as f:
|
||||
assert f.read() == "Hello from private space!"
|
||||
upload.assert_called_once()
|
||||
|
||||
with patch.object(
|
||||
@ -267,24 +268,28 @@ class TestPredictionsFromSpaces:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
|
||||
f.write("Hello from private space!")
|
||||
|
||||
output = client.submit(f.name, api_name="/upload_btn").result()
|
||||
assert open(output).read() == "Hello from private space!"
|
||||
with open(client.submit(f.name, api_name="/upload_btn").result()) as f:
|
||||
assert f.read() == "Hello from private space!"
|
||||
upload.assert_called_once()
|
||||
|
||||
with patch.object(
|
||||
client.endpoints[2], "_upload", wraps=client.endpoints[0]._upload
|
||||
) as upload:
|
||||
# `delete=False` is required for Windows compat
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2:
|
||||
|
||||
f1.write("File1")
|
||||
f2.write("File2")
|
||||
|
||||
output = client.submit(
|
||||
3, [f1.name, f2.name], "hello", api_name="/upload_multiple"
|
||||
r1, r2 = client.submit(
|
||||
3,
|
||||
[f1.name, f2.name],
|
||||
"hello",
|
||||
api_name="/upload_multiple",
|
||||
).result()
|
||||
assert open(output[0]).read() == "File1"
|
||||
assert open(output[1]).read() == "File2"
|
||||
with open(r1) as f:
|
||||
assert f.read() == "File1"
|
||||
with open(r2) as f:
|
||||
assert f.read() == "File2"
|
||||
upload.assert_called_once()
|
||||
|
||||
@pytest.mark.flaky
|
||||
@ -588,10 +593,8 @@ class TestAPIInfo:
|
||||
def test_serializable_in_mapping(self, calculator_demo):
|
||||
with connect(calculator_demo) as client:
|
||||
assert all(
|
||||
[
|
||||
isinstance(c, SimpleSerializable)
|
||||
for c in client.endpoints[0].serializers
|
||||
]
|
||||
isinstance(c, SimpleSerializable)
|
||||
for c in client.endpoints[0].serializers
|
||||
)
|
||||
|
||||
@pytest.mark.flaky
|
||||
|
@ -34,8 +34,10 @@ def test_file_serializing():
|
||||
assert serializing.serialize(output) == output
|
||||
|
||||
files = serializing.deserialize(output)
|
||||
assert open(files[0]).read() == "Hello World!"
|
||||
assert open(files[1]).read() == "Greetings!"
|
||||
with open(files[0]) as f:
|
||||
assert f.read() == "Hello World!"
|
||||
with open(files[1]) as f:
|
||||
assert f.read() == "Greetings!"
|
||||
finally:
|
||||
os.remove(f1.name)
|
||||
os.remove(f2.name)
|
||||
|
@ -65,7 +65,7 @@ from gradio.flagging import (
|
||||
SimpleCSVLogger,
|
||||
)
|
||||
from gradio.helpers import EventData, Progress, make_waveform, skip, update
|
||||
from gradio.helpers import create_examples as Examples
|
||||
from gradio.helpers import create_examples as Examples # noqa: N812
|
||||
from gradio.interface import Interface, TabbedInterface, close_all
|
||||
from gradio.ipython_ext import load_ipython_extension
|
||||
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
|
||||
|
123
gradio/blocks.py
123
gradio/blocks.py
@ -12,7 +12,7 @@ import warnings
|
||||
import webbrowser
|
||||
from abc import abstractmethod
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterator
|
||||
|
||||
import anyio
|
||||
import requests
|
||||
@ -34,7 +34,7 @@ from gradio import (
|
||||
)
|
||||
from gradio.context import Context
|
||||
from gradio.deprecation import check_deprecated_parameters
|
||||
from gradio.exceptions import DuplicateBlockError, InvalidApiName
|
||||
from gradio.exceptions import DuplicateBlockError, InvalidApiNameError
|
||||
from gradio.helpers import EventData, create_tracker, skip, special_args
|
||||
from gradio.themes import Default as DefaultTheme
|
||||
from gradio.themes import ThemeClass as Theme
|
||||
@ -56,7 +56,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
|
||||
from gradio.components import Component
|
||||
|
||||
BUILT_IN_THEMES: Dict[str, Theme] = {
|
||||
BUILT_IN_THEMES: dict[str, Theme] = {
|
||||
t.name: t
|
||||
for t in [
|
||||
themes.Base(),
|
||||
@ -74,7 +74,7 @@ class Block:
|
||||
*,
|
||||
render: bool = True,
|
||||
elem_id: str | None = None,
|
||||
elem_classes: List[str] | str | None = None,
|
||||
elem_classes: list[str] | str | None = None,
|
||||
visible: bool = True,
|
||||
root_url: str | None = None, # URL that is prepended to all file paths
|
||||
_skip_init_processing: bool = False, # Used for loading from Spaces
|
||||
@ -145,15 +145,15 @@ class Block:
|
||||
else self.__class__.__name__.lower()
|
||||
)
|
||||
|
||||
def get_expected_parent(self) -> Type[BlockContext] | None:
|
||||
def get_expected_parent(self) -> type[BlockContext] | None:
|
||||
return None
|
||||
|
||||
def set_event_trigger(
|
||||
self,
|
||||
event_name: str,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None,
|
||||
outputs: Component | List[Component] | None,
|
||||
inputs: Component | list[Component] | set[Component] | None,
|
||||
outputs: Component | list[Component] | None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
scroll_to_output: bool = False,
|
||||
@ -164,12 +164,12 @@ class Block:
|
||||
queue: bool | None = None,
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
cancels: List[int] | None = None,
|
||||
cancels: list[int] | None = None,
|
||||
every: float | None = None,
|
||||
collects_event_data: bool | None = None,
|
||||
trigger_after: int | None = None,
|
||||
trigger_only_on_success: bool = False,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
"""
|
||||
Adds an event to the component's dependencies.
|
||||
Parameters:
|
||||
@ -251,7 +251,7 @@ class Block:
|
||||
api_name_ = utils.append_unique_suffix(
|
||||
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
|
||||
)
|
||||
if not (api_name == api_name_):
|
||||
if api_name != api_name_:
|
||||
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
|
||||
api_name = api_name_
|
||||
|
||||
@ -295,11 +295,11 @@ class Block:
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def update(**kwargs) -> Dict:
|
||||
def update(**kwargs) -> dict:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict:
|
||||
def get_specific_update(cls, generic_update: dict[str, Any]) -> dict:
|
||||
generic_update = generic_update.copy()
|
||||
del generic_update["__type__"]
|
||||
specific_update = cls.update(**generic_update)
|
||||
@ -318,7 +318,7 @@ class BlockContext(Block):
|
||||
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
|
||||
render: If False, this will not be included in the Blocks config file at all.
|
||||
"""
|
||||
self.children: List[Block] = []
|
||||
self.children: list[Block] = []
|
||||
Block.__init__(self, visible=visible, render=render, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
@ -368,8 +368,8 @@ class BlockFunction:
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable | None,
|
||||
inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
inputs: list[Component],
|
||||
outputs: list[Component],
|
||||
preprocess: bool,
|
||||
postprocess: bool,
|
||||
inputs_as_dict: bool,
|
||||
@ -399,13 +399,13 @@ class BlockFunction:
|
||||
return str(self)
|
||||
|
||||
|
||||
class class_or_instancemethod(classmethod):
|
||||
class class_or_instancemethod(classmethod): # noqa: N801
|
||||
def __get__(self, instance, type_):
|
||||
descr_get = super().__get__ if instance is None else self.__func__.__get__
|
||||
return descr_get(instance, type_)
|
||||
|
||||
|
||||
def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True):
|
||||
def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = True):
|
||||
"""
|
||||
Converts a dictionary of updates into a format that can be sent to the frontend.
|
||||
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
|
||||
@ -433,15 +433,15 @@ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool =
|
||||
|
||||
|
||||
def convert_component_dict_to_list(
|
||||
outputs_ids: List[int], predictions: Dict
|
||||
) -> List | Dict:
|
||||
outputs_ids: list[int], predictions: dict
|
||||
) -> list | dict:
|
||||
"""
|
||||
Converts a dictionary of component updates into a list of updates in the order of
|
||||
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
|
||||
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
|
||||
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
|
||||
"""
|
||||
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
|
||||
keys_are_blocks = [isinstance(key, Block) for key in predictions]
|
||||
if all(keys_are_blocks):
|
||||
reordered_predictions = [skip() for _ in outputs_ids]
|
||||
for component, value in predictions.items():
|
||||
@ -459,7 +459,7 @@ def convert_component_dict_to_list(
|
||||
return predictions
|
||||
|
||||
|
||||
def get_api_info(config: Dict, serialize: bool = True):
|
||||
def get_api_info(config: dict, serialize: bool = True):
|
||||
"""
|
||||
Gets the information needed to generate the API docs from a Blocks config.
|
||||
Parameters:
|
||||
@ -662,8 +662,8 @@ class Blocks(BlockContext):
|
||||
if not self.analytics_enabled:
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
|
||||
super().__init__(render=False, **kwargs)
|
||||
self.blocks: Dict[int, Block] = {}
|
||||
self.fns: List[BlockFunction] = []
|
||||
self.blocks: dict[int, Block] = {}
|
||||
self.fns: list[BlockFunction] = []
|
||||
self.dependencies = []
|
||||
self.mode = mode
|
||||
|
||||
@ -674,7 +674,7 @@ class Blocks(BlockContext):
|
||||
self.height = None
|
||||
self.api_open = True
|
||||
|
||||
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
|
||||
self.is_space = os.getenv("SYSTEM") == "spaces"
|
||||
self.favicon_path = None
|
||||
self.auth = None
|
||||
self.dev_mode = True
|
||||
@ -713,7 +713,7 @@ class Blocks(BlockContext):
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict,
|
||||
fns: List[Callable],
|
||||
fns: list[Callable],
|
||||
root_url: str | None = None,
|
||||
) -> Blocks:
|
||||
"""
|
||||
@ -727,7 +727,7 @@ class Blocks(BlockContext):
|
||||
config = copy.deepcopy(config)
|
||||
components_config = config["components"]
|
||||
theme = config.get("theme", "default")
|
||||
original_mapping: Dict[int, Block] = {}
|
||||
original_mapping: dict[int, Block] = {}
|
||||
|
||||
def get_block_instance(id: int) -> Block:
|
||||
for block_config in components_config:
|
||||
@ -862,7 +862,7 @@ class Blocks(BlockContext):
|
||||
api_name,
|
||||
[dep["api_name"] for dep in Context.root_block.dependencies],
|
||||
)
|
||||
if not (api_name == api_name_):
|
||||
if api_name != api_name_:
|
||||
warnings.warn(
|
||||
f"api_name {api_name} already exists, using {api_name_}"
|
||||
)
|
||||
@ -937,7 +937,9 @@ class Blocks(BlockContext):
|
||||
None,
|
||||
)
|
||||
if inferred_fn_index is None:
|
||||
raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
|
||||
raise InvalidApiNameError(
|
||||
f"Cannot find a function with api_name {api_name}"
|
||||
)
|
||||
fn_index = inferred_fn_index
|
||||
if not (self.is_callable(fn_index)):
|
||||
raise ValueError(
|
||||
@ -970,9 +972,9 @@ class Blocks(BlockContext):
|
||||
async def call_function(
|
||||
self,
|
||||
fn_index: int,
|
||||
processed_input: List[Any],
|
||||
processed_input: list[Any],
|
||||
iterator: Iterator[Any] | None = None,
|
||||
requests: routes.Request | List[routes.Request] | None = None,
|
||||
requests: routes.Request | list[routes.Request] | None = None,
|
||||
event_id: str | None = None,
|
||||
event_data: EventData | None = None,
|
||||
):
|
||||
@ -993,10 +995,7 @@ class Blocks(BlockContext):
|
||||
if block_fn.inputs_as_dict:
|
||||
processed_input = [dict(zip(block_fn.inputs, processed_input))]
|
||||
|
||||
if isinstance(requests, list):
|
||||
request = requests[0]
|
||||
else:
|
||||
request = requests
|
||||
request = requests[0] if isinstance(requests, list) else requests
|
||||
processed_input, progress_index, _ = special_args(
|
||||
block_fn.fn, processed_input, request, event_data
|
||||
)
|
||||
@ -1054,7 +1053,7 @@ class Blocks(BlockContext):
|
||||
"iterator": iterator,
|
||||
}
|
||||
|
||||
def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]:
|
||||
def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
|
||||
dependency = self.dependencies[fn_index]
|
||||
processed_input = []
|
||||
|
||||
@ -1068,7 +1067,7 @@ class Blocks(BlockContext):
|
||||
|
||||
return processed_input
|
||||
|
||||
def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]:
|
||||
def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
|
||||
dependency = self.dependencies[fn_index]
|
||||
predictions = []
|
||||
|
||||
@ -1084,7 +1083,7 @@ class Blocks(BlockContext):
|
||||
|
||||
return predictions
|
||||
|
||||
def validate_inputs(self, fn_index: int, inputs: List[Any]):
|
||||
def validate_inputs(self, fn_index: int, inputs: list[Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
@ -1106,10 +1105,7 @@ class Blocks(BlockContext):
|
||||
block = self.blocks[input_id]
|
||||
wanted_args.append(str(block))
|
||||
for inp in inputs:
|
||||
if isinstance(inp, str):
|
||||
v = f'"{inp}"'
|
||||
else:
|
||||
v = str(inp)
|
||||
v = f'"{inp}"' if isinstance(inp, str) else str(inp)
|
||||
received_args.append(v)
|
||||
|
||||
wanted = ", ".join(wanted_args)
|
||||
@ -1125,7 +1121,7 @@ Received inputs:
|
||||
[{received}]"""
|
||||
)
|
||||
|
||||
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
|
||||
def preprocess_data(self, fn_index: int, inputs: list[Any], state: dict[int, Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
@ -1146,7 +1142,7 @@ Received inputs:
|
||||
processed_input = inputs
|
||||
return processed_input
|
||||
|
||||
def validate_outputs(self, fn_index: int, predictions: Any | List[Any]):
|
||||
def validate_outputs(self, fn_index: int, predictions: Any | list[Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
@ -1168,10 +1164,7 @@ Received inputs:
|
||||
block = self.blocks[output_id]
|
||||
wanted_args.append(str(block))
|
||||
for pred in predictions:
|
||||
if isinstance(pred, str):
|
||||
v = f'"{pred}"'
|
||||
else:
|
||||
v = str(pred)
|
||||
v = f'"{pred}"' if isinstance(pred, str) else str(pred)
|
||||
received_args.append(v)
|
||||
|
||||
wanted = ", ".join(wanted_args)
|
||||
@ -1186,7 +1179,7 @@ Received outputs:
|
||||
)
|
||||
|
||||
def postprocess_data(
|
||||
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
|
||||
self, fn_index: int, predictions: list | dict, state: dict[int, Any]
|
||||
):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
@ -1241,13 +1234,13 @@ Received outputs:
|
||||
async def process_api(
|
||||
self,
|
||||
fn_index: int,
|
||||
inputs: List[Any],
|
||||
state: Dict[int, Any],
|
||||
request: routes.Request | List[routes.Request] | None = None,
|
||||
iterators: Dict[int, Any] | None = None,
|
||||
inputs: list[Any],
|
||||
state: dict[int, Any],
|
||||
request: routes.Request | list[routes.Request] | None = None,
|
||||
iterators: dict[int, Any] | None = None,
|
||||
event_id: str | None = None,
|
||||
event_data: EventData | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Processes API calls from the frontend. First preprocesses the data,
|
||||
then runs the relevant function, then postprocesses the output.
|
||||
@ -1342,15 +1335,15 @@ Received outputs:
|
||||
"theme": self.theme.name,
|
||||
}
|
||||
|
||||
def getLayout(block):
|
||||
def get_layout(block):
|
||||
if not isinstance(block, BlockContext):
|
||||
return {"id": block._id}
|
||||
children_layout = []
|
||||
for child in block.children:
|
||||
children_layout.append(getLayout(child))
|
||||
children_layout.append(get_layout(child))
|
||||
return {"id": block._id, "children": children_layout}
|
||||
|
||||
config["layout"] = getLayout(self)
|
||||
config["layout"] = get_layout(self)
|
||||
|
||||
for _id, block in self.blocks.items():
|
||||
props = block.get_config() if hasattr(block, "get_config") else {}
|
||||
@ -1393,10 +1386,10 @@ Received outputs:
|
||||
|
||||
@class_or_instancemethod
|
||||
def load(
|
||||
self_or_cls,
|
||||
self_or_cls, # noqa: N805
|
||||
fn: Callable | None = None,
|
||||
inputs: List[Component] | None = None,
|
||||
outputs: List[Component] | None = None,
|
||||
inputs: list[Component] | None = None,
|
||||
outputs: list[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
scroll_to_output: bool = False,
|
||||
show_progress: bool = True,
|
||||
@ -1413,7 +1406,7 @@ Received outputs:
|
||||
api_key: str | None = None,
|
||||
alias: str | None = None,
|
||||
**kwargs,
|
||||
) -> Blocks | Dict[str, Any] | None:
|
||||
) -> Blocks | dict[str, Any] | None:
|
||||
"""
|
||||
For reverse compatibility reasons, this is both a class method and an instance
|
||||
method, the two of which, confusingly, do two completely different things.
|
||||
@ -1571,7 +1564,7 @@ Received outputs:
|
||||
debug: bool = False,
|
||||
enable_queue: bool | None = None,
|
||||
max_threads: int = 40,
|
||||
auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None,
|
||||
auth: Callable | tuple[str, str] | list[tuple[str, str]] | None = None,
|
||||
auth_message: str | None = None,
|
||||
prevent_thread_lock: bool = False,
|
||||
show_error: bool = False,
|
||||
@ -1588,11 +1581,11 @@ Received outputs:
|
||||
ssl_verify: bool = True,
|
||||
quiet: bool = False,
|
||||
show_api: bool = True,
|
||||
file_directories: List[str] | None = None,
|
||||
allowed_paths: List[str] | None = None,
|
||||
blocked_paths: List[str] | None = None,
|
||||
file_directories: list[str] | None = None,
|
||||
allowed_paths: list[str] | None = None,
|
||||
blocked_paths: list[str] | None = None,
|
||||
_frontend: bool = True,
|
||||
) -> Tuple[FastAPI, str, str]:
|
||||
) -> tuple[FastAPI, str, str]:
|
||||
"""
|
||||
Launches a simple web server that serves the demo. Can also be used to create a
|
||||
public link used by anyone to access the demo from their browser by setting share=True.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,7 +4,7 @@ of the on-page-load event, which is defined in gr.Blocks().load()."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
@ -19,7 +19,7 @@ set_documentation_group("events")
|
||||
|
||||
|
||||
def set_cancel_events(
|
||||
block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
|
||||
block: Block, event_name: str, cancels: None | dict[str, Any] | list[dict[str, Any]]
|
||||
):
|
||||
if cancels:
|
||||
if not isinstance(cancels, list):
|
||||
@ -91,8 +91,8 @@ class EventListenerMethod:
|
||||
def __call__(
|
||||
self,
|
||||
fn: Callable | None,
|
||||
inputs: Component | List[Component] | Set[Component] | None = None,
|
||||
outputs: Component | List[Component] | None = None,
|
||||
inputs: Component | list[Component] | set[Component] | None = None,
|
||||
outputs: Component | list[Component] | None = None,
|
||||
api_name: str | None = None,
|
||||
status_tracker: StatusTracker | None = None,
|
||||
scroll_to_output: bool = False,
|
||||
@ -102,7 +102,7 @@ class EventListenerMethod:
|
||||
max_batch_size: int = 4,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
|
||||
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
|
||||
every: float | None = None,
|
||||
_js: str | None = None,
|
||||
) -> Dependency:
|
||||
@ -290,7 +290,7 @@ class Selectable(EventListener):
|
||||
class SelectData(EventData):
|
||||
def __init__(self, target: Block | None, data: Any):
|
||||
super().__init__(target, data)
|
||||
self.index: int | Tuple[int, int] = data["index"]
|
||||
self.index: int | tuple[int, int] = data["index"]
|
||||
"""
|
||||
The index of the selected item. Is a tuple if the component is two dimensional or selection is a range.
|
||||
"""
|
||||
|
@ -15,10 +15,13 @@ class TooManyRequestsError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidApiName(ValueError):
|
||||
class InvalidApiNameError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
InvalidApiName = InvalidApiNameError # backwards compatibility
|
||||
|
||||
|
||||
@document()
|
||||
class Error(Exception):
|
||||
"""
|
||||
|
@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Callable, Dict
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import requests
|
||||
from gradio_client import Client
|
||||
@ -87,7 +87,7 @@ def load_blocks_from_repo(
|
||||
src = tokens[0]
|
||||
name = "/".join(tokens[1:])
|
||||
|
||||
factory_methods: Dict[str, Callable] = {
|
||||
factory_methods: dict[str, Callable] = {
|
||||
# for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
|
||||
"huggingface": from_model,
|
||||
"models": from_model,
|
||||
@ -393,7 +393,7 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
|
||||
data.update({"options": {"wait_for_model": True}})
|
||||
data = json.dumps(data)
|
||||
response = requests.request("POST", api_url, headers=headers, data=data)
|
||||
if not (response.status_code == 200):
|
||||
if response.status_code != 200:
|
||||
errors_json = response.json()
|
||||
errors, warns = "", ""
|
||||
if errors_json.get("error"):
|
||||
@ -494,7 +494,7 @@ def from_spaces_blocks(space: str, api_key: str | None) -> Blocks:
|
||||
|
||||
def from_spaces_interface(
|
||||
model_name: str,
|
||||
config: Dict,
|
||||
config: dict,
|
||||
alias: str | None,
|
||||
api_key: str | None,
|
||||
iframe_url: str,
|
||||
|
@ -10,7 +10,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from distutils.version import StrictVersion
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import filelock
|
||||
import huggingface_hub
|
||||
@ -33,7 +33,7 @@ class FlaggingCallback(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def setup(self, components: List[IOComponent], flagging_dir: str):
|
||||
def setup(self, components: list[IOComponent], flagging_dir: str):
|
||||
"""
|
||||
This method should be overridden and ensure that everything is set up correctly for flag().
|
||||
This method gets called once at the beginning of the Interface.launch() method.
|
||||
@ -46,7 +46,7 @@ class FlaggingCallback(ABC):
|
||||
@abstractmethod
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
@ -81,14 +81,14 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
|
||||
def setup(self, components: list[IOComponent], flagging_dir: str | Path):
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
@ -112,7 +112,7 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
||||
|
||||
with open(log_filepath, "r") as csvfile:
|
||||
with open(log_filepath) as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
return line_count
|
||||
|
||||
@ -136,7 +136,7 @@ class CSVLogger(FlaggingCallback):
|
||||
|
||||
def setup(
|
||||
self,
|
||||
components: List[IOComponent],
|
||||
components: list[IOComponent],
|
||||
flagging_dir: str | Path,
|
||||
):
|
||||
self.components = components
|
||||
@ -145,7 +145,7 @@ class CSVLogger(FlaggingCallback):
|
||||
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
@ -186,7 +186,7 @@ class CSVLogger(FlaggingCallback):
|
||||
writer.writerow(utils.sanitize_list_for_csv(headers))
|
||||
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
||||
|
||||
with open(log_filepath, "r", encoding="utf-8") as csvfile:
|
||||
with open(log_filepath, encoding="utf-8") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
return line_count
|
||||
|
||||
@ -235,7 +235,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
self.info_filename = info_filename
|
||||
self.separate_dirs = separate_dirs
|
||||
|
||||
def setup(self, components: List[IOComponent], flagging_dir: str):
|
||||
def setup(self, components: list[IOComponent], flagging_dir: str):
|
||||
"""
|
||||
Params:
|
||||
flagging_dir (str): local directory where the dataset is cloned,
|
||||
@ -286,7 +286,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
pass
|
||||
|
||||
def flag(self, flag_data: List[Any], flag_option: str = "") -> int:
|
||||
def flag(self, flag_data: list[Any], flag_option: str = "") -> int:
|
||||
if self.separate_dirs:
|
||||
# JSONL files to support dataset preview on the Hub
|
||||
unique_id = str(uuid.uuid4())
|
||||
@ -312,7 +312,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
data_file: Path,
|
||||
components_dir: Path,
|
||||
path_in_repo: str | None,
|
||||
flag_data: List[Any],
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
) -> int:
|
||||
# Deserialize components (write images/audio to files)
|
||||
@ -368,7 +368,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
return sample_nb
|
||||
|
||||
@staticmethod
|
||||
def _save_as_csv(data_file: Path, headers: List[str], row: List[Any]) -> int:
|
||||
def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
|
||||
"""Save data as CSV and return the sample name (row number)."""
|
||||
is_new = not data_file.exists()
|
||||
|
||||
@ -386,7 +386,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
return sum(1 for _ in csv.reader(csvfile)) - 1
|
||||
|
||||
@staticmethod
|
||||
def _save_as_jsonl(data_file: Path, headers: List[str], row: List[Any]) -> str:
|
||||
def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
|
||||
"""Save data as JSONL and return the sample name (uuid)."""
|
||||
Path.mkdir(data_file.parent, parents=True, exist_ok=True)
|
||||
with open(data_file, "w") as f:
|
||||
@ -394,15 +394,15 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
return data_file.parent.name
|
||||
|
||||
def _deserialize_components(
|
||||
self, data_dir: Path, flag_data: List[Any], flag_option: str = ""
|
||||
) -> Tuple[Dict[Any, Any], List[Any]]:
|
||||
self, data_dir: Path, flag_data: list[Any], flag_option: str = ""
|
||||
) -> tuple[dict[Any, Any], list[Any]]:
|
||||
"""Deserialize components and return the corresponding row for the flagged sample.
|
||||
|
||||
Images/audio are saved to disk as individual files.
|
||||
"""
|
||||
# Components that can have a preview on dataset repos
|
||||
# NOTE: not at root level to avoid circular imports
|
||||
FILE_PREVIEW_TYPES = {gr.Audio: "Audio", gr.Image: "Image"}
|
||||
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
|
||||
|
||||
# Generate the row corresponding to the flagged sample
|
||||
features = {}
|
||||
@ -418,8 +418,8 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
row.append(Path(deserialized).name)
|
||||
|
||||
# If component is eligible for a preview, add the URL of the file
|
||||
if isinstance(component, tuple(FILE_PREVIEW_TYPES)): # type: ignore
|
||||
for _component, _type in FILE_PREVIEW_TYPES.items():
|
||||
if isinstance(component, tuple(file_preview_types)): # type: ignore
|
||||
for _component, _type in file_preview_types.items():
|
||||
if isinstance(component, _component):
|
||||
features[label + " file"] = {"_type": _type}
|
||||
break
|
||||
|
@ -12,7 +12,7 @@ import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@ -36,9 +36,9 @@ set_documentation_group("helpers")
|
||||
|
||||
|
||||
def create_examples(
|
||||
examples: List[Any] | List[List[Any]] | str,
|
||||
inputs: IOComponent | List[IOComponent],
|
||||
outputs: IOComponent | List[IOComponent] | None = None,
|
||||
examples: list[Any] | list[list[Any]] | str,
|
||||
inputs: IOComponent | list[IOComponent],
|
||||
outputs: IOComponent | list[IOComponent] | None = None,
|
||||
fn: Callable | None = None,
|
||||
cache_examples: bool = False,
|
||||
examples_per_page: int = 10,
|
||||
@ -85,9 +85,9 @@ class Examples:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
examples: List[Any] | List[List[Any]] | str,
|
||||
inputs: IOComponent | List[IOComponent],
|
||||
outputs: IOComponent | List[IOComponent] | None = None,
|
||||
examples: list[Any] | list[list[Any]] | str,
|
||||
inputs: IOComponent | list[IOComponent],
|
||||
outputs: IOComponent | list[IOComponent] | None = None,
|
||||
fn: Callable | None = None,
|
||||
cache_examples: bool = False,
|
||||
examples_per_page: int = 10,
|
||||
@ -323,7 +323,7 @@ class Examples:
|
||||
Context.root_block.dependencies.remove(dependency)
|
||||
Context.root_block.fns.pop(fn_index)
|
||||
|
||||
async def load_from_cache(self, example_id: int) -> List[Any]:
|
||||
async def load_from_cache(self, example_id: int) -> list[Any]:
|
||||
"""Loads a particular cached example for the interface.
|
||||
Parameters:
|
||||
example_id: The id of the example to process (zero-indexed).
|
||||
@ -396,7 +396,7 @@ class Progress(Iterable):
|
||||
self.track_tqdm = track_tqdm
|
||||
self._callback = _callback
|
||||
self._event_id = _event_id
|
||||
self.iterables: List[TrackedIterable] = []
|
||||
self.iterables: list[TrackedIterable] = []
|
||||
|
||||
def __len__(self):
|
||||
return self.iterables[-1].length
|
||||
@ -431,7 +431,7 @@ class Progress(Iterable):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
progress: float | Tuple[int, int | None] | None,
|
||||
progress: float | tuple[int, int | None] | None,
|
||||
desc: str | None = None,
|
||||
total: int | None = None,
|
||||
unit: str = "steps",
|
||||
@ -541,7 +541,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm):
|
||||
if self._progress is not None:
|
||||
self._progress.event_id = event_id
|
||||
self._progress.tqdm(iterable, desc, _tqdm=self)
|
||||
kwargs["file"] = open(os.devnull, "w")
|
||||
kwargs["file"] = open(os.devnull, "w") # noqa: SIM115
|
||||
self.__init__orig__(iterable, desc, *args, **kwargs)
|
||||
|
||||
def iter_tqdm(self):
|
||||
@ -595,7 +595,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm):
|
||||
|
||||
def special_args(
|
||||
fn: Callable,
|
||||
inputs: List[Any] | None = None,
|
||||
inputs: list[Any] | None = None,
|
||||
request: routes.Request | None = None,
|
||||
event_data: EventData | None = None,
|
||||
):
|
||||
@ -632,9 +632,10 @@ def special_args(
|
||||
event_data_index = i
|
||||
if inputs is not None and event_data is not None:
|
||||
inputs.insert(i, param.annotation(event_data.target, event_data._data))
|
||||
elif param.default is not param.empty:
|
||||
if inputs is not None and len(inputs) <= i:
|
||||
inputs.insert(i, param.default)
|
||||
elif (
|
||||
param.default is not param.empty and inputs is not None and len(inputs) <= i
|
||||
):
|
||||
inputs.insert(i, param.default)
|
||||
if inputs is not None:
|
||||
while len(inputs) < len(positional_args):
|
||||
i = len(inputs)
|
||||
@ -696,12 +697,12 @@ def skip() -> dict:
|
||||
|
||||
@document()
|
||||
def make_waveform(
|
||||
audio: str | Tuple[int, np.ndarray],
|
||||
audio: str | tuple[int, np.ndarray],
|
||||
*,
|
||||
bg_color: str = "#f3f4f6",
|
||||
bg_image: str | None = None,
|
||||
fg_alpha: float = 0.75,
|
||||
bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"),
|
||||
bars_color: str | tuple[str, str] = ("#fbbf24", "#ea580c"),
|
||||
bar_count: int = 50,
|
||||
bar_width: float = 0.6,
|
||||
):
|
||||
@ -728,13 +729,13 @@ def make_waveform(
|
||||
duration = round(len(audio[1]) / audio[0], 4)
|
||||
|
||||
# Helper methods to create waveform
|
||||
def hex_to_RGB(hex_str):
|
||||
def hex_to_rgb(hex_str):
|
||||
return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
|
||||
|
||||
def get_color_gradient(c1, c2, n):
|
||||
assert n > 1
|
||||
c1_rgb = np.array(hex_to_RGB(c1)) / 255
|
||||
c2_rgb = np.array(hex_to_RGB(c2)) / 255
|
||||
c1_rgb = np.array(hex_to_rgb(c1)) / 255
|
||||
c2_rgb = np.array(hex_to_rgb(c2)) / 255
|
||||
mix_pcts = [x / (n - 1) for x in range(n)]
|
||||
rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
|
||||
return [
|
||||
@ -770,7 +771,7 @@ def make_waveform(
|
||||
plt.axis("off")
|
||||
plt.margins(x=0)
|
||||
tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||
savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"}
|
||||
savefig_kwargs: dict[str, Any] = {"bbox_inches": "tight"}
|
||||
if bg_image is not None:
|
||||
savefig_kwargs["transparent"] = True
|
||||
else:
|
||||
|
@ -8,7 +8,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from gradio import components
|
||||
|
||||
@ -132,8 +132,8 @@ class CheckboxGroup(components.CheckboxGroup):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choices: List[str],
|
||||
default: List[str] | None = None,
|
||||
choices: list[str],
|
||||
default: list[str] | None = None,
|
||||
type: str = "value",
|
||||
label: Optional[str] = None,
|
||||
optional: bool = False,
|
||||
@ -168,7 +168,7 @@ class Radio(components.Radio):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choices: List[str],
|
||||
choices: list[str],
|
||||
type: str = "value",
|
||||
default: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
@ -202,7 +202,7 @@ class Dropdown(components.Dropdown):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choices: List[str],
|
||||
choices: list[str],
|
||||
type: str = "value",
|
||||
default: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
@ -236,7 +236,7 @@ class Image(components.Image):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: Tuple[int, int] = None,
|
||||
shape: tuple[int, int] = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "upload",
|
||||
@ -366,12 +366,12 @@ class Dataframe(components.Dataframe):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: Optional[List[str]] = None,
|
||||
headers: Optional[list[str]] = None,
|
||||
row_count: int = 3,
|
||||
col_count: Optional[int] = 3,
|
||||
datatype: str | List[str] = "str",
|
||||
col_width: int | List[int] = None,
|
||||
default: Optional[List[List[Any]]] = None,
|
||||
datatype: str | list[str] = "str",
|
||||
col_width: int | list[int] = None,
|
||||
default: Optional[list[list[Any]]] = None,
|
||||
type: str = "pandas",
|
||||
label: Optional[str] = None,
|
||||
optional: bool = False,
|
||||
@ -413,7 +413,7 @@ class Timeseries(components.Timeseries):
|
||||
def __init__(
|
||||
self,
|
||||
x: Optional[str] = None,
|
||||
y: str | List[str] = None,
|
||||
y: str | list[str] = None,
|
||||
label: Optional[str] = None,
|
||||
optional: bool = False,
|
||||
):
|
||||
|
@ -8,10 +8,9 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
@ -63,7 +62,7 @@ class Interface(Blocks):
|
||||
instances: weakref.WeakSet = weakref.WeakSet()
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls) -> List[Interface]:
|
||||
def get_instances(cls) -> list[Interface]:
|
||||
"""
|
||||
:return: list of all current instances.
|
||||
"""
|
||||
@ -119,9 +118,9 @@ class Interface(Blocks):
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: str | IOComponent | List[str | IOComponent] | None,
|
||||
outputs: str | IOComponent | List[str | IOComponent] | None,
|
||||
examples: List[Any] | List[List[Any]] | str | None = None,
|
||||
inputs: str | IOComponent | list[str | IOComponent] | None,
|
||||
outputs: str | IOComponent | list[str | IOComponent] | None,
|
||||
examples: list[Any] | list[list[Any]] | str | None = None,
|
||||
cache_examples: bool | None = None,
|
||||
examples_per_page: int = 10,
|
||||
live: bool = False,
|
||||
@ -134,7 +133,7 @@ class Interface(Blocks):
|
||||
theme: Theme | str | None = None,
|
||||
css: str | None = None,
|
||||
allow_flagging: str | None = None,
|
||||
flagging_options: List[str] | List[Tuple[str, str]] | None = None,
|
||||
flagging_options: list[str] | list[tuple[str, str]] | None = None,
|
||||
flagging_dir: str = "flagged",
|
||||
flagging_callback: FlaggingCallback = CSVLogger(),
|
||||
analytics_enabled: bool | None = None,
|
||||
@ -287,17 +286,11 @@ class Interface(Blocks):
|
||||
self.live = live
|
||||
self.title = title
|
||||
|
||||
CLEANER = re.compile("<.*?>")
|
||||
|
||||
def clean_html(raw_html):
|
||||
cleantext = re.sub(CLEANER, "", raw_html)
|
||||
return cleantext
|
||||
|
||||
md = utils.get_markdown_parser()
|
||||
simple_description = None
|
||||
simple_description: str | None = None
|
||||
if description is not None:
|
||||
description = md.render(description)
|
||||
simple_description = clean_html(description)
|
||||
simple_description = utils.remove_html_tags(description)
|
||||
self.simple_description = simple_description
|
||||
self.description = description
|
||||
if article is not None:
|
||||
@ -466,19 +459,19 @@ class Interface(Blocks):
|
||||
if self.description:
|
||||
Markdown(self.description)
|
||||
|
||||
def render_flag_btns(self) -> List[Button]:
|
||||
def render_flag_btns(self) -> list[Button]:
|
||||
return [Button(label) for label, _ in self.flagging_options]
|
||||
|
||||
def render_input_column(
|
||||
self,
|
||||
) -> Tuple[
|
||||
) -> tuple[
|
||||
Button | None,
|
||||
Button | None,
|
||||
Button | None,
|
||||
List[Button] | None,
|
||||
list[Button] | None,
|
||||
Column,
|
||||
Column | None,
|
||||
List[Interpretation] | None,
|
||||
list[Interpretation] | None,
|
||||
]:
|
||||
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
|
||||
interpret_component_column, interpretation_set = None, None
|
||||
@ -531,7 +524,7 @@ class Interface(Blocks):
|
||||
def render_output_column(
|
||||
self,
|
||||
submit_btn_in: Button | None,
|
||||
) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
|
||||
) -> tuple[Button | None, Button | None, Button | None, list | None, Button | None]:
|
||||
submit_btn = submit_btn_in
|
||||
interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
|
||||
|
||||
@ -699,7 +692,7 @@ class Interface(Blocks):
|
||||
def attach_interpretation_events(
|
||||
self,
|
||||
interpretation_btn: Button | None,
|
||||
interpretation_set: List[Interpretation] | None,
|
||||
interpretation_set: list[Interpretation] | None,
|
||||
input_component_column: Column | None,
|
||||
interpret_component_column: Column | None,
|
||||
):
|
||||
@ -711,53 +704,58 @@ class Interface(Blocks):
|
||||
preprocess=False,
|
||||
)
|
||||
|
||||
def attach_flagging_events(self, flag_btns: List[Button] | None, clear_btn: Button):
|
||||
if flag_btns:
|
||||
if self.interface_type in [
|
||||
def attach_flagging_events(self, flag_btns: list[Button] | None, clear_btn: Button):
|
||||
if not (
|
||||
flag_btns
|
||||
and self.interface_type
|
||||
in (
|
||||
InterfaceTypes.STANDARD,
|
||||
InterfaceTypes.OUTPUT_ONLY,
|
||||
InterfaceTypes.UNIFIED,
|
||||
]:
|
||||
if self.allow_flagging == "auto":
|
||||
flag_method = FlagMethod(
|
||||
self.flagging_callback, "", "", visual_feedback=False
|
||||
)
|
||||
flag_btns[0].click( # flag_btns[0] is just the "Submit" button
|
||||
flag_method,
|
||||
inputs=self.input_components,
|
||||
outputs=None,
|
||||
preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
return
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
if self.interface_type == InterfaceTypes.UNIFIED:
|
||||
flag_components = self.input_components
|
||||
else:
|
||||
flag_components = self.input_components + self.output_components
|
||||
if self.allow_flagging == "auto":
|
||||
flag_method = FlagMethod(
|
||||
self.flagging_callback, "", "", visual_feedback=False
|
||||
)
|
||||
flag_btns[0].click( # flag_btns[0] is just the "Submit" button
|
||||
flag_method,
|
||||
inputs=self.input_components,
|
||||
outputs=None,
|
||||
preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
return
|
||||
|
||||
for flag_btn, (label, value) in zip(flag_btns, self.flagging_options):
|
||||
assert isinstance(value, str)
|
||||
flag_method = FlagMethod(self.flagging_callback, label, value)
|
||||
flag_btn.click(
|
||||
lambda: Button.update(value="Saving...", interactive=False),
|
||||
None,
|
||||
flag_btn,
|
||||
queue=False,
|
||||
)
|
||||
flag_btn.click(
|
||||
flag_method,
|
||||
inputs=flag_components,
|
||||
outputs=flag_btn,
|
||||
preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
clear_btn.click(
|
||||
flag_method.reset,
|
||||
None,
|
||||
flag_btn,
|
||||
queue=False,
|
||||
)
|
||||
if self.interface_type == InterfaceTypes.UNIFIED:
|
||||
flag_components = self.input_components
|
||||
else:
|
||||
flag_components = self.input_components + self.output_components
|
||||
|
||||
for flag_btn, (label, value) in zip(flag_btns, self.flagging_options):
|
||||
assert isinstance(value, str)
|
||||
flag_method = FlagMethod(self.flagging_callback, label, value)
|
||||
flag_btn.click(
|
||||
lambda: Button.update(value="Saving...", interactive=False),
|
||||
None,
|
||||
flag_btn,
|
||||
queue=False,
|
||||
)
|
||||
flag_btn.click(
|
||||
flag_method,
|
||||
inputs=flag_components,
|
||||
outputs=flag_btn,
|
||||
preprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
clear_btn.click(
|
||||
flag_method.reset,
|
||||
None,
|
||||
flag_btn,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def render_examples(self):
|
||||
if self.examples:
|
||||
@ -798,7 +796,7 @@ class Interface(Blocks):
|
||||
Column.update(visible=True),
|
||||
]
|
||||
|
||||
async def interpret(self, raw_input: List[Any]) -> List[Any]:
|
||||
async def interpret(self, raw_input: list[Any]) -> list[Any]:
|
||||
return [
|
||||
{"original": raw_value, "interpretation": interpretation}
|
||||
for interpretation, raw_value in zip(
|
||||
@ -823,8 +821,8 @@ class TabbedInterface(Blocks):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interface_list: List[Interface],
|
||||
tab_names: List[str] | None = None,
|
||||
interface_list: list[Interface],
|
||||
tab_names: list[str] | None = None,
|
||||
title: str | None = None,
|
||||
theme: Theme | None = None,
|
||||
analytics_enabled: bool | None = None,
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import copy
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
from gradio_client import utils as client_utils
|
||||
@ -28,8 +28,8 @@ class Interpretable(ABC): # noqa: B024
|
||||
pass
|
||||
|
||||
def get_interpretation_scores(
|
||||
self, x: Any, neighbors: List[Any] | None, scores: List[float], **kwargs
|
||||
) -> List:
|
||||
self, x: Any, neighbors: list[Any] | None, scores: list[float], **kwargs
|
||||
) -> list:
|
||||
"""
|
||||
Arrange the output values from the neighbors into interpretation scores for the interface to render.
|
||||
Parameters:
|
||||
@ -44,7 +44,7 @@ class Interpretable(ABC): # noqa: B024
|
||||
|
||||
class TokenInterpretable(Interpretable, ABC):
|
||||
@abstractmethod
|
||||
def tokenize(self, x: Any) -> Tuple[List, List, None]:
|
||||
def tokenize(self, x: Any) -> tuple[list, list, None]:
|
||||
"""
|
||||
Interprets an input data point x by splitting it into a list of tokens (e.g
|
||||
a string into words or an image into super-pixels).
|
||||
@ -52,13 +52,13 @@ class TokenInterpretable(Interpretable, ABC):
|
||||
return [], [], None
|
||||
|
||||
@abstractmethod
|
||||
def get_masked_inputs(self, tokens: List, binary_mask_matrix: List[List]) -> List:
|
||||
def get_masked_inputs(self, tokens: list, binary_mask_matrix: list[list]) -> list:
|
||||
return []
|
||||
|
||||
|
||||
class NeighborInterpretable(Interpretable, ABC):
|
||||
@abstractmethod
|
||||
def get_interpretation_neighbors(self, x: Any) -> Tuple[List, Dict]:
|
||||
def get_interpretation_neighbors(self, x: Any) -> tuple[list, dict]:
|
||||
"""
|
||||
Generates values similar to input to be used to interpret the significance of the input in the final output.
|
||||
Parameters:
|
||||
@ -70,7 +70,7 @@ class NeighborInterpretable(Interpretable, ABC):
|
||||
return [], {}
|
||||
|
||||
|
||||
async def run_interpret(interface: Interface, raw_input: List):
|
||||
async def run_interpret(interface: Interface, raw_input: list):
|
||||
"""
|
||||
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
|
||||
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
||||
@ -265,12 +265,12 @@ def diff(original: Any, perturbed: Any) -> int | float:
|
||||
try: # try computing numerical difference
|
||||
score = float(original) - float(perturbed)
|
||||
except ValueError: # otherwise, look at strict difference in label
|
||||
score = int(not (original == perturbed))
|
||||
score = int(original != perturbed)
|
||||
return score
|
||||
|
||||
|
||||
def quantify_difference_in_label(
|
||||
interface: Interface, original_output: List, perturbed_output: List
|
||||
interface: Interface, original_output: list, perturbed_output: list
|
||||
) -> int | float:
|
||||
output_component = interface.output_components[0]
|
||||
post_original_output = output_component.postprocess(original_output[0])
|
||||
@ -300,7 +300,7 @@ def quantify_difference_in_label(
|
||||
|
||||
|
||||
def get_regression_or_classification_value(
|
||||
interface: Interface, original_output: List, perturbed_output: List
|
||||
interface: Interface, original_output: list, perturbed_output: list
|
||||
) -> int | float:
|
||||
"""Used to combine regression/classification for Shap interpretation method."""
|
||||
output_component = interface.output_components[0]
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Type
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
@ -216,7 +215,7 @@ class Tab(BlockContext, Selectable):
|
||||
**super(BlockContext, self).get_config(),
|
||||
}
|
||||
|
||||
def get_expected_parent(self) -> Type[Tabs]:
|
||||
def get_expected_parent(self) -> type[Tabs]:
|
||||
return Tabs
|
||||
|
||||
def get_block_name(self):
|
||||
|
@ -9,7 +9,7 @@ import socket
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
@ -89,7 +89,7 @@ def start_server(
|
||||
ssl_keyfile: str | None = None,
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
) -> Tuple[str, int, str, App, Server]:
|
||||
) -> tuple[str, int, str, App, Server]:
|
||||
"""Launches a local server running the provided Interface
|
||||
Parameters:
|
||||
blocks: The Blocks object to run on the server
|
||||
|
@ -8,7 +8,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from gradio import components
|
||||
|
||||
@ -109,7 +109,7 @@ class Dataframe(components.Dataframe):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers: Optional[List[str]] = None,
|
||||
headers: Optional[list[str]] = None,
|
||||
max_rows: Optional[int] = 20,
|
||||
max_cols: Optional[int] = None,
|
||||
overflow_row_behaviour: str = "paginate",
|
||||
@ -145,7 +145,7 @@ class Timeseries(components.Timeseries):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, x: str = None, y: str | List[str] = None, label: Optional[str] = None
|
||||
self, x: str = None, y: str | list[str] = None, label: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@ -227,7 +227,7 @@ class HighlightedText(components.HighlightedText):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
color_map: Dict[str, str] = None,
|
||||
color_map: dict[str, str] = None,
|
||||
label: Optional[str] = None,
|
||||
show_legend: bool = False,
|
||||
):
|
||||
@ -281,7 +281,7 @@ class Carousel(components.Carousel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
components: components.Component | List[components.Component],
|
||||
components: components.Component | list[components.Component],
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
|
@ -3,7 +3,7 @@ please use the `gr.Interface.from_pipeline()` function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from gradio import components
|
||||
|
||||
@ -11,7 +11,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from transformers import pipelines
|
||||
|
||||
|
||||
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
|
||||
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
|
||||
"""
|
||||
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
|
||||
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
|
||||
|
@ -8,7 +8,6 @@ import tempfile
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from ffmpy import FFmpeg, FFprobe, FFRuntimeError
|
||||
@ -25,7 +24,7 @@ with warnings.catch_warnings():
|
||||
#########################
|
||||
|
||||
|
||||
def to_binary(x: str | Dict) -> bytes:
|
||||
def to_binary(x: str | dict) -> bytes:
|
||||
"""Converts a base64 string or dictionary to a binary string that can be sent in a POST."""
|
||||
if isinstance(x, dict):
|
||||
if x.get("data"):
|
||||
@ -362,10 +361,7 @@ def _convert(image, dtype, force_copy=False, uniform=False):
|
||||
|
||||
image = np.asarray(image)
|
||||
dtypeobj_in = image.dtype
|
||||
if dtype is np.floating:
|
||||
dtypeobj_out = np.dtype("float64")
|
||||
else:
|
||||
dtypeobj_out = np.dtype(dtype)
|
||||
dtypeobj_out = np.dtype("float64") if dtype is np.floating else np.dtype(dtype)
|
||||
dtype_in = dtypeobj_in.type
|
||||
dtype_out = dtypeobj_out.type
|
||||
kind_in = dtypeobj_in.kind
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
import time
|
||||
from asyncio import TimeoutError as AsyncTimeOutError
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Tuple
|
||||
from typing import Any, Deque
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
@ -44,14 +44,14 @@ class Queue:
|
||||
concurrency_count: int,
|
||||
update_intervals: float,
|
||||
max_size: int | None,
|
||||
blocks_dependencies: List,
|
||||
blocks_dependencies: list,
|
||||
):
|
||||
self.event_queue: Deque[Event] = deque()
|
||||
self.events_pending_reconnection = []
|
||||
self.stopped = False
|
||||
self.max_thread_count = concurrency_count
|
||||
self.update_intervals = update_intervals
|
||||
self.active_jobs: List[None | List[Event]] = [None] * concurrency_count
|
||||
self.active_jobs: list[None | list[Event]] = [None] * concurrency_count
|
||||
self.delete_lock = asyncio.Lock()
|
||||
self.server_path = None
|
||||
self.duration_history_total = 0
|
||||
@ -96,7 +96,7 @@ class Queue:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]:
|
||||
def get_events_in_batch(self) -> tuple[list[Event] | None, bool]:
|
||||
if not (self.event_queue):
|
||||
return None, False
|
||||
|
||||
@ -158,7 +158,7 @@ class Queue:
|
||||
def set_progress(
|
||||
self,
|
||||
event_id: str,
|
||||
iterables: List[TrackedIterable] | None,
|
||||
iterables: list[TrackedIterable] | None,
|
||||
):
|
||||
if iterables is None:
|
||||
return
|
||||
@ -167,7 +167,7 @@ class Queue:
|
||||
continue
|
||||
for evt in job:
|
||||
if evt._id == event_id:
|
||||
progress_data: List[ProgressUnit] = []
|
||||
progress_data: list[ProgressUnit] = []
|
||||
for iterable in iterables:
|
||||
progress_unit = ProgressUnit(
|
||||
index=iterable.index,
|
||||
@ -303,7 +303,7 @@ class Queue:
|
||||
queue_eta=self.queue_duration,
|
||||
)
|
||||
|
||||
def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
|
||||
def get_request_params(self, websocket: fastapi.WebSocket) -> dict[str, Any]:
|
||||
return {
|
||||
"url": str(websocket.url),
|
||||
"headers": dict(websocket.headers),
|
||||
@ -312,7 +312,7 @@ class Queue:
|
||||
"client": {"host": websocket.client.host, "port": websocket.client.port}, # type: ignore
|
||||
}
|
||||
|
||||
async def call_prediction(self, events: List[Event], batch: bool):
|
||||
async def call_prediction(self, events: list[Event], batch: bool):
|
||||
data = events[0].data
|
||||
assert data is not None, "No event data"
|
||||
token = events[0].token
|
||||
@ -340,8 +340,8 @@ class Queue:
|
||||
)
|
||||
return response
|
||||
|
||||
async def process_events(self, events: List[Event], batch: bool) -> None:
|
||||
awake_events: List[Event] = []
|
||||
async def process_events(self, events: list[Event], batch: bool) -> None:
|
||||
awake_events: list[Event] = []
|
||||
try:
|
||||
for event in events:
|
||||
client_awake = await self.gather_event_data(event)
|
||||
@ -438,7 +438,7 @@ class Queue:
|
||||
# to start "from scratch"
|
||||
await self.reset_iterators(event.session_hash, event.fn_index)
|
||||
|
||||
async def send_message(self, event, data: Dict, timeout: float | int = 1) -> bool:
|
||||
async def send_message(self, event, data: dict, timeout: float | int = 1) -> bool:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
event.websocket.send_json(data=data), timeout=timeout
|
||||
@ -448,7 +448,7 @@ class Queue:
|
||||
await self.clean_event(event)
|
||||
return False
|
||||
|
||||
async def get_message(self, event, timeout=5) -> Tuple[PredictBody | None, bool]:
|
||||
async def get_message(self, event, timeout=5) -> tuple[PredictBody | None, bool]:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
event.websocket.receive_json(), timeout=timeout
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
from typing import Dict, NamedTuple
|
||||
from typing import NamedTuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiofiles
|
||||
@ -36,7 +36,7 @@ class OpenRange(NamedTuple):
|
||||
|
||||
def clamp(self, start: int, end: int) -> ClosedRange:
|
||||
begin = max(self.start, start)
|
||||
end = min((x for x in (self.end, end) if x))
|
||||
end = min(x for x in (self.end, end) if x)
|
||||
|
||||
begin = min(begin, end)
|
||||
end = max(begin, end)
|
||||
@ -51,7 +51,7 @@ class RangedFileResponse(Response):
|
||||
self,
|
||||
path: str | os.PathLike,
|
||||
range: OpenRange,
|
||||
headers: Dict[str, str] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
filename: str | None = None,
|
||||
stat_result: os.stat_result | None = None,
|
||||
|
@ -18,10 +18,7 @@ def run_in_reload_mode():
|
||||
args = sys.argv[1:]
|
||||
if len(args) == 0:
|
||||
raise ValueError("No file specified.")
|
||||
if len(args) == 1:
|
||||
demo_name = "demo"
|
||||
else:
|
||||
demo_name = args[1]
|
||||
demo_name = "demo" if len(args) == 1 else args[1]
|
||||
|
||||
original_path = args[0]
|
||||
abs_original_path = utils.abspath(original_path)
|
||||
|
@ -309,10 +309,8 @@ class App(FastAPI):
|
||||
)
|
||||
abs_path = utils.abspath(path_or_url)
|
||||
in_blocklist = any(
|
||||
(
|
||||
utils.is_in_or_equal(abs_path, blocked_path)
|
||||
for blocked_path in blocks.blocked_paths
|
||||
)
|
||||
utils.is_in_or_equal(abs_path, blocked_path)
|
||||
for blocked_path in blocks.blocked_paths
|
||||
)
|
||||
if in_blocklist:
|
||||
raise HTTPException(403, f"File not allowed: {path_or_url}.")
|
||||
@ -320,10 +318,8 @@ class App(FastAPI):
|
||||
in_app_dir = utils.abspath(app.cwd) in abs_path.parents
|
||||
created_by_app = str(abs_path) in set().union(*blocks.temp_file_sets)
|
||||
in_file_dir = any(
|
||||
(
|
||||
utils.is_in_or_equal(abs_path, allowed_path)
|
||||
for allowed_path in blocks.allowed_paths
|
||||
)
|
||||
utils.is_in_or_equal(abs_path, allowed_path)
|
||||
for allowed_path in blocks.allowed_paths
|
||||
)
|
||||
was_uploaded = utils.abspath(app.uploaded_file_dir) in abs_path.parents
|
||||
|
||||
@ -464,14 +460,15 @@ class App(FastAPI):
|
||||
)
|
||||
else:
|
||||
fn_index_inferred = body.fn_index
|
||||
if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn(
|
||||
fn_index_inferred
|
||||
if (
|
||||
not app.get_blocks().api_open
|
||||
and app.get_blocks().queue_enabled_for_fn(fn_index_inferred)
|
||||
and f"Bearer {app.queue_token}" != request.headers.get("Authorization")
|
||||
):
|
||||
if f"Bearer {app.queue_token}" != request.headers.get("Authorization"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authorized to skip the queue",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authorized to skip the queue",
|
||||
)
|
||||
|
||||
# If this fn_index cancels jobs, then the only input we need is the
|
||||
# current session hash
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
@ -55,7 +54,7 @@ class Webcam(components.Image):
|
||||
self,
|
||||
value: str | Image | np.ndarray | None = None,
|
||||
*,
|
||||
shape: Tuple[int, int] | None = None,
|
||||
shape: tuple[int, int] | None = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "webcam",
|
||||
@ -102,7 +101,7 @@ class Sketchpad(components.Image):
|
||||
self,
|
||||
value: str | Image | np.ndarray | None = None,
|
||||
*,
|
||||
shape: Tuple[int, int] = (28, 28),
|
||||
shape: tuple[int, int] = (28, 28),
|
||||
image_mode: str = "L",
|
||||
invert_colors: bool = True,
|
||||
source: str = "canvas",
|
||||
@ -149,7 +148,7 @@ class Paint(components.Image):
|
||||
self,
|
||||
value: str | Image | np.ndarray | None = None,
|
||||
*,
|
||||
shape: Tuple[int, int] | None = None,
|
||||
shape: tuple[int, int] | None = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "canvas",
|
||||
@ -196,7 +195,7 @@ class ImageMask(components.Image):
|
||||
self,
|
||||
value: str | Image | np.ndarray | None = None,
|
||||
*,
|
||||
shape: Tuple[int, int] | None = None,
|
||||
shape: tuple[int, int] | None = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "upload",
|
||||
@ -243,7 +242,7 @@ class ImagePaint(components.Image):
|
||||
self,
|
||||
value: str | Image | np.ndarray | None = None,
|
||||
*,
|
||||
shape: Tuple[int, int] | None = None,
|
||||
shape: tuple[int, int] | None = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "upload",
|
||||
@ -290,7 +289,7 @@ class Pil(components.Image):
|
||||
self,
|
||||
value: str | Image | np.ndarray | None = None,
|
||||
*,
|
||||
shape: Tuple[int, int] | None = None,
|
||||
shape: tuple[int, int] | None = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "upload",
|
||||
@ -372,7 +371,7 @@ class Microphone(components.Audio):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: str | Tuple[int, np.ndarray] | Callable | None = None,
|
||||
value: str | tuple[int, np.ndarray] | Callable | None = None,
|
||||
*,
|
||||
source: str = "microphone",
|
||||
type: str = "numpy",
|
||||
@ -407,7 +406,7 @@ class Files(components.File):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: str | typing.List[str] | Callable | None = None,
|
||||
value: str | list[str] | Callable | None = None,
|
||||
*,
|
||||
file_count: str = "multiple",
|
||||
type: str = "file",
|
||||
@ -440,12 +439,12 @@ class Numpy(components.Dataframe):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: typing.List[typing.List[Any]] | Callable | None = None,
|
||||
value: list[list[Any]] | Callable | None = None,
|
||||
*,
|
||||
headers: typing.List[str] | None = None,
|
||||
row_count: int | Tuple[int, str] = (1, "dynamic"),
|
||||
col_count: int | Tuple[int, str] | None = None,
|
||||
datatype: str | typing.List[str] = "str",
|
||||
headers: list[str] | None = None,
|
||||
row_count: int | tuple[int, str] = (1, "dynamic"),
|
||||
col_count: int | tuple[int, str] | None = None,
|
||||
datatype: str | list[str] = "str",
|
||||
type: str = "numpy",
|
||||
max_rows: int | None = 20,
|
||||
max_cols: int | None = None,
|
||||
@ -487,12 +486,12 @@ class Matrix(components.Dataframe):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: typing.List[typing.List[Any]] | Callable | None = None,
|
||||
value: list[list[Any]] | Callable | None = None,
|
||||
*,
|
||||
headers: typing.List[str] | None = None,
|
||||
row_count: int | Tuple[int, str] = (1, "dynamic"),
|
||||
col_count: int | Tuple[int, str] | None = None,
|
||||
datatype: str | typing.List[str] = "str",
|
||||
headers: list[str] | None = None,
|
||||
row_count: int | tuple[int, str] = (1, "dynamic"),
|
||||
col_count: int | tuple[int, str] | None = None,
|
||||
datatype: str | list[str] = "str",
|
||||
type: str = "array",
|
||||
max_rows: int | None = 20,
|
||||
max_cols: int | None = None,
|
||||
@ -534,12 +533,12 @@ class List(components.Dataframe):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: typing.List[typing.List[Any]] | Callable | None = None,
|
||||
value: list[list[Any]] | Callable | None = None,
|
||||
*,
|
||||
headers: typing.List[str] | None = None,
|
||||
row_count: int | Tuple[int, str] = (1, "dynamic"),
|
||||
col_count: int | Tuple[int, str] = 1,
|
||||
datatype: str | typing.List[str] = "str",
|
||||
headers: list[str] | None = None,
|
||||
row_count: int | tuple[int, str] = (1, "dynamic"),
|
||||
col_count: int | tuple[int, str] = 1,
|
||||
datatype: str | list[str] = "str",
|
||||
type: str = "array",
|
||||
max_rows: int | None = 20,
|
||||
max_cols: int | None = None,
|
||||
|
@ -5,7 +5,7 @@ import re
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable
|
||||
from typing import Iterable
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
@ -108,17 +108,17 @@ class ThemeClass:
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "ThemeClass":
|
||||
def load(cls, path: str) -> ThemeClass:
|
||||
"""Load a theme from a json file.
|
||||
|
||||
Parameters:
|
||||
path: The filepath to read.
|
||||
"""
|
||||
theme = json.load(open(path), object_hook=fonts.as_font)
|
||||
return cls.from_dict(theme)
|
||||
with open(path) as fp:
|
||||
return cls.from_dict(json.load(fp, object_hook=fonts.as_font))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, theme: Dict[str, Dict[str, str]]) -> "ThemeClass":
|
||||
def from_dict(cls, theme: dict[str, dict[str, str]]) -> ThemeClass:
|
||||
"""Create a theme instance from a dictionary representation.
|
||||
|
||||
Parameters:
|
||||
@ -142,8 +142,7 @@ class ThemeClass:
|
||||
Parameters:
|
||||
filename: The path to write the theme too
|
||||
"""
|
||||
as_dict = self.to_dict()
|
||||
json.dump(as_dict, open(Path(filename), "w"), cls=fonts.FontEncoder)
|
||||
Path(filename).write_text(json.dumps(self.to_dict(), cls=fonts.FontEncoder))
|
||||
|
||||
@classmethod
|
||||
def from_hub(cls, repo_name: str, hf_token: str | None = None):
|
||||
@ -248,10 +247,7 @@ class ThemeClass:
|
||||
|
||||
# If no version, set the version to next patch release
|
||||
if not version:
|
||||
if space_exists:
|
||||
version = self._get_next_version(space_info)
|
||||
else:
|
||||
version = "0.0.1"
|
||||
version = self._get_next_version(space_info) if space_exists else "0.0.1"
|
||||
else:
|
||||
_ = semver.Version(version)
|
||||
|
||||
@ -279,7 +275,7 @@ class ThemeClass:
|
||||
)
|
||||
readme_file.write(textwrap.dedent(readme_content))
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as app_file:
|
||||
contents = open(str(Path(__file__).parent / "app.py")).read()
|
||||
contents = (Path(__file__).parent / "app.py").read_text()
|
||||
contents = re.sub(
|
||||
r"theme=gr.themes.Default\(\)",
|
||||
f"theme='{space_id}'",
|
||||
|
@ -76,7 +76,11 @@ css = """
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as demo:
|
||||
with gr.Blocks( # noqa: SIM117
|
||||
theme=gr.themes.Base(),
|
||||
css=css,
|
||||
title="Gradio Theme Builder",
|
||||
) as demo:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="controls", min_width=400):
|
||||
with gr.Row():
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import huggingface_hub
|
||||
import semantic_version
|
||||
@ -17,7 +16,7 @@ class ThemeAsset:
|
||||
self.version = semver.Version(self.filename.split("@")[1].replace(".json", ""))
|
||||
|
||||
|
||||
def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> List[ThemeAsset]:
|
||||
def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> list[ThemeAsset]:
|
||||
if "gradio-theme" not in getattr(space_info, "tags", []):
|
||||
raise ValueError(f"{space_info.id} is not a valid gradio-theme space!")
|
||||
|
||||
@ -29,7 +28,7 @@ def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> List[Theme
|
||||
|
||||
|
||||
def get_matching_version(
|
||||
assets: List[ThemeAsset], expression: str | None
|
||||
assets: list[ThemeAsset], expression: str | None
|
||||
) -> ThemeAsset | None:
|
||||
|
||||
expression = expression or "*"
|
||||
|
@ -27,11 +27,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -104,10 +100,10 @@ def get_local_ip_address() -> str:
|
||||
return ip_address
|
||||
|
||||
|
||||
def initiated_analytics(data: Dict[str, Any]) -> None:
|
||||
def initiated_analytics(data: dict[str, Any]) -> None:
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def initiated_analytics_thread(data: Dict[str, Any]) -> None:
|
||||
def initiated_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-initiated-analytics/", data=data, timeout=5
|
||||
@ -118,10 +114,10 @@ def initiated_analytics(data: Dict[str, Any]) -> None:
|
||||
threading.Thread(target=initiated_analytics_thread, args=(data,)).start()
|
||||
|
||||
|
||||
def launch_analytics(data: Dict[str, Any]) -> None:
|
||||
def launch_analytics(data: dict[str, Any]) -> None:
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def launch_analytics_thread(data: Dict[str, Any]) -> None:
|
||||
def launch_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-launched-analytics/", data=data, timeout=5
|
||||
@ -132,7 +128,7 @@ def launch_analytics(data: Dict[str, Any]) -> None:
|
||||
threading.Thread(target=launch_analytics_thread, args=(data,)).start()
|
||||
|
||||
|
||||
def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None:
|
||||
def launched_telemetry(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
blocks_telemetry, inputs_telemetry, outputs_telemetry, targets_telemetry = (
|
||||
[],
|
||||
[],
|
||||
@ -180,7 +176,7 @@ def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None:
|
||||
data.update(additional_data)
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def launched_telemtry_thread(data: Dict[str, Any]) -> None:
|
||||
def launched_telemtry_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-launched-telemetry/", data=data, timeout=5
|
||||
@ -191,10 +187,10 @@ def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None:
|
||||
threading.Thread(target=launched_telemtry_thread, args=(data,)).start()
|
||||
|
||||
|
||||
def integration_analytics(data: Dict[str, Any]) -> None:
|
||||
def integration_analytics(data: dict[str, Any]) -> None:
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def integration_analytics_thread(data: Dict[str, Any]) -> None:
|
||||
def integration_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-integration-analytics/", data=data, timeout=5
|
||||
@ -213,7 +209,7 @@ def error_analytics(message: str) -> None:
|
||||
"""
|
||||
data = {"ip_address": get_local_ip_address(), "error": message}
|
||||
|
||||
def error_analytics_thread(data: Dict[str, Any]) -> None:
|
||||
def error_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-error-analytics/", data=data, timeout=5
|
||||
@ -320,7 +316,7 @@ def launch_counter() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def get_default_args(func: Callable) -> List[Any]:
|
||||
def get_default_args(func: Callable) -> list[Any]:
|
||||
signature = inspect.signature(func)
|
||||
return [
|
||||
v.default if v.default is not inspect.Parameter.empty else None
|
||||
@ -329,7 +325,7 @@ def get_default_args(func: Callable) -> List[Any]:
|
||||
|
||||
|
||||
def assert_configs_are_equivalent_besides_ids(
|
||||
config1: Dict, config2: Dict, root_keys: Tuple = ("mode",)
|
||||
config1: dict, config2: dict, root_keys: tuple = ("mode",)
|
||||
):
|
||||
"""Allows you to test if two different Blocks configs produce the same demo.
|
||||
|
||||
@ -382,7 +378,7 @@ def assert_configs_are_equivalent_besides_ids(
|
||||
return True
|
||||
|
||||
|
||||
def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]):
|
||||
def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]):
|
||||
if len(ner_groups) == 0:
|
||||
return [(input_string, None)]
|
||||
|
||||
@ -400,7 +396,7 @@ def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]):
|
||||
return output
|
||||
|
||||
|
||||
def delete_none(_dict: Dict, skip_value: bool = False) -> Dict:
|
||||
def delete_none(_dict: dict, skip_value: bool = False) -> dict:
|
||||
"""
|
||||
Delete keys whose values are None from a dictionary
|
||||
"""
|
||||
@ -412,14 +408,14 @@ def delete_none(_dict: Dict, skip_value: bool = False) -> Dict:
|
||||
return _dict
|
||||
|
||||
|
||||
def resolve_singleton(_list: List[Any] | Any) -> Any:
|
||||
def resolve_singleton(_list: list[Any] | Any) -> Any:
|
||||
if len(_list) == 1:
|
||||
return _list[0]
|
||||
else:
|
||||
return _list
|
||||
|
||||
|
||||
def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockContext]:
|
||||
def component_or_layout_class(cls_name: str) -> type[Component] | type[BlockContext]:
|
||||
"""
|
||||
Returns the component, template, or layout class with the given class name, or
|
||||
raises a ValueError if not found.
|
||||
@ -536,9 +532,9 @@ class AsyncRequest:
|
||||
method: Method,
|
||||
url: str,
|
||||
*,
|
||||
validation_model: Type[BaseModel] | None = None,
|
||||
validation_model: type[BaseModel] | None = None,
|
||||
validation_function: Union[Callable, None] = None,
|
||||
exception_type: Type[Exception] = Exception,
|
||||
exception_type: type[Exception] = Exception,
|
||||
raise_for_status: bool = False,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
**kwargs,
|
||||
@ -565,7 +561,7 @@ class AsyncRequest:
|
||||
self._request = self._create_request(method, url, **kwargs)
|
||||
self.client_ = client or self.client
|
||||
|
||||
def __await__(self) -> Generator[None, Any, "AsyncRequest"]:
|
||||
def __await__(self) -> Generator[None, Any, AsyncRequest]:
|
||||
"""
|
||||
Wrap Request's __await__ magic function to create request calls which are executed in one line.
|
||||
"""
|
||||
@ -740,7 +736,7 @@ def sanitize_value_for_csv(value: str | Number) -> str | Number:
|
||||
return value
|
||||
|
||||
|
||||
def sanitize_list_for_csv(values: List[Any]) -> List[Any]:
|
||||
def sanitize_list_for_csv(values: list[Any]) -> list[Any]:
|
||||
"""
|
||||
Sanitizes a list of values (or a list of list of values) that is being written to a
|
||||
CSV file to prevent CSV injection attacks.
|
||||
@ -756,7 +752,7 @@ def sanitize_list_for_csv(values: List[Any]) -> List[Any]:
|
||||
return sanitized_values
|
||||
|
||||
|
||||
def append_unique_suffix(name: str, list_of_names: List[str]):
|
||||
def append_unique_suffix(name: str, list_of_names: list[str]):
|
||||
"""Appends a numerical suffix to `name` so that it does not appear in `list_of_names`."""
|
||||
set_of_names: set[str] = set(list_of_names) # for O(1) lookup
|
||||
if name not in set_of_names:
|
||||
@ -815,8 +811,8 @@ def set_task_name(task, session_hash: str, fn_index: int, batch: bool):
|
||||
|
||||
|
||||
def get_cancel_function(
|
||||
dependencies: List[Dict[str, Any]]
|
||||
) -> Tuple[Callable, List[int]]:
|
||||
dependencies: list[dict[str, Any]]
|
||||
) -> tuple[Callable, list[int]]:
|
||||
fn_to_comp = {}
|
||||
for dep in dependencies:
|
||||
if Context.root_block:
|
||||
@ -858,7 +854,7 @@ def is_special_typed_parameter(name, parameter_types):
|
||||
return is_request or is_event_data
|
||||
|
||||
|
||||
def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool):
|
||||
def check_function_inputs_match(fn: Callable, inputs: list, inputs_as_dict: bool):
|
||||
"""
|
||||
Checks if the input component set matches the function
|
||||
Returns: None if valid, a string error message if mismatch
|
||||
@ -878,9 +874,8 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
|
||||
max_args += 1
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
max_args = infinity
|
||||
elif param.kind == param.KEYWORD_ONLY:
|
||||
if not has_default:
|
||||
return f"Keyword-only args must have default values for function {fn}"
|
||||
elif param.kind == param.KEYWORD_ONLY and not has_default:
|
||||
return f"Keyword-only args must have default values for function {fn}"
|
||||
arg_count = 1 if inputs_as_dict else len(inputs)
|
||||
if min_args == max_args and max_args != arg_count:
|
||||
warnings.warn(
|
||||
@ -918,15 +913,15 @@ def tex2svg(formula, *args):
|
||||
with MatplotlibBackendMananger():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
FONTSIZE = 20
|
||||
DPI = 300
|
||||
fontsize = 20
|
||||
dpi = 300
|
||||
plt.rc("mathtext", fontset="cm")
|
||||
fig = plt.figure(figsize=(0.01, 0.01))
|
||||
fig.text(0, 0, rf"${formula}$", fontsize=FONTSIZE)
|
||||
fig.text(0, 0, rf"${formula}$", fontsize=fontsize)
|
||||
output = BytesIO()
|
||||
fig.savefig(
|
||||
output,
|
||||
dpi=DPI,
|
||||
dpi=dpi,
|
||||
transparent=True,
|
||||
format="svg",
|
||||
bbox_inches="tight",
|
||||
@ -942,7 +937,7 @@ def tex2svg(formula, *args):
|
||||
height_match = re.search(r'height="([\d.]+)pt"', svg_code)
|
||||
if height_match:
|
||||
height = float(height_match.group(1))
|
||||
new_height = height / FONTSIZE # conversion from pt to em
|
||||
new_height = height / fontsize # conversion from pt to em
|
||||
svg_code = re.sub(
|
||||
r'height="[\d.]+pt"', f'height="{new_height}em"', svg_code
|
||||
)
|
||||
@ -1016,7 +1011,7 @@ def get_serializer_name(block: Block) -> str | None:
|
||||
def highlight_code(code, name, attrs):
|
||||
try:
|
||||
lexer = get_lexer_by_name(name)
|
||||
except:
|
||||
except Exception:
|
||||
lexer = get_lexer_by_name("text")
|
||||
formatter = HtmlFormatter()
|
||||
|
||||
@ -1047,3 +1042,10 @@ def get_markdown_parser() -> MarkdownIt:
|
||||
md.add_render_rule("link_open", render_blank_link)
|
||||
|
||||
return md
|
||||
|
||||
|
||||
HTML_TAG_RE = re.compile("<.*?>")
|
||||
|
||||
|
||||
def remove_html_tags(raw_html: str | None) -> str:
|
||||
return re.sub(HTML_TAG_RE, "", raw_html or "")
|
||||
|
@ -67,10 +67,9 @@ extend-select = [
|
||||
"B",
|
||||
"C",
|
||||
"I",
|
||||
# Formatting-related UP rules
|
||||
"UP030",
|
||||
"UP031",
|
||||
"UP032",
|
||||
"N",
|
||||
"SIM",
|
||||
"UP",
|
||||
]
|
||||
ignore = [
|
||||
"C901", # function is too complex (TODO: un-ignore this)
|
||||
@ -79,6 +78,9 @@ ignore = [
|
||||
"B017", # pytest.raises considered evil
|
||||
"B028", # explicit stacklevel for warnings
|
||||
"E501", # from scripts/lint_backend.sh
|
||||
"SIM105", # contextlib.suppress (has a performance cost)
|
||||
"SIM117", # multiple nested with blocks (doesn't look good with gr.Row etc)
|
||||
"UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them)
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
@ -91,3 +93,6 @@ ignore = [
|
||||
"gradio/__init__.py" = [
|
||||
"F401", # "Imported but unused" (TODO: it would be better to be explicit and use __all__)
|
||||
]
|
||||
"gradio/routes.py" = [
|
||||
"UP006", # Pydantic on Python 3.7 requires old-style type annotations (TODO: drop when Python 3.7 is dropped)
|
||||
]
|
||||
|
@ -197,7 +197,7 @@ respx==0.19.2
|
||||
# via -r requirements.in
|
||||
rfc3986[idna2008]==1.5.0
|
||||
# via httpx
|
||||
ruff==0.0.263
|
||||
ruff==0.0.264
|
||||
# via -r requirements.in
|
||||
s3transfer==0.6.0
|
||||
# via boto3
|
||||
|
@ -185,7 +185,7 @@ requests==2.28.1
|
||||
# transformers
|
||||
respx==0.19.2
|
||||
# via -r requirements.in
|
||||
ruff==0.0.263
|
||||
ruff==0.0.264
|
||||
# via -r requirements.in
|
||||
rfc3986[idna2008]==1.5.0
|
||||
# via httpx
|
||||
|
@ -279,8 +279,7 @@ class TestBlocksMethods:
|
||||
return 42
|
||||
|
||||
def generator_function():
|
||||
for index in range(10):
|
||||
yield index
|
||||
yield from range(10)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
|
||||
@ -670,8 +669,7 @@ class TestCallFunction:
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_generator(self):
|
||||
def generator(x):
|
||||
for i in range(x):
|
||||
yield i
|
||||
yield from range(x)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
inp = gr.Number()
|
||||
@ -1368,11 +1366,10 @@ class TestProgressBar:
|
||||
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
||||
if msg["msg"] == "progress":
|
||||
if msg[
|
||||
"progress_data"
|
||||
]: # Ignore empty lists which sometimes appear on Windows
|
||||
progress_updates.append(msg["progress_data"])
|
||||
if (
|
||||
msg["msg"] == "progress" and msg["progress_data"]
|
||||
): # Ignore empty lists which sometimes appear on Windows
|
||||
progress_updates.append(msg["progress_data"])
|
||||
if msg["msg"] == "process_completed":
|
||||
completed = True
|
||||
break
|
||||
|
@ -1489,7 +1489,8 @@ class TestLabel:
|
||||
label_output = gr.Label()
|
||||
label = label_output.postprocess(y)
|
||||
assert label == {"label": "happy"}
|
||||
assert json.load(open(label_output.deserialize(label))) == label
|
||||
with open(label_output.deserialize(label)) as f:
|
||||
assert json.load(f) == label
|
||||
|
||||
y = {3: 0.7, 1: 0.2, 0: 0.1}
|
||||
label = label_output.postprocess(y)
|
||||
|
@ -11,7 +11,7 @@ from gradio_client import media_data
|
||||
|
||||
import gradio as gr
|
||||
from gradio.context import Context
|
||||
from gradio.exceptions import InvalidApiName
|
||||
from gradio.exceptions import InvalidApiNameError
|
||||
from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples
|
||||
|
||||
"""
|
||||
@ -190,16 +190,16 @@ class TestLoadInterface:
|
||||
def test_sentiment_model(self):
|
||||
io = gr.load("models/distilbert-base-uncased-finetuned-sst-2-english")
|
||||
try:
|
||||
output = io("I am happy, I love you")
|
||||
assert json.load(open(output))["label"] == "POSITIVE"
|
||||
with open(io("I am happy, I love you")) as f:
|
||||
assert json.load(f)["label"] == "POSITIVE"
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
def test_image_classification_model(self):
|
||||
io = gr.Blocks.load(name="models/google/vit-base-patch16-224")
|
||||
try:
|
||||
output = io("gradio/test_data/lion.jpg")
|
||||
assert json.load(open(output))["label"] == "lion"
|
||||
with open(io("gradio/test_data/lion.jpg")) as f:
|
||||
assert json.load(f)["label"] == "lion"
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
@ -214,9 +214,9 @@ class TestLoadInterface:
|
||||
def test_numerical_to_label_space(self):
|
||||
io = gr.load("spaces/abidlabs/titanic-survival")
|
||||
try:
|
||||
output = io("male", 77, 10)
|
||||
assert json.load(open(output))["label"] == "Perishes"
|
||||
assert io.theme.name == "soft"
|
||||
with open(io("male", 77, 10)) as f:
|
||||
assert json.load(f)["label"] == "Perishes"
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
@ -472,7 +472,7 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme):
|
||||
|
||||
|
||||
def test_raise_value_error_when_api_name_invalid():
|
||||
with pytest.raises(InvalidApiName):
|
||||
with pytest.raises(InvalidApiNameError):
|
||||
demo = gr.Blocks.load(name="spaces/gradio/hello_world")
|
||||
demo("freddy", api_name="route does not exist")
|
||||
|
||||
|
@ -26,7 +26,7 @@ class TestDefault:
|
||||
"interpretation"
|
||||
]
|
||||
assert interpretation[0][1] > 0 # Checks to see if the first word has >0 score.
|
||||
assert 0 == interpretation[-1][1] # Checks to see if the last word has 0 score.
|
||||
assert interpretation[-1][1] == 0 # Checks to see if the last word has 0 score.
|
||||
|
||||
|
||||
class TestShapley:
|
||||
@ -92,9 +92,9 @@ class TestHelperMethods:
|
||||
def test_quantify_difference_with_label(self):
|
||||
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"])
|
||||
assert -7 == diff
|
||||
assert diff == -7
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"])
|
||||
assert -100 == diff
|
||||
assert diff == -100
|
||||
|
||||
def test_quantify_difference_with_confidences(self):
|
||||
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
||||
@ -104,11 +104,11 @@ class TestHelperMethods:
|
||||
diff = gradio.interpretation.quantify_difference_in_label(
|
||||
iface, [output_1], [output_2]
|
||||
)
|
||||
assert 0.3 == pytest.approx(diff)
|
||||
assert pytest.approx(diff) == 0.3
|
||||
diff = gradio.interpretation.quantify_difference_in_label(
|
||||
iface, [output_1], [output_3]
|
||||
)
|
||||
assert 0.8 == pytest.approx(diff)
|
||||
assert pytest.approx(diff) == 0.8
|
||||
|
||||
def test_get_regression_value(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
||||
@ -118,19 +118,19 @@ class TestHelperMethods:
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, [output_1], [output_2]
|
||||
)
|
||||
assert 0 == diff
|
||||
assert diff == 0
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, [output_1], [output_3]
|
||||
)
|
||||
assert 0.1 == pytest.approx(diff)
|
||||
assert pytest.approx(diff) == 0.1
|
||||
|
||||
def test_get_classification_value(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, ["cat"], ["test"]
|
||||
)
|
||||
assert 1 == diff
|
||||
assert diff == 1
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(
|
||||
iface, ["test"], ["test"]
|
||||
)
|
||||
assert 0 == diff
|
||||
assert diff == 0
|
||||
|
@ -28,8 +28,8 @@ class TestSeries:
|
||||
io2 = gr.load("spaces/abidlabs/image-classifier")
|
||||
series = mix.Series(io1, io2)
|
||||
try:
|
||||
output = series("gradio/test_data/lion.jpg")
|
||||
assert json.load(open(output))["label"] == "lion"
|
||||
with open(series("gradio/test_data/lion.jpg")) as f:
|
||||
assert json.load(f)["label"] == "lion"
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
|
@ -187,15 +187,16 @@ class TestVideoProcessing:
|
||||
def test_video_has_playable_codecs_catches_exceptions(
|
||||
self, exception_to_raise, test_file_dir
|
||||
):
|
||||
with patch("ffmpy.FFprobe.run", side_effect=exception_to_raise):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix="out.avi", delete=False
|
||||
) as tmp_not_playable_vid:
|
||||
shutil.copy(
|
||||
str(test_file_dir / "bad_video_sample.mp4"),
|
||||
tmp_not_playable_vid.name,
|
||||
)
|
||||
assert processing_utils.video_is_playable(tmp_not_playable_vid.name)
|
||||
with patch(
|
||||
"ffmpy.FFprobe.run", side_effect=exception_to_raise
|
||||
), tempfile.NamedTemporaryFile(
|
||||
suffix="out.avi", delete=False
|
||||
) as tmp_not_playable_vid:
|
||||
shutil.copy(
|
||||
str(test_file_dir / "bad_video_sample.mp4"),
|
||||
tmp_not_playable_vid.name,
|
||||
)
|
||||
assert processing_utils.video_is_playable(tmp_not_playable_vid.name)
|
||||
|
||||
def test_convert_video_to_playable_mp4(self, test_file_dir):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
|
@ -56,9 +56,8 @@ class TestRoutes:
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_upload_path(self, test_client):
|
||||
response = test_client.post(
|
||||
"/upload", files={"files": open("test/test_files/alphabet.txt", "r")}
|
||||
)
|
||||
with open("test/test_files/alphabet.txt") as f:
|
||||
response = test_client.post("/upload", files={"files": f})
|
||||
assert response.status_code == 200
|
||||
file = response.json()[0]
|
||||
assert "alphabet" in file
|
||||
@ -72,9 +71,8 @@ class TestRoutes:
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
test_client = TestClient(app)
|
||||
try:
|
||||
response = test_client.post(
|
||||
"/upload", files={"files": open("test/test_files/alphabet.txt", "r")}
|
||||
)
|
||||
with open("test/test_files/alphabet.txt") as f:
|
||||
response = test_client.post("/upload", files={"files": f})
|
||||
assert response.status_code == 200
|
||||
file = response.json()[0]
|
||||
assert "alphabet" in file
|
||||
@ -399,8 +397,7 @@ class TestRoutes:
|
||||
class TestGeneratorRoutes:
|
||||
def test_generator(self):
|
||||
def generator(string):
|
||||
for char in string:
|
||||
yield char
|
||||
yield from string
|
||||
|
||||
io = Interface(generator, "text", "text")
|
||||
app, _, _ = io.queue().launch(prevent_thread_lock=True)
|
||||
|
@ -350,7 +350,7 @@ class TestRequest:
|
||||
name: str
|
||||
job: str
|
||||
id: str
|
||||
createdAt: str
|
||||
createdAt: str # noqa: N815
|
||||
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
@ -434,7 +434,7 @@ async def test_validate_with_model(respx_mock):
|
||||
name: str
|
||||
job: str
|
||||
id: str
|
||||
createdAt: str
|
||||
createdAt: str # noqa: N815
|
||||
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.POST,
|
||||
@ -469,7 +469,7 @@ async def test_validate_and_fail_with_model(respx_mock):
|
||||
@mock.patch("gradio.utils.AsyncRequest._validate_response_data")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_type(validate_response_data, respx_mock):
|
||||
class ResponseValidationException(Exception):
|
||||
class ResponseValidationError(Exception):
|
||||
message = "Response object is not valid."
|
||||
|
||||
validate_response_data.side_effect = Exception()
|
||||
@ -479,9 +479,9 @@ async def test_exception_type(validate_response_data, respx_mock):
|
||||
client_response: AsyncRequest = await AsyncRequest(
|
||||
method=AsyncRequest.Method.GET,
|
||||
url=MOCK_REQUEST_URL,
|
||||
exception_type=ResponseValidationException,
|
||||
exception_type=ResponseValidationError,
|
||||
)
|
||||
assert isinstance(client_response.exception, ResponseValidationException)
|
||||
assert isinstance(client_response.exception, ResponseValidationError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -511,9 +511,8 @@ async def test_validate_with_function(respx_mock):
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_fail_with_function(respx_mock):
|
||||
def has_name(response):
|
||||
if response["name"] is not None:
|
||||
if response["name"] == "Alex":
|
||||
return response
|
||||
if response["name"] is not None and response["name"] == "Alex":
|
||||
return response
|
||||
raise Exception
|
||||
|
||||
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response({"name": "morpheus"}))
|
||||
|
Loading…
x
Reference in New Issue
Block a user