From d1853625fd75247b2b2d4c81785cb77a51dad199 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 5 May 2023 05:54:23 +0300 Subject: [PATCH] More Ruff rules (#4038) * Bump ruff to 0.0.264 * Enable Ruff Naming rules and fix most errors * Move `clean_html` to utils (to fix an N lint error) * Changelog * Clean up possibly leaking file handles * Enable and autofix Ruff SIM * Fix remaining Ruff SIMs * Enable and autofix Ruff UP issues * Fix misordered import from #4048 * Fix bare except from #4048 --------- Co-authored-by: Abubakar Abid --- CHANGELOG.md | 1 + client/python/gradio_client/client.py | 44 +- client/python/gradio_client/documentation.py | 4 +- client/python/gradio_client/serializing.py | 76 ++-- client/python/gradio_client/utils.py | 31 +- client/python/test/requirements.txt | 2 +- client/python/test/test_client.py | 33 +- client/python/test/test_serializing.py | 6 +- gradio/__init__.py | 2 +- gradio/blocks.py | 123 +++--- gradio/components.py | 439 +++++++++---------- gradio/events.py | 12 +- gradio/exceptions.py | 5 +- gradio/external.py | 8 +- gradio/flagging.py | 38 +- gradio/helpers.py | 43 +- gradio/inputs.py | 22 +- gradio/interface.py | 130 +++--- gradio/interpretation.py | 20 +- gradio/layouts.py | 3 +- gradio/networking.py | 4 +- gradio/outputs.py | 10 +- gradio/pipelines.py | 4 +- gradio/processing_utils.py | 8 +- gradio/queueing.py | 24 +- gradio/ranged_response.py | 6 +- gradio/reload.py | 5 +- gradio/routes.py | 27 +- gradio/templates.py | 49 +-- gradio/themes/base.py | 20 +- gradio/themes/builder.py | 6 +- gradio/themes/utils/semver_match.py | 5 +- gradio/utils.py | 74 ++-- pyproject.toml | 13 +- test/requirements-37.txt | 2 +- test/requirements.txt | 2 +- test/test_blocks.py | 15 +- test/test_components.py | 3 +- test/test_external.py | 16 +- test/test_interpretation.py | 18 +- test/test_mix.py | 4 +- test/test_processing_utils.py | 19 +- test/test_routes.py | 13 +- test/test_utils.py | 15 +- 44 files changed, 694 insertions(+), 710 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d58dd2b7d8..310cfb6f68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -120,6 +120,7 @@ No changes to highlight. - CI: Simplified Python CI workflow by [@akx](https://github.com/akx) in [PR 3982](https://github.com/gradio-app/gradio/pull/3982) - Upgrade pyright to 1.1.305 by [@akx](https://github.com/akx) in [PR 4042](https://github.com/gradio-app/gradio/pull/4042) +- More Ruff rules are enabled and lint errors fixed by [@akx](https://github.com/akx) in [PR 4038](https://github.com/gradio-app/gradio/pull/4038) ## Breaking Changes: diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index e86d5f47d6..e44c0b7566 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -13,7 +13,7 @@ from concurrent.futures import Future, TimeoutError from datetime import datetime from pathlib import Path from threading import Lock -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable import huggingface_hub import requests @@ -132,7 +132,7 @@ class Client: hf_token: str | None = None, private: bool = True, hardware: str | None = None, - secrets: Dict[str, str] | None = None, + secrets: dict[str, str] | None = None, sleep_timeout: int = 5, max_workers: int = 40, verbose: bool = True, @@ -216,7 +216,7 @@ class Client: current_info.hardware or huggingface_hub.SpaceHardware.CPU_BASIC ) hardware = hardware or original_info.hardware - if not current_hardware == hardware: + if current_hardware != hardware: huggingface_hub.request_space_hardware(space_id, hardware) # type: ignore print( f"-------\nNOTE: this Space uses upgraded hardware: {hardware}... see billing info at https://huggingface.co/settings/billing\n-------" @@ -262,7 +262,7 @@ class Client: *args, api_name: str | None = None, fn_index: int | None = None, - result_callbacks: Callable | List[Callable] | None = None, + result_callbacks: Callable | list[Callable] | None = None, ) -> Job: """ Creates and returns a Job object which calls the Gradio API in a background thread. The job can be used to retrieve the status and result of the remote API call. @@ -323,7 +323,7 @@ class Client: all_endpoints: bool | None = None, print_info: bool = True, return_format: Literal["dict", "str"] | None = None, - ) -> Dict | str | None: + ) -> dict | str | None: """ Prints the usage info for the API. If the Gradio app has multiple API endpoints, the usage info for each endpoint will be printed separately. If return_format="dict" the info is returned in dictionary format, as shown in the example below. @@ -449,7 +449,7 @@ class Client: def _render_endpoints_info( self, name_or_index: str | int, - endpoints_info: Dict[str, List[Dict[str, str]]], + endpoints_info: dict[str, list[dict[str, str]]], ) -> str: parameter_names = [p["label"] for p in endpoints_info["parameters"]] parameter_names = [utils.sanitize_parameter_names(p) for p in parameter_names] @@ -542,7 +542,7 @@ class Client: def _space_name_to_src(self, space) -> str | None: return huggingface_hub.space_info(space, token=self.hf_token).host # type: ignore - def _get_config(self) -> Dict: + def _get_config(self) -> dict: r = requests.get( urllib.parse.urljoin(self.src, utils.CONFIG_URL), headers=self.headers ) @@ -568,7 +568,7 @@ class Client: class Endpoint: """Helper class for storing all the information about a single API endpoint.""" - def __init__(self, client: Client, fn_index: int, dependency: Dict): + def __init__(self, client: Client, fn_index: int, dependency: dict): self.client: Client = client self.fn_index = fn_index self.dependency = dependency @@ -615,7 +615,7 @@ class Endpoint: return _inner def make_predict(self, helper: Communicator | None = None): - def _predict(*data) -> Tuple: + def _predict(*data) -> tuple: data = json.dumps( { "data": data, @@ -669,8 +669,8 @@ class Endpoint: return outputs def _upload( - self, file_paths: List[str | List[str]] - ) -> List[str | List[str]] | List[Dict[str, Any] | List[Dict[str, Any]]]: + self, file_paths: list[str | list[str]] + ) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]: if not file_paths: return [] # Put all the filepaths in one file @@ -683,7 +683,7 @@ class Endpoint: if not isinstance(fs, list): fs = [fs] for f in fs: - files.append(("files", (Path(f).name, open(f, "rb")))) + files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115 indices.append(i) r = requests.post( self.client.upload_url, headers=self.client.headers, files=files @@ -718,8 +718,8 @@ class Endpoint: def _add_uploaded_files_to_data( self, - files: List[str | List[str]] | List[Dict[str, Any] | List[Dict[str, Any]]], - data: List[Any], + files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]], + data: list[Any], ) -> None: """Helper function to modify the input data with the uploaded files.""" file_counter = 0 @@ -728,7 +728,7 @@ class Endpoint: data[i] = files[file_counter] file_counter += 1 - def serialize(self, *data) -> Tuple: + def serialize(self, *data) -> tuple: data = list(data) for i, input_component_type in enumerate(self.input_component_types): if input_component_type == utils.STATE_COMPONENT: @@ -748,7 +748,7 @@ class Endpoint: o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)]) return o - def deserialize(self, *data) -> Tuple | Any: + def deserialize(self, *data) -> tuple | Any: assert len(data) == len( self.deserializers ), f"Expected {len(self.deserializers)} outputs, got {len(data)}" @@ -758,7 +758,7 @@ class Endpoint: for s, d, oct in zip( self.deserializers, data, self.output_component_types ) - if not oct == utils.STATE_COMPONENT + if oct != utils.STATE_COMPONENT ] ) if ( @@ -766,7 +766,7 @@ class Endpoint: [ oct for oct in self.output_component_types - if not oct == utils.STATE_COMPONENT + if oct != utils.STATE_COMPONENT ] ) == 1 @@ -776,7 +776,7 @@ class Endpoint: output = outputs return output - def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]: + def _setup_serializers(self) -> tuple[list[Serializable], list[Serializable]]: inputs = self.dependency["inputs"] serializers = [] @@ -820,7 +820,7 @@ class Endpoint: return serializers, deserializers - def _use_websocket(self, dependency: Dict) -> bool: + def _use_websocket(self, dependency: dict) -> bool: queue_enabled = self.client.config.get("enable_queue", False) queue_uses_websocket = version.parse( self.client.config.get("version", "2.0") @@ -873,7 +873,7 @@ class Job(Future): def __iter__(self) -> Job: return self - def __next__(self) -> Tuple | Any: + def __next__(self) -> tuple | Any: if not self.communicator: raise StopIteration() @@ -925,7 +925,7 @@ class Job(Future): else: return super().result(timeout=timeout) - def outputs(self) -> List[Tuple | Any]: + def outputs(self) -> list[tuple | Any]: """ Returns a list containing the latest outputs from the Job. diff --git a/client/python/gradio_client/documentation.py b/client/python/gradio_client/documentation.py index e6c60412c3..4d8d41ddf8 100644 --- a/client/python/gradio_client/documentation.py +++ b/client/python/gradio_client/documentation.py @@ -3,7 +3,7 @@ from __future__ import annotations import inspect -from typing import Callable, Dict, List, Tuple +from typing import Callable classes_to_document = {} classes_inherit_documentation = {} @@ -61,7 +61,7 @@ def document(*fns, inherit=False): return inner_doc -def document_fn(fn: Callable, cls) -> Tuple[str, List[Dict], Dict, str | None]: +def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]: """ Generates documentation for any function. Parameters: diff --git a/client/python/gradio_client/serializing.py b/client/python/gradio_client/serializing.py index 268ed31f3c..02c9507355 100644 --- a/client/python/gradio_client/serializing.py +++ b/client/python/gradio_client/serializing.py @@ -4,21 +4,21 @@ import json import os import uuid from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any from gradio_client import media_data, utils from gradio_client.data_classes import FileData class Serializable: - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: """ The typing information for this component as a dictionary whose values are a list of 2 strings: [Python type, language-agnostic description]. Keys of the dictionary are: raw_input, raw_output, serialized_input, serialized_output """ raise NotImplementedError() - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: """ The example inputs for this component as a dictionary whose values are example inputs compatible with this component. Keys of the dictionary are: raw, serialized @@ -26,12 +26,12 @@ class Serializable: raise NotImplementedError() # For backwards compatibility - def input_api_info(self) -> Tuple[str, str]: + def input_api_info(self) -> tuple[str, str]: api_info = self.api_info() return (api_info["serialized_input"][0], api_info["serialized_input"][1]) # For backwards compatibility - def output_api_info(self) -> Tuple[str, str]: + def output_api_info(self) -> tuple[str, str]: api_info = self.api_info() return (api_info["serialized_output"][0], api_info["serialized_output"][1]) @@ -57,7 +57,7 @@ class Serializable: class SimpleSerializable(Serializable): """General class that does not perform any serialization or deserialization.""" - def api_info(self) -> Dict[str, str | List[str]]: + def api_info(self) -> dict[str, str | list[str]]: return { "raw_input": ["Any", ""], "raw_output": ["Any", ""], @@ -65,7 +65,7 @@ class SimpleSerializable(Serializable): "serialized_output": ["Any", ""], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": None, "serialized": None, @@ -75,7 +75,7 @@ class SimpleSerializable(Serializable): class StringSerializable(Serializable): """Expects a string as input/output but performs no serialization.""" - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": ["str", "string value"], "raw_output": ["str", "string value"], @@ -83,7 +83,7 @@ class StringSerializable(Serializable): "serialized_output": ["str", "string value"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": "Howdy!", "serialized": "Howdy!", @@ -93,7 +93,7 @@ class StringSerializable(Serializable): class ListStringSerializable(Serializable): """Expects a list of strings as input/output but performs no serialization.""" - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": ["List[str]", "list of string values"], "raw_output": ["List[str]", "list of string values"], @@ -101,7 +101,7 @@ class ListStringSerializable(Serializable): "serialized_output": ["List[str]", "list of string values"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": ["Howdy!", "Merhaba"], "serialized": ["Howdy!", "Merhaba"], @@ -111,7 +111,7 @@ class ListStringSerializable(Serializable): class BooleanSerializable(Serializable): """Expects a boolean as input/output but performs no serialization.""" - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": ["bool", "boolean value"], "raw_output": ["bool", "boolean value"], @@ -119,7 +119,7 @@ class BooleanSerializable(Serializable): "serialized_output": ["bool", "boolean value"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": True, "serialized": True, @@ -129,7 +129,7 @@ class BooleanSerializable(Serializable): class NumberSerializable(Serializable): """Expects a number (int/float) as input/output but performs no serialization.""" - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": ["int | float", "numeric value"], "raw_output": ["int | float", "numeric value"], @@ -137,7 +137,7 @@ class NumberSerializable(Serializable): "serialized_output": ["int | float", "numeric value"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": 5, "serialized": 5, @@ -147,7 +147,7 @@ class NumberSerializable(Serializable): class ImgSerializable(Serializable): """Expects a base64 string as input/output which is serialized to a filepath.""" - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": ["str", "base64 representation of image"], "raw_output": ["str", "base64 representation of image"], @@ -155,7 +155,7 @@ class ImgSerializable(Serializable): "serialized_output": ["str", "filepath or URL to image"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": media_data.BASE64_IMAGE, "serialized": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", @@ -204,7 +204,7 @@ class ImgSerializable(Serializable): class FileSerializable(Serializable): """Expects a dict with base64 representation of object as input/output which is serialized to a filepath.""" - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": [ "str | Dict", @@ -218,7 +218,7 @@ class FileSerializable(Serializable): "serialized_output": ["str", "filepath or URL to file"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": {"is_file": False, "data": media_data.BASE64_FILE}, "serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf", @@ -280,9 +280,9 @@ class FileSerializable(Serializable): def serialize( self, - x: str | FileData | None | List[str | FileData | None], + x: str | FileData | None | list[str | FileData | None], load_dir: str | Path = "", - ) -> FileData | None | List[FileData | None]: + ) -> FileData | None | list[FileData | None]: """ Convert from human-friendly version of a file (string filepath) to a seralized representation (base64) @@ -299,11 +299,11 @@ class FileSerializable(Serializable): def deserialize( self, - x: str | FileData | None | List[str | FileData | None], + x: str | FileData | None | list[str | FileData | None], save_dir: Path | str | None = None, root_url: str | None = None, hf_token: str | None = None, - ) -> str | None | List[str | None]: + ) -> str | None | list[str | None]: """ Convert from serialized representation of a file (base64) to a human-friendly version (string filepath). Optionally, save the file to the directory specified by `save_dir` @@ -331,7 +331,7 @@ class FileSerializable(Serializable): class VideoSerializable(FileSerializable): - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": [ "str | Dict", @@ -345,7 +345,7 @@ class VideoSerializable(FileSerializable): "serialized_output": ["str", "filepath or URL to file"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": {"is_file": False, "data": media_data.BASE64_VIDEO}, "serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/video_sample.mp4", @@ -353,16 +353,16 @@ class VideoSerializable(FileSerializable): def serialize( self, x: str | None, load_dir: str | Path = "" - ) -> Tuple[FileData | None, None]: + ) -> tuple[FileData | None, None]: return (super().serialize(x, load_dir), None) # type: ignore def deserialize( self, - x: Tuple[FileData | None, FileData | None] | None, + x: tuple[FileData | None, FileData | None] | None, save_dir: Path | str | None = None, root_url: str | None = None, hf_token: str | None = None, - ) -> str | Tuple[str | None, str | None] | None: + ) -> str | tuple[str | None, str | None] | None: """ Convert from serialized representation of a file (base64) to a human-friendly version (string filepath). Optionally, save the file to the directory specified by `save_dir` @@ -378,7 +378,7 @@ class VideoSerializable(FileSerializable): class JSONSerializable(Serializable): - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": ["str | Dict | List", "JSON-serializable object or a string"], "raw_output": ["Dict | List", "dictionary- or list-like object"], @@ -386,7 +386,7 @@ class JSONSerializable(Serializable): "serialized_output": ["str", "filepath to JSON file"], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": {"a": 1, "b": 2}, "serialized": None, @@ -396,7 +396,7 @@ class JSONSerializable(Serializable): self, x: str | None, load_dir: str | Path = "", - ) -> Dict | List | None: + ) -> dict | list | None: """ Convert from a a human-friendly version (string path to json file) to a serialized representation (json string) @@ -410,7 +410,7 @@ class JSONSerializable(Serializable): def deserialize( self, - x: str | Dict | List, + x: str | dict | list, save_dir: str | Path | None = None, root_url: str | None = None, hf_token: str | None = None, @@ -430,7 +430,7 @@ class JSONSerializable(Serializable): class GallerySerializable(Serializable): - def api_info(self) -> Dict[str, List[str]]: + def api_info(self) -> dict[str, list[str]]: return { "raw_input": [ "List[List[str | None]]", @@ -450,7 +450,7 @@ class GallerySerializable(Serializable): ], } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": [media_data.BASE64_IMAGE] * 2, "serialized": [ @@ -461,7 +461,7 @@ class GallerySerializable(Serializable): def serialize( self, x: str | None, load_dir: str | Path = "" - ) -> List[List[str | None]] | None: + ) -> list[list[str | None]] | None: if x is None or x == "": return None files = [] @@ -475,7 +475,7 @@ class GallerySerializable(Serializable): def deserialize( self, - x: List[List[str | None]] | None, + x: list[list[str | None]] | None, save_dir: str = "", root_url: str | None = None, hf_token: str | None = None, @@ -486,7 +486,7 @@ class GallerySerializable(Serializable): gallery_path.mkdir(exist_ok=True, parents=True) captions = {} for img_data in x: - if isinstance(img_data, list) or isinstance(img_data, tuple): + if isinstance(img_data, (list, tuple)): img_data, caption = img_data else: caption = None @@ -510,7 +510,7 @@ SERIALIZER_MAPPING["Serializable"] = SimpleSerializable SERIALIZER_MAPPING["File"] = FileSerializable SERIALIZER_MAPPING["UploadButton"] = FileSerializable -COMPONENT_MAPPING: Dict[str, type] = { +COMPONENT_MAPPING: dict[str, type] = { "textbox": StringSerializable, "number": NumberSerializable, "slider": NumberSerializable, diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index b5f298a916..7e4f4f62fe 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -14,7 +14,7 @@ from datetime import datetime from enum import Enum from pathlib import Path from threading import Lock -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional import fsspec.asyn import httpx @@ -84,7 +84,7 @@ class Status(Enum): CANCELLED = "CANCELLED" @staticmethod - def ordering(status: "Status") -> int: + def ordering(status: Status) -> int: """Order of messages. Helpful for testing.""" order = [ Status.STARTING, @@ -100,11 +100,11 @@ class Status(Enum): ] return order.index(status) - def __lt__(self, other: "Status"): + def __lt__(self, other: Status): return self.ordering(self) < self.ordering(other) @staticmethod - def msg_to_status(msg: str) -> "Status": + def msg_to_status(msg: str) -> Status: """Map the raw message from the backend to the status code presented to users.""" return { "send_hash": Status.JOINING_QUEUE, @@ -127,7 +127,7 @@ class ProgressUnit: desc: Optional[str] @classmethod - def from_ws_msg(cls, data: List[Dict]) -> List["ProgressUnit"]: + def from_ws_msg(cls, data: list[dict]) -> list[ProgressUnit]: return [ cls( index=d.get("index"), @@ -150,7 +150,7 @@ class StatusUpdate: eta: float | None success: bool | None time: datetime | None - progress_data: List[ProgressUnit] | None + progress_data: list[ProgressUnit] | None def create_initial_status_update(): @@ -173,7 +173,7 @@ class JobStatus: """ latest_status: StatusUpdate = field(default_factory=create_initial_status_update) - outputs: List[Any] = field(default_factory=list) + outputs: list[Any] = field(default_factory=list) @dataclass @@ -182,7 +182,7 @@ class Communicator: lock: Lock job: JobStatus - deserialize: Callable[..., Tuple] + deserialize: Callable[..., tuple] reset_url: str should_cancel: bool = False @@ -208,7 +208,7 @@ async def get_pred_from_ws( data: str, hash_data: str, helper: Communicator | None = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: completed = False resp = {} while not completed: @@ -285,9 +285,10 @@ def download_tmp_copy_of_file( suffix=suffix, dir=dir, ) - with requests.get(url_path, headers=headers, stream=True) as r: - with open(file_obj.name, "wb") as f: - shutil.copyfileobj(r.raw, f) + with requests.get(url_path, headers=headers, stream=True) as r, open( + file_obj.name, "wb" + ) as f: + shutil.copyfileobj(r.raw, f) return file_obj @@ -360,7 +361,7 @@ def encode_url_or_file_to_base64(path: str | Path): return encode_file_to_base64(path) -def decode_base64_to_binary(encoding: str) -> Tuple[bytes, str | None]: +def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]: extension = get_extension(encoding) data = encoding.rsplit(",", 1)[-1] return base64.b64decode(data), extension @@ -421,7 +422,7 @@ def decode_base64_to_file( return file_obj -def dict_or_str_to_json_file(jsn: str | Dict | List, dir: str | Path | None = None): +def dict_or_str_to_json_file(jsn: str | dict | list, dir: str | Path | None = None): if dir is not None: os.makedirs(dir, exist_ok=True) @@ -435,7 +436,7 @@ def dict_or_str_to_json_file(jsn: str | Dict | List, dir: str | Path | None = No return file_obj -def file_to_json(file_path: str | Path) -> Dict | List: +def file_to_json(file_path: str | Path) -> dict | list: with open(file_path) as f: return json.load(f) diff --git a/client/python/test/requirements.txt b/client/python/test/requirements.txt index 0a0e97bcb9..8a864c0525 100644 --- a/client/python/test/requirements.txt +++ b/client/python/test/requirements.txt @@ -1,6 +1,6 @@ black==22.6.0 pytest-asyncio pytest==7.1.2 -ruff==0.0.263 +ruff==0.0.264 pyright==1.1.305 gradio diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index a24d2e1e7b..c4fa25c48c 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -48,8 +48,8 @@ class TestPredictionsFromSpaces: @pytest.mark.flaky def test_numerical_to_label_space(self): client = Client("gradio-tests/titanic-survival") - output = client.predict("male", 77, 10, api_name="/predict") - assert json.load(open(output))["label"] == "Perishes" + with open(client.predict("male", 77, 10, api_name="/predict")) as f: + assert json.load(f)["label"] == "Perishes" with pytest.raises( ValueError, match="This Gradio app might have multiple endpoints. Please specify an `api_name` or `fn_index`", @@ -258,7 +258,8 @@ class TestPredictionsFromSpaces: f.write("Hello from private space!") output = client.submit(1, "foo", f.name, api_name="/file_upload").result() - assert open(output).read() == "Hello from private space!" + with open(output) as f: + assert f.read() == "Hello from private space!" upload.assert_called_once() with patch.object( @@ -267,24 +268,28 @@ class TestPredictionsFromSpaces: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: f.write("Hello from private space!") - output = client.submit(f.name, api_name="/upload_btn").result() - assert open(output).read() == "Hello from private space!" + with open(client.submit(f.name, api_name="/upload_btn").result()) as f: + assert f.read() == "Hello from private space!" upload.assert_called_once() with patch.object( client.endpoints[2], "_upload", wraps=client.endpoints[0]._upload ) as upload: + # `delete=False` is required for Windows compat with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1: with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2: - f1.write("File1") f2.write("File2") - - output = client.submit( - 3, [f1.name, f2.name], "hello", api_name="/upload_multiple" + r1, r2 = client.submit( + 3, + [f1.name, f2.name], + "hello", + api_name="/upload_multiple", ).result() - assert open(output[0]).read() == "File1" - assert open(output[1]).read() == "File2" + with open(r1) as f: + assert f.read() == "File1" + with open(r2) as f: + assert f.read() == "File2" upload.assert_called_once() @pytest.mark.flaky @@ -588,10 +593,8 @@ class TestAPIInfo: def test_serializable_in_mapping(self, calculator_demo): with connect(calculator_demo) as client: assert all( - [ - isinstance(c, SimpleSerializable) - for c in client.endpoints[0].serializers - ] + isinstance(c, SimpleSerializable) + for c in client.endpoints[0].serializers ) @pytest.mark.flaky diff --git a/client/python/test/test_serializing.py b/client/python/test/test_serializing.py index a1cef1accf..5fe804e990 100644 --- a/client/python/test/test_serializing.py +++ b/client/python/test/test_serializing.py @@ -34,8 +34,10 @@ def test_file_serializing(): assert serializing.serialize(output) == output files = serializing.deserialize(output) - assert open(files[0]).read() == "Hello World!" - assert open(files[1]).read() == "Greetings!" + with open(files[0]) as f: + assert f.read() == "Hello World!" + with open(files[1]) as f: + assert f.read() == "Greetings!" finally: os.remove(f1.name) os.remove(f2.name) diff --git a/gradio/__init__.py b/gradio/__init__.py index 5928fce9e4..6b2d02a2c2 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -65,7 +65,7 @@ from gradio.flagging import ( SimpleCSVLogger, ) from gradio.helpers import EventData, Progress, make_waveform, skip, update -from gradio.helpers import create_examples as Examples +from gradio.helpers import create_examples as Examples # noqa: N812 from gradio.interface import Interface, TabbedInterface, close_all from gradio.ipython_ext import load_ipython_extension from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs diff --git a/gradio/blocks.py b/gradio/blocks.py index 335d5f2845..cf74b2c65c 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -12,7 +12,7 @@ import warnings import webbrowser from abc import abstractmethod from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type +from typing import TYPE_CHECKING, Any, Callable, Iterator import anyio import requests @@ -34,7 +34,7 @@ from gradio import ( ) from gradio.context import Context from gradio.deprecation import check_deprecated_parameters -from gradio.exceptions import DuplicateBlockError, InvalidApiName +from gradio.exceptions import DuplicateBlockError, InvalidApiNameError from gradio.helpers import EventData, create_tracker, skip, special_args from gradio.themes import Default as DefaultTheme from gradio.themes import ThemeClass as Theme @@ -56,7 +56,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime). from gradio.components import Component -BUILT_IN_THEMES: Dict[str, Theme] = { +BUILT_IN_THEMES: dict[str, Theme] = { t.name: t for t in [ themes.Base(), @@ -74,7 +74,7 @@ class Block: *, render: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, visible: bool = True, root_url: str | None = None, # URL that is prepended to all file paths _skip_init_processing: bool = False, # Used for loading from Spaces @@ -145,15 +145,15 @@ class Block: else self.__class__.__name__.lower() ) - def get_expected_parent(self) -> Type[BlockContext] | None: + def get_expected_parent(self) -> type[BlockContext] | None: return None def set_event_trigger( self, event_name: str, fn: Callable | None, - inputs: Component | List[Component] | Set[Component] | None, - outputs: Component | List[Component] | None, + inputs: Component | list[Component] | set[Component] | None, + outputs: Component | list[Component] | None, preprocess: bool = True, postprocess: bool = True, scroll_to_output: bool = False, @@ -164,12 +164,12 @@ class Block: queue: bool | None = None, batch: bool = False, max_batch_size: int = 4, - cancels: List[int] | None = None, + cancels: list[int] | None = None, every: float | None = None, collects_event_data: bool | None = None, trigger_after: int | None = None, trigger_only_on_success: bool = False, - ) -> Tuple[Dict[str, Any], int]: + ) -> tuple[dict[str, Any], int]: """ Adds an event to the component's dependencies. Parameters: @@ -251,7 +251,7 @@ class Block: api_name_ = utils.append_unique_suffix( api_name, [dep["api_name"] for dep in Context.root_block.dependencies] ) - if not (api_name == api_name_): + if api_name != api_name_: warnings.warn(f"api_name {api_name} already exists, using {api_name_}") api_name = api_name_ @@ -295,11 +295,11 @@ class Block: @staticmethod @abstractmethod - def update(**kwargs) -> Dict: + def update(**kwargs) -> dict: return {} @classmethod - def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict: + def get_specific_update(cls, generic_update: dict[str, Any]) -> dict: generic_update = generic_update.copy() del generic_update["__type__"] specific_update = cls.update(**generic_update) @@ -318,7 +318,7 @@ class BlockContext(Block): visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated). render: If False, this will not be included in the Blocks config file at all. """ - self.children: List[Block] = [] + self.children: list[Block] = [] Block.__init__(self, visible=visible, render=render, **kwargs) def __enter__(self): @@ -368,8 +368,8 @@ class BlockFunction: def __init__( self, fn: Callable | None, - inputs: List[Component], - outputs: List[Component], + inputs: list[Component], + outputs: list[Component], preprocess: bool, postprocess: bool, inputs_as_dict: bool, @@ -399,13 +399,13 @@ class BlockFunction: return str(self) -class class_or_instancemethod(classmethod): +class class_or_instancemethod(classmethod): # noqa: N801 def __get__(self, instance, type_): descr_get = super().__get__ if instance is None else self.__func__.__get__ return descr_get(instance, type_) -def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True): +def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = True): """ Converts a dictionary of updates into a format that can be sent to the frontend. E.g. {"__type__": "generic_update", "value": "2", "interactive": False} @@ -433,15 +433,15 @@ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = def convert_component_dict_to_list( - outputs_ids: List[int], predictions: Dict -) -> List | Dict: + outputs_ids: list[int], predictions: dict +) -> list | dict: """ Converts a dictionary of component updates into a list of updates in the order of the outputs_ids and including every output component. Leaves other types of dictionaries unchanged. E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}} Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}] """ - keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()] + keys_are_blocks = [isinstance(key, Block) for key in predictions] if all(keys_are_blocks): reordered_predictions = [skip() for _ in outputs_ids] for component, value in predictions.items(): @@ -459,7 +459,7 @@ def convert_component_dict_to_list( return predictions -def get_api_info(config: Dict, serialize: bool = True): +def get_api_info(config: dict, serialize: bool = True): """ Gets the information needed to generate the API docs from a Blocks config. Parameters: @@ -662,8 +662,8 @@ class Blocks(BlockContext): if not self.analytics_enabled: os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True" super().__init__(render=False, **kwargs) - self.blocks: Dict[int, Block] = {} - self.fns: List[BlockFunction] = [] + self.blocks: dict[int, Block] = {} + self.fns: list[BlockFunction] = [] self.dependencies = [] self.mode = mode @@ -674,7 +674,7 @@ class Blocks(BlockContext): self.height = None self.api_open = True - self.is_space = True if os.getenv("SYSTEM") == "spaces" else False + self.is_space = os.getenv("SYSTEM") == "spaces" self.favicon_path = None self.auth = None self.dev_mode = True @@ -713,7 +713,7 @@ class Blocks(BlockContext): def from_config( cls, config: dict, - fns: List[Callable], + fns: list[Callable], root_url: str | None = None, ) -> Blocks: """ @@ -727,7 +727,7 @@ class Blocks(BlockContext): config = copy.deepcopy(config) components_config = config["components"] theme = config.get("theme", "default") - original_mapping: Dict[int, Block] = {} + original_mapping: dict[int, Block] = {} def get_block_instance(id: int) -> Block: for block_config in components_config: @@ -862,7 +862,7 @@ class Blocks(BlockContext): api_name, [dep["api_name"] for dep in Context.root_block.dependencies], ) - if not (api_name == api_name_): + if api_name != api_name_: warnings.warn( f"api_name {api_name} already exists, using {api_name_}" ) @@ -937,7 +937,9 @@ class Blocks(BlockContext): None, ) if inferred_fn_index is None: - raise InvalidApiName(f"Cannot find a function with api_name {api_name}") + raise InvalidApiNameError( + f"Cannot find a function with api_name {api_name}" + ) fn_index = inferred_fn_index if not (self.is_callable(fn_index)): raise ValueError( @@ -970,9 +972,9 @@ class Blocks(BlockContext): async def call_function( self, fn_index: int, - processed_input: List[Any], + processed_input: list[Any], iterator: Iterator[Any] | None = None, - requests: routes.Request | List[routes.Request] | None = None, + requests: routes.Request | list[routes.Request] | None = None, event_id: str | None = None, event_data: EventData | None = None, ): @@ -993,10 +995,7 @@ class Blocks(BlockContext): if block_fn.inputs_as_dict: processed_input = [dict(zip(block_fn.inputs, processed_input))] - if isinstance(requests, list): - request = requests[0] - else: - request = requests + request = requests[0] if isinstance(requests, list) else requests processed_input, progress_index, _ = special_args( block_fn.fn, processed_input, request, event_data ) @@ -1054,7 +1053,7 @@ class Blocks(BlockContext): "iterator": iterator, } - def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]: + def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]: dependency = self.dependencies[fn_index] processed_input = [] @@ -1068,7 +1067,7 @@ class Blocks(BlockContext): return processed_input - def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]: + def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]: dependency = self.dependencies[fn_index] predictions = [] @@ -1084,7 +1083,7 @@ class Blocks(BlockContext): return predictions - def validate_inputs(self, fn_index: int, inputs: List[Any]): + def validate_inputs(self, fn_index: int, inputs: list[Any]): block_fn = self.fns[fn_index] dependency = self.dependencies[fn_index] @@ -1106,10 +1105,7 @@ class Blocks(BlockContext): block = self.blocks[input_id] wanted_args.append(str(block)) for inp in inputs: - if isinstance(inp, str): - v = f'"{inp}"' - else: - v = str(inp) + v = f'"{inp}"' if isinstance(inp, str) else str(inp) received_args.append(v) wanted = ", ".join(wanted_args) @@ -1125,7 +1121,7 @@ Received inputs: [{received}]""" ) - def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]): + def preprocess_data(self, fn_index: int, inputs: list[Any], state: dict[int, Any]): block_fn = self.fns[fn_index] dependency = self.dependencies[fn_index] @@ -1146,7 +1142,7 @@ Received inputs: processed_input = inputs return processed_input - def validate_outputs(self, fn_index: int, predictions: Any | List[Any]): + def validate_outputs(self, fn_index: int, predictions: Any | list[Any]): block_fn = self.fns[fn_index] dependency = self.dependencies[fn_index] @@ -1168,10 +1164,7 @@ Received inputs: block = self.blocks[output_id] wanted_args.append(str(block)) for pred in predictions: - if isinstance(pred, str): - v = f'"{pred}"' - else: - v = str(pred) + v = f'"{pred}"' if isinstance(pred, str) else str(pred) received_args.append(v) wanted = ", ".join(wanted_args) @@ -1186,7 +1179,7 @@ Received outputs: ) def postprocess_data( - self, fn_index: int, predictions: List | Dict, state: Dict[int, Any] + self, fn_index: int, predictions: list | dict, state: dict[int, Any] ): block_fn = self.fns[fn_index] dependency = self.dependencies[fn_index] @@ -1241,13 +1234,13 @@ Received outputs: async def process_api( self, fn_index: int, - inputs: List[Any], - state: Dict[int, Any], - request: routes.Request | List[routes.Request] | None = None, - iterators: Dict[int, Any] | None = None, + inputs: list[Any], + state: dict[int, Any], + request: routes.Request | list[routes.Request] | None = None, + iterators: dict[int, Any] | None = None, event_id: str | None = None, event_data: EventData | None = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Processes API calls from the frontend. First preprocesses the data, then runs the relevant function, then postprocesses the output. @@ -1342,15 +1335,15 @@ Received outputs: "theme": self.theme.name, } - def getLayout(block): + def get_layout(block): if not isinstance(block, BlockContext): return {"id": block._id} children_layout = [] for child in block.children: - children_layout.append(getLayout(child)) + children_layout.append(get_layout(child)) return {"id": block._id, "children": children_layout} - config["layout"] = getLayout(self) + config["layout"] = get_layout(self) for _id, block in self.blocks.items(): props = block.get_config() if hasattr(block, "get_config") else {} @@ -1393,10 +1386,10 @@ Received outputs: @class_or_instancemethod def load( - self_or_cls, + self_or_cls, # noqa: N805 fn: Callable | None = None, - inputs: List[Component] | None = None, - outputs: List[Component] | None = None, + inputs: list[Component] | None = None, + outputs: list[Component] | None = None, api_name: str | None = None, scroll_to_output: bool = False, show_progress: bool = True, @@ -1413,7 +1406,7 @@ Received outputs: api_key: str | None = None, alias: str | None = None, **kwargs, - ) -> Blocks | Dict[str, Any] | None: + ) -> Blocks | dict[str, Any] | None: """ For reverse compatibility reasons, this is both a class method and an instance method, the two of which, confusingly, do two completely different things. @@ -1571,7 +1564,7 @@ Received outputs: debug: bool = False, enable_queue: bool | None = None, max_threads: int = 40, - auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None, + auth: Callable | tuple[str, str] | list[tuple[str, str]] | None = None, auth_message: str | None = None, prevent_thread_lock: bool = False, show_error: bool = False, @@ -1588,11 +1581,11 @@ Received outputs: ssl_verify: bool = True, quiet: bool = False, show_api: bool = True, - file_directories: List[str] | None = None, - allowed_paths: List[str] | None = None, - blocked_paths: List[str] | None = None, + file_directories: list[str] | None = None, + allowed_paths: list[str] | None = None, + blocked_paths: list[str] | None = None, _frontend: bool = True, - ) -> Tuple[FastAPI, str, str]: + ) -> tuple[FastAPI, str, str]: """ Launches a simple web server that serves the demo. Can also be used to create a public link used by anyone to access the demo from their browser by setting share=True. diff --git a/gradio/components.py b/gradio/components.py index b9eae05ac9..f557ff6e85 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -20,7 +20,7 @@ from copy import deepcopy from enum import Enum from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Type, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, cast import aiofiles import altair as alt @@ -76,8 +76,8 @@ if TYPE_CHECKING: from typing import TypedDict class DataframeData(TypedDict): - headers: List[str] - data: List[List[str | int | bool]] + headers: list[str] + data: list[list[str | int | bool]] set_documentation_group("component") @@ -141,18 +141,14 @@ class Component(Block, Serializable): warnings.warn( "'rounded' styling is no longer supported. To round adjacent components together, place them in a Column(variant='box')." ) - if isinstance(kwargs["rounded"], list) or isinstance( - kwargs["rounded"], tuple - ): + if isinstance(kwargs["rounded"], (list, tuple)): put_deprecated_params_in_box = True kwargs.pop("rounded") if "margin" in kwargs: warnings.warn( "'margin' styling is no longer supported. To place adjacent components together without margin, place them in a Column(variant='box')." ) - if isinstance(kwargs["margin"], list) or isinstance( - kwargs["margin"], tuple - ): + if isinstance(kwargs["margin"], (list, tuple)): put_deprecated_params_in_box = True kwargs.pop("margin") if "border" in kwargs: @@ -165,9 +161,12 @@ class Component(Block, Serializable): if len(kwargs): for key in kwargs: warnings.warn(f"Unknown style parameter: {key}") - if put_deprecated_params_in_box and isinstance(self.parent, (Row, Column)): - if self.parent.variant == "default": - self.parent.variant = "compact" + if ( + put_deprecated_params_in_box + and isinstance(self.parent, (Row, Column)) + and self.parent.variant == "default" + ): + self.parent.variant = "compact" return self @@ -186,12 +185,12 @@ class IOComponent(Component): interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, load_fn: Callable | None = None, every: float | None = None, **kwargs, ): - self.temp_files: Set[str] = set() + self.temp_files: set[str] = set() self.DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str( Path(tempfile.gettempdir()) / "gradio" ) @@ -206,7 +205,7 @@ class IOComponent(Component): self.interactive = interactive # load_event is set in the Blocks.attach_load_events method - self.load_event: None | Dict[str, Any] = None + self.load_event: None | dict[str, Any] = None self.load_event_to_attach = None load_fn, initial_value = self.get_load_fn_and_initial_value(value) self.value = ( @@ -298,9 +297,10 @@ class IOComponent(Component): full_temp_file_path = str(utils.abspath(temp_dir / f.name)) if not Path(full_temp_file_path).exists(): - with requests.get(url, stream=True) as r: - with open(full_temp_file_path, "wb") as f: - shutil.copyfileobj(r.raw, f) + with requests.get(url, stream=True) as r, open( + full_temp_file_path, "wb" + ) as f: + shutil.copyfileobj(r.raw, f) self.temp_files.add(full_temp_file_path) return full_temp_file_path @@ -365,7 +365,7 @@ class IOComponent(Component): class FormComponent: - def get_expected_parent(self) -> Type[Form]: + def get_expected_parent(self) -> type[Form]: return Form @@ -404,7 +404,7 @@ class Textbox( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, type: str = "text", **kwargs, ): @@ -525,7 +525,7 @@ class Textbox( self.interpretation_replacement = replacement return self - def tokenize(self, x: str) -> Tuple[List[str], List[str], None]: + def tokenize(self, x: str) -> tuple[list[str], list[str], None]: """ Tokenizes an input string by dividing into "words" delimited by self.interpretation_separator """ @@ -543,8 +543,8 @@ class Textbox( return tokens, leave_one_out_strings, None def get_masked_inputs( - self, tokens: List[str], binary_mask_matrix: List[List[int]] - ) -> List[str]: + self, tokens: list[str], binary_mask_matrix: list[list[int]] + ) -> list[str]: """ Constructs partially-masked sentences for SHAP interpretation """ @@ -555,8 +555,8 @@ class Textbox( return masked_inputs def get_interpretation_scores( - self, x, neighbors, scores: List[float], tokens: List[str], masks=None, **kwargs - ) -> List[Tuple[str, float]]: + self, x, neighbors, scores: list[float], tokens: list[str], masks=None, **kwargs + ) -> list[tuple[str, float]]: """ Returns: Each tuple set represents a set of characters and their corresponding interpretation score. @@ -616,7 +616,7 @@ class Number( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, precision: int | None = None, **kwargs, ): @@ -731,7 +731,7 @@ class Number( self.interpretation_delta_type = delta_type return self - def get_interpretation_neighbors(self, x: float | int) -> Tuple[List[float], Dict]: + def get_interpretation_neighbors(self, x: float | int) -> tuple[list[float], dict]: x = self._round_to_precision(x, self.precision) if self.interpretation_delta_type == "percent": delta = 1.0 * self.interpretation_delta * x / 100 @@ -755,8 +755,8 @@ class Number( return negatives + positives, {} def get_interpretation_scores( - self, x: float, neighbors: List[float], scores: List[float | None], **kwargs - ) -> List[Tuple[float, float | None]]: + self, x: float, neighbors: list[float], scores: list[float | None], **kwargs + ) -> list[tuple[float, float | None]]: """ Returns: Each tuple set represents a numeric value near the input and its corresponding interpretation score. @@ -799,7 +799,7 @@ class Slider( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, randomize: bool = False, **kwargs, ): @@ -845,7 +845,7 @@ class Slider( NeighborInterpretable.__init__(self) self.cleared_value = self.value - def api_info(self) -> Dict[str, Tuple[str, str]]: + def api_info(self) -> dict[str, tuple[str, str]]: description = f"numeric value between {self.minimum} and {self.maximum}" return { "raw_input": ("int | float", description), @@ -854,7 +854,7 @@ class Slider( "serialized_output": ("int | float", description), } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": self.minimum, "serialized": self.minimum, @@ -912,7 +912,7 @@ class Slider( """ return self.minimum if y is None else y - def set_interpret_parameters(self, steps: int = 8) -> "Slider": + def set_interpret_parameters(self, steps: int = 8) -> Slider: """ Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider. Parameters: @@ -921,7 +921,7 @@ class Slider( self.interpretation_steps = steps return self - def get_interpretation_neighbors(self, x) -> Tuple[object, dict]: + def get_interpretation_neighbors(self, x) -> tuple[object, dict]: return ( np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}, @@ -973,7 +973,7 @@ class Checkbox( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -1065,9 +1065,9 @@ class CheckboxGroup( def __init__( self, - choices: List[str] | None = None, + choices: list[str] | None = None, *, - value: List[str] | str | Callable | None = None, + value: list[str] | str | Callable | None = None, type: str = "value", label: str | None = None, info: str | None = None, @@ -1076,7 +1076,7 @@ class CheckboxGroup( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -1129,7 +1129,7 @@ class CheckboxGroup( **IOComponent.get_config(self), } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": self.choices[0] if self.choices else None, "serialized": self.choices[0] if self.choices else None, @@ -1137,11 +1137,11 @@ class CheckboxGroup( @staticmethod def update( - value: List[str] + value: list[str] | str | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, - choices: List[str] | None = None, + choices: list[str] | None = None, label: str | None = None, show_label: bool | None = None, interactive: bool | None = None, @@ -1157,7 +1157,7 @@ class CheckboxGroup( "__type__": "update", } - def preprocess(self, x: List[str]) -> List[str] | List[int]: + def preprocess(self, x: list[str]) -> list[str] | list[int]: """ Parameters: x: list of selected choices @@ -1173,7 +1173,7 @@ class CheckboxGroup( f"Unknown type: {self.type}. Please choose from: 'value', 'index'." ) - def postprocess(self, y: List[str] | str | None) -> List[str]: + def postprocess(self, y: list[str] | str | None) -> list[str]: """ Any postprocessing needed to be performed on function output. Parameters: @@ -1205,10 +1205,7 @@ class CheckboxGroup( """ final_scores = [] for choice, score in zip(self.choices, scores): - if choice in x: - score_set = [score, None] - else: - score_set = [None, score] + score_set = [score, None] if choice in x else [None, score] final_scores.append(score_set) return final_scores @@ -1252,7 +1249,7 @@ class Radio( def __init__( self, - choices: List[str] | None = None, + choices: list[str] | None = None, *, value: str | Callable | None = None, type: str = "value", @@ -1263,7 +1260,7 @@ class Radio( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -1316,7 +1313,7 @@ class Radio( **IOComponent.get_config(self), } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": self.choices[0] if self.choices else None, "serialized": self.choices[0] if self.choices else None, @@ -1325,7 +1322,7 @@ class Radio( @staticmethod def update( value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, - choices: List[str] | None = None, + choices: list[str] | None = None, label: str | None = None, show_label: bool | None = None, interactive: bool | None = None, @@ -1366,8 +1363,8 @@ class Radio( return choices, {} def get_interpretation_scores( - self, x, neighbors, scores: List[float | None], **kwargs - ) -> List: + self, x, neighbors, scores: list[float | None], **kwargs + ) -> list: """ Returns: Each value represents the interpretation score corresponding to each choice. @@ -1409,9 +1406,9 @@ class Dropdown( def __init__( self, - choices: List[str] | None = None, + choices: list[str] | None = None, *, - value: str | List[str] | Callable | None = None, + value: str | list[str] | Callable | None = None, type: str = "value", multiselect: bool | None = None, max_choices: int | None = None, @@ -1422,7 +1419,7 @@ class Dropdown( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, allow_custom_value: bool = False, **kwargs, ): @@ -1451,9 +1448,8 @@ class Dropdown( ) self.type = type self.multiselect = multiselect - if multiselect: - if isinstance(value, str): - value = [value] + if multiselect and isinstance(value, str): + value = [value] if not multiselect and max_choices is not None: warnings.warn( "The `max_choices` parameter is ignored when `multiselect` is False." @@ -1487,7 +1483,7 @@ class Dropdown( self.cleared_value = self.value or ([] if multiselect else "") - def api_info(self) -> Dict[str, Tuple[str, str]]: + def api_info(self) -> dict[str, tuple[str, str]]: if self.multiselect: type = "List[str]" description = f"List of options from: {self.choices}" @@ -1501,7 +1497,7 @@ class Dropdown( "serialized_output": (type, description), } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: if self.multiselect: return { "raw": [self.choices[0]] if self.choices else [], @@ -1526,7 +1522,7 @@ class Dropdown( @staticmethod def update( value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, - choices: str | List[str] | None = None, + choices: str | list[str] | None = None, label: str | None = None, show_label: bool | None = None, interactive: bool | None = None, @@ -1545,8 +1541,8 @@ class Dropdown( } def preprocess( - self, x: str | List[str] - ) -> str | int | List[str] | List[int] | None: + self, x: str | list[str] + ) -> str | int | list[str] | list[int] | None: """ Parameters: x: selected choice(s) @@ -1580,8 +1576,8 @@ class Dropdown( return choices, {} def get_interpretation_scores( - self, x, neighbors, scores: List[float | None], **kwargs - ) -> List: + self, x, neighbors, scores: list[float | None], **kwargs + ) -> list: """ Returns: Each value represents the interpretation score corresponding to each choice. @@ -1624,7 +1620,7 @@ class Image( self, value: str | _Image.Image | np.ndarray | None = None, *, - shape: Tuple[int, int] | None = None, + shape: tuple[int, int] | None = None, image_mode: str = "RGB", invert_colors: bool = False, source: str = "upload", @@ -1637,7 +1633,7 @@ class Image( visible: bool = True, streaming: bool = False, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, mirror_webcam: bool = True, brush_radius: float | None = None, **kwargs, @@ -1766,8 +1762,8 @@ class Image( ) def preprocess( - self, x: str | Dict[str, str] - ) -> np.ndarray | _Image.Image | str | Dict | None: + self, x: str | dict[str, str] + ) -> np.ndarray | _Image.Image | str | dict | None: """ Parameters: x: base64 url data, or (if tool == "sketch") a dict of image and mask base64 url data @@ -1906,7 +1902,7 @@ class Image( def get_interpretation_scores( self, x, neighbors, scores, masks, tokens=None, **kwargs - ) -> List[List[float]]: + ) -> list[list[float]]: """ Returns: A 2D array representing the interpretation score of each pixel of the image. @@ -1977,7 +1973,7 @@ class Video( def __init__( self, - value: str | Tuple[str, str | None] | Callable | None = None, + value: str | tuple[str, str | None] | Callable | None = None, *, format: str | None = None, source: str = "upload", @@ -1987,7 +1983,7 @@ class Video( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, mirror_webcam: bool = True, include_audio: bool | None = None, **kwargs, @@ -2043,7 +2039,7 @@ class Video( @staticmethod def update( value: str - | Tuple[str, str | None] + | tuple[str, str | None] | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, source: str | None = None, @@ -2063,7 +2059,7 @@ class Video( } def preprocess( - self, x: Tuple[FileData, FileData | None] | FileData | None + self, x: tuple[FileData, FileData | None] | FileData | None ) -> str | None: """ Parameters: @@ -2123,8 +2119,8 @@ class Video( return str(file_name) def postprocess( - self, y: str | Tuple[str, str | None] | None - ) -> Tuple[FileData | None, FileData | None] | None: + self, y: str | tuple[str, str | None] | None + ) -> tuple[FileData | None, FileData | None] | None: """ Processes a video to ensure that it is in the correct format before returning it to the front end. @@ -2229,7 +2225,7 @@ class Video( def srt_to_vtt(srt_file_path, vtt_file_path): """Convert an SRT subtitle file to a VTT subtitle file""" - with open(srt_file_path, "r", encoding="utf-8") as srt_file, open( + with open(srt_file_path, encoding="utf-8") as srt_file, open( vtt_file_path, "w", encoding="utf-8" ) as vtt_file: vtt_file.write("WEBVTT\n\n") @@ -2301,7 +2297,7 @@ class Audio( def __init__( self, - value: str | Tuple[int, np.ndarray] | Callable | None = None, + value: str | tuple[int, np.ndarray] | Callable | None = None, *, source: str = "upload", type: str = "numpy", @@ -2312,7 +2308,7 @@ class Audio( visible: bool = True, streaming: bool = False, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -2368,7 +2364,7 @@ class Audio( **IOComponent.get_config(self), } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": {"is_file": False, "data": media_data.BASE64_AUDIO}, "serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav", @@ -2394,8 +2390,8 @@ class Audio( } def preprocess( - self, x: Dict[str, Any] | None - ) -> Tuple[int, np.ndarray] | str | None: + self, x: dict[str, Any] | None + ) -> tuple[int, np.ndarray] | str | None: """ Parameters: x: dictionary with keys "name", "data", "is_file", "crop_min", "crop_max". @@ -2521,7 +2517,7 @@ class Audio( masked_inputs.append(masked_data) return masked_inputs - def postprocess(self, y: Tuple[int, np.ndarray] | str | None) -> str | Dict | None: + def postprocess(self, y: tuple[int, np.ndarray] | str | None) -> str | dict | None: """ Parameters: y: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None. @@ -2584,10 +2580,10 @@ class File( def __init__( self, - value: str | List[str] | Callable | None = None, + value: str | list[str] | Callable | None = None, *, file_count: str = "single", - file_types: List[str] | None = None, + file_types: list[str] | None = None, type: str = "file", label: str | None = None, every: float | None = None, @@ -2595,7 +2591,7 @@ class File( interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -2682,11 +2678,11 @@ class File( } def preprocess( - self, x: List[Dict[str, Any]] | None + self, x: list[dict[str, Any]] | None ) -> ( bytes | tempfile._TemporaryFileWrapper - | List[bytes | tempfile._TemporaryFileWrapper] + | list[bytes | tempfile._TemporaryFileWrapper] | None ): """ @@ -2741,8 +2737,8 @@ class File( return process_single_file(x) def postprocess( - self, y: str | List[str] | None - ) -> Dict[str, Any] | List[Dict[str, Any]] | None: + self, y: str | list[str] | None + ) -> dict[str, Any] | list[dict[str, Any]] | None: """ Parameters: y: file path @@ -2784,7 +2780,7 @@ class File( ) return self - def as_example(self, input_data: str | List | None) -> str: + def as_example(self, input_data: str | list | None) -> str: if input_data is None: return "" elif isinstance(input_data, list): @@ -2807,12 +2803,12 @@ class Dataframe(Changeable, Selectable, IOComponent, JSONSerializable): def __init__( self, - value: List[List[Any]] | Callable | None = None, + value: list[list[Any]] | Callable | None = None, *, - headers: List[str] | None = None, - row_count: int | Tuple[int, str] = (1, "dynamic"), - col_count: int | Tuple[int, str] | None = None, - datatype: str | List[str] = "str", + headers: list[str] | None = None, + row_count: int | tuple[int, str] = (1, "dynamic"), + col_count: int | tuple[int, str] | None = None, + datatype: str | list[str] = "str", type: str = "pandas", max_rows: int | None = 20, max_cols: int | None = None, @@ -2823,7 +2819,7 @@ class Dataframe(Changeable, Selectable, IOComponent, JSONSerializable): interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, wrap: bool = False, **kwargs, ): @@ -2965,8 +2961,8 @@ class Dataframe(Changeable, Selectable, IOComponent, JSONSerializable): ) def postprocess( - self, y: str | pd.DataFrame | np.ndarray | List[List[str | float]] | Dict - ) -> Dict: + self, y: str | pd.DataFrame | np.ndarray | list[list[str | float]] | dict + ) -> dict: """ Parameters: y: dataframe in given format @@ -3016,7 +3012,7 @@ class Dataframe(Changeable, Selectable, IOComponent, JSONSerializable): raise ValueError("Cannot process value as a Dataframe") @staticmethod - def __process_counts(count, default=3) -> Tuple[int, str]: + def __process_counts(count, default=3) -> tuple[int, str]: if count is None: return (default, "dynamic") if type(count) == int or type(count) == float: @@ -3025,7 +3021,7 @@ class Dataframe(Changeable, Selectable, IOComponent, JSONSerializable): return count @staticmethod - def __validate_headers(headers: List[str] | None, col_count: int): + def __validate_headers(headers: list[str] | None, col_count: int): if headers is not None and len(headers) != col_count: raise ValueError( f"The length of the headers list must be equal to the col_count int.\n" @@ -3034,7 +3030,7 @@ class Dataframe(Changeable, Selectable, IOComponent, JSONSerializable): ) @classmethod - def __process_markdown(cls, data: List[List[Any]], datatype: List[str]): + def __process_markdown(cls, data: list[list[Any]], datatype: list[str]): if "markdown" not in datatype: return data @@ -3086,15 +3082,15 @@ class Timeseries(Changeable, IOComponent, JSONSerializable): value: str | Callable | None = None, *, x: str | None = None, - y: str | List[str] | None = None, - colors: List[str] | None = None, + y: str | list[str] | None = None, + colors: list[str] | None = None, label: str | None = None, every: float | None = None, show_label: bool = True, interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -3141,7 +3137,7 @@ class Timeseries(Changeable, IOComponent, JSONSerializable): @staticmethod def update( value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, - colors: List[str] | None = None, + colors: list[str] | None = None, label: str | None = None, show_label: bool | None = None, interactive: bool | None = None, @@ -3157,7 +3153,7 @@ class Timeseries(Changeable, IOComponent, JSONSerializable): "__type__": "update", } - def preprocess(self, x: Dict | None) -> pd.DataFrame | None: + def preprocess(self, x: dict | None) -> pd.DataFrame | None: """ Parameters: x: Dict with keys 'data': 2D array of str, numeric, or bool data, 'headers': list of strings for header names, 'range': optional two element list designating start of end of subrange. @@ -3175,7 +3171,7 @@ class Timeseries(Changeable, IOComponent, JSONSerializable): dataframe = dataframe.loc[dataframe[self.x or 0] <= x["range"][1]] return dataframe - def postprocess(self, y: str | pd.DataFrame | None) -> Dict | None: + def postprocess(self, y: str | pd.DataFrame | None) -> dict | None: """ Parameters: y: csv or dataframe with timeseries data @@ -3266,7 +3262,7 @@ class Button(Clickable, IOComponent, StringSerializable): visible: bool = True, interactive: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -3350,14 +3346,14 @@ class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable): def __init__( self, label: str = "Upload a File", - value: str | List[str] | Callable | None = None, + value: str | list[str] | Callable | None = None, *, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, type: str = "file", file_count: str = "single", - file_types: List[str] | None = None, + file_types: list[str] | None = None, **kwargs, ): """ @@ -3416,11 +3412,11 @@ class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable): } def preprocess( - self, x: List[Dict[str, Any]] | None + self, x: list[dict[str, Any]] | None ) -> ( bytes | tempfile._TemporaryFileWrapper - | List[bytes | tempfile._TemporaryFileWrapper] + | list[bytes | tempfile._TemporaryFileWrapper] | None ): """ @@ -3515,7 +3511,7 @@ class ColorPicker(Changeable, Submittable, Blurrable, IOComponent, StringSeriali interactive: bool | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -3545,7 +3541,7 @@ class ColorPicker(Changeable, Submittable, Blurrable, IOComponent, StringSeriali **kwargs, ) - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": "#000000", "serialized": "#000000", @@ -3621,7 +3617,7 @@ class Label(Changeable, Selectable, IOComponent, JSONSerializable): def __init__( self, - value: Dict[str, float] | str | float | Callable | None = None, + value: dict[str, float] | str | float | Callable | None = None, *, num_top_classes: int | None = None, label: str | None = None, @@ -3629,7 +3625,7 @@ class Label(Changeable, Selectable, IOComponent, JSONSerializable): show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, color: str | None = None, **kwargs, ): @@ -3674,7 +3670,7 @@ class Label(Changeable, Selectable, IOComponent, JSONSerializable): **IOComponent.get_config(self), } - def postprocess(self, y: Dict[str, float] | str | float | None) -> Dict | None: + def postprocess(self, y: dict[str, float] | str | float | None) -> dict | None: """ Parameters: y: a dictionary mapping labels to confidence value, or just a string/numerical label by itself @@ -3708,7 +3704,7 @@ class Label(Changeable, Selectable, IOComponent, JSONSerializable): @staticmethod def update( - value: Dict[str, float] + value: dict[str, float] | str | float | Literal[_Keywords.NO_VALUE] @@ -3764,9 +3760,9 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable): def __init__( self, - value: List[Tuple[str, str | float | None]] | Dict | Callable | None = None, + value: list[tuple[str, str | float | None]] | dict | Callable | None = None, *, - color_map: Dict[str, str] + color_map: dict[str, str] | None = None, # Parameter moved to HighlightedText.style() show_legend: bool = False, combine_adjacent: bool = False, @@ -3776,7 +3772,7 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable): show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -3829,11 +3825,11 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable): @staticmethod def update( - value: List[Tuple[str, str | float | None]] - | Dict + value: list[tuple[str, str | float | None]] + | dict | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, - color_map: Dict[str, str] | None = None, + color_map: dict[str, str] | None = None, show_legend: bool | None = None, label: str | None = None, show_label: bool | None = None, @@ -3851,8 +3847,8 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable): return updated_config def postprocess( - self, y: List[Tuple[str, str | float | None]] | Dict | None - ) -> List[Tuple[str, str | float | None]] | None: + self, y: list[tuple[str, str | float | None]] | dict | None + ) -> list[tuple[str, str | float | None]] | None: """ Parameters: y: List of (word, category) tuples @@ -3910,7 +3906,7 @@ class HighlightedText(Changeable, Selectable, IOComponent, JSONSerializable): def style( self, *, - color_map: Dict[str, str] | None = None, + color_map: dict[str, str] | None = None, container: bool | None = None, **kwargs, ): @@ -3939,9 +3935,9 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable): def __init__( self, - value: Tuple[ + value: tuple[ np.ndarray | _Image.Image | str, - List[Tuple[np.ndarray | Tuple[int, int, int, int], str]], + list[tuple[np.ndarray | tuple[int, int, int, int], str]], ] | None = None, *, @@ -3951,7 +3947,7 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable): show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -3994,9 +3990,9 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable): @staticmethod def update( - value: Tuple[ + value: tuple[ np.ndarray | _Image.Image | str, - List[Tuple[np.ndarray | Tuple[int, int, int, int], str]], + list[tuple[np.ndarray | tuple[int, int, int, int], str]], ] | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, show_legend: bool | None = None, @@ -4016,11 +4012,11 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable): def postprocess( self, - y: Tuple[ + y: tuple[ np.ndarray | _Image.Image | str, - List[Tuple[np.ndarray | Tuple[int, int, int, int], str]], + list[tuple[np.ndarray | tuple[int, int, int, int], str]], ], - ) -> Tuple[dict, List[Tuple[dict, str]]] | None: + ) -> tuple[dict, list[tuple[dict, str]]] | None: """ Parameters: y: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label. @@ -4060,12 +4056,12 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable): mask_array = mask else: x1, y1, x2, y2 = mask - BORDER_WIDTH = 3 + border_width = 3 mask_array[y1:y2, x1:x2] = 0.5 - mask_array[y1:y2, x1 : x1 + BORDER_WIDTH] = 1 - mask_array[y1:y2, x2 - BORDER_WIDTH : x2] = 1 - mask_array[y1 : y1 + BORDER_WIDTH, x1:x2] = 1 - mask_array[y2 - BORDER_WIDTH : y2, x1:x2] = 1 + mask_array[y1:y2, x1 : x1 + border_width] = 1 + mask_array[y1:y2, x2 - border_width : x2] = 1 + mask_array[y1 : y1 + border_width, x1:x2] = 1 + mask_array[y2 - border_width : y2, x1:x2] = 1 if label in color_map: rgb_color = hex_to_rgb(color_map[label]) @@ -4097,7 +4093,7 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable): *, height: int | None = None, width: int | None = None, - color_map: Dict[str, str] | None = None, + color_map: dict[str, str] | None = None, **kwargs, ): """ @@ -4129,14 +4125,14 @@ class JSON(Changeable, IOComponent, JSONSerializable): def __init__( self, - value: str | Dict | List | Callable | None = None, + value: str | dict | list | Callable | None = None, *, label: str | None = None, every: float | None = None, show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -4183,7 +4179,7 @@ class JSON(Changeable, IOComponent, JSONSerializable): } return updated_config - def postprocess(self, y: Dict | List | str | None) -> Dict | List | None: + def postprocess(self, y: dict | list | str | None) -> dict | list | None: """ Parameters: y: either a string filepath to a JSON file, or a Python list or dict that can be converted to JSON @@ -4227,7 +4223,7 @@ class HTML(Changeable, IOComponent, StringSerializable): show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -4290,14 +4286,14 @@ class Gallery(IOComponent, GallerySerializable, Selectable): def __init__( self, - value: List[np.ndarray | _Image.Image | str | Tuple] | Callable | None = None, + value: list[np.ndarray | _Image.Image | str | tuple] | Callable | None = None, *, label: str | None = None, every: float | None = None, show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -4352,10 +4348,10 @@ class Gallery(IOComponent, GallerySerializable, Selectable): def postprocess( self, - y: List[np.ndarray | _Image.Image | str] - | List[Tuple[np.ndarray | _Image.Image | str, str]] + y: list[np.ndarray | _Image.Image | str] + | list[tuple[np.ndarray | _Image.Image | str, str]] | None, - ) -> List[str]: + ) -> list[str]: """ Parameters: y: list of images, or list of (image, caption) tuples @@ -4367,7 +4363,7 @@ class Gallery(IOComponent, GallerySerializable, Selectable): output = [] for img in y: caption = None - if isinstance(img, tuple) or isinstance(img, list): + if isinstance(img, (tuple, list)): img, caption = img if isinstance(img, np.ndarray): file = processing_utils.save_array_to_file(img) @@ -4397,9 +4393,9 @@ class Gallery(IOComponent, GallerySerializable, Selectable): def style( self, *, - grid: int | Tuple | None = None, - columns: int | Tuple | None = None, - rows: int | Tuple | None = None, + grid: int | tuple | None = None, + columns: int | tuple | None = None, + rows: int | tuple | None = None, height: str | None = None, container: bool | None = None, preview: bool | None = None, @@ -4465,17 +4461,17 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): def __init__( self, - value: List[List[str | Tuple[str] | Tuple[str, str] | None]] + value: list[list[str | tuple[str] | tuple[str, str] | None]] | Callable | None = None, - color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style() + color_map: dict[str, str] | None = None, # Parameter moved to Chatbot.style() *, label: str | None = None, every: float | None = None, show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -4521,7 +4517,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): @staticmethod def update( - value: List[List[str | Tuple[str] | Tuple[str, str] | None]] + value: list[list[str | tuple[str] | tuple[str, str] | None]] | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, label: str | None = None, @@ -4538,8 +4534,8 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): return updated_config def _preprocess_chat_messages( - self, chat_message: str | Dict | None - ) -> str | Tuple[str] | Tuple[str, str] | None: + self, chat_message: str | dict | None + ) -> str | tuple[str] | tuple[str, str] | None: if chat_message is None: return None elif isinstance(chat_message, dict): @@ -4552,8 +4548,8 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): def preprocess( self, - y: List[List[str | Dict | None] | Tuple[str | Dict | None, str | Dict | None]], - ) -> List[List[str | Tuple[str] | Tuple[str, str] | None]]: + y: list[list[str | dict | None] | tuple[str | dict | None, str | dict | None]], + ) -> list[list[str | tuple[str] | tuple[str, str] | None]]: if y is None: return y processed_messages = [] @@ -4573,8 +4569,8 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): return processed_messages def _postprocess_chat_messages( - self, chat_message: str | Tuple | List | None - ) -> str | Dict | None: + self, chat_message: str | tuple | list | None + ) -> str | dict | None: if chat_message is None: return None elif isinstance(chat_message, (tuple, list)): @@ -4603,8 +4599,8 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): def postprocess( self, - y: List[List[str | Tuple[str] | Tuple[str, str] | None] | Tuple], - ) -> List[List[str | Dict | None]]: + y: list[list[str | tuple[str] | tuple[str, str] | None] | tuple], + ) -> list[list[str | dict | None]]: """ Parameters: y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed. @@ -4660,13 +4656,13 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable): self, value: str | Callable | None = None, *, - clear_color: List[float] | None = None, + clear_color: list[float] | None = None, label: str | None = None, every: float | None = None, show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -4700,7 +4696,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable): **IOComponent.get_config(self), } - def example_inputs(self) -> Dict[str, Any]: + def example_inputs(self) -> dict[str, Any]: return { "raw": {"is_file": False, "data": media_data.BASE64_MODEL3D}, "serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/Box.gltf", @@ -4722,7 +4718,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable): } return updated_config - def preprocess(self, x: Dict[str, str] | None) -> str | None: + def preprocess(self, x: dict[str, str] | None) -> str | None: """ Parameters: x: JSON object with filename as 'name' property and base64 data as 'data' property @@ -4743,7 +4739,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable): return temp_file_path - def postprocess(self, y: str | None) -> Dict[str, str] | None: + def postprocess(self, y: str | None) -> dict[str, str] | None: """ Parameters: y: path to the model @@ -4793,7 +4789,7 @@ class Plot(Changeable, Clearable, IOComponent, JSONSerializable): show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -4847,7 +4843,7 @@ class Plot(Changeable, Clearable, IOComponent, JSONSerializable): } return updated_config - def postprocess(self, y) -> Dict[str, str] | None: + def postprocess(self, y) -> dict[str, str] | None: """ Parameters: y: plot data @@ -4868,10 +4864,7 @@ class Plot(Changeable, Clearable, IOComponent, JSONSerializable): out_y = json.dumps(json_item(y)) else: is_altair = "altair" in y.__module__ - if is_altair: - dtype = "altair" - else: - dtype = "plotly" + dtype = "altair" if is_altair else "plotly" out_y = y.to_json() return {"type": dtype, "plot": out_y} @@ -4921,7 +4914,7 @@ class ScatterPlot(Plot): size: str | None = None, shape: str | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -4932,8 +4925,8 @@ class ScatterPlot(Plot): shape_legend_position: str | None = None, height: int | None = None, width: int | None = None, - x_lim: List[int | float] | None = None, - y_lim: List[int | float] | None = None, + x_lim: list[int | float] | None = None, + y_lim: list[int | float] | None = None, caption: str | None = None, interactive: bool | None = True, label: str | None = None, @@ -4941,7 +4934,7 @@ class ScatterPlot(Plot): show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, ): """ Parameters: @@ -5015,14 +5008,14 @@ class ScatterPlot(Plot): @staticmethod def update( - value: DataFrame | Dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, + value: DataFrame | dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, x: str | None = None, y: str | None = None, color: str | None = None, size: str | None = None, shape: str | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5033,8 +5026,8 @@ class ScatterPlot(Plot): shape_legend_position: str | None = None, height: int | None = None, width: int | None = None, - x_lim: List[int | float] | None = None, - y_lim: List[int | float] | None = None, + x_lim: list[int | float] | None = None, + y_lim: list[int | float] | None = None, interactive: bool | None = None, caption: str | None = None, label: str | None = None, @@ -5129,7 +5122,7 @@ class ScatterPlot(Plot): size: str | None = None, shape: str | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5140,8 +5133,8 @@ class ScatterPlot(Plot): shape_legend_position: str | None = None, height: int | None = None, width: int | None = None, - x_lim: List[int | float] | None = None, - y_lim: List[int | float] | None = None, + x_lim: list[int | float] | None = None, + y_lim: list[int | float] | None = None, interactive: bool | None = True, ): """Helper for creating the scatter plot.""" @@ -5212,7 +5205,7 @@ class ScatterPlot(Plot): return chart - def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: + def postprocess(self, y: pd.DataFrame | dict | None) -> dict[str, str] | None: # if None or update if y is None or isinstance(y, Dict): return y @@ -5266,7 +5259,7 @@ class LinePlot(Plot): stroke_dash: str | None = None, overlay_point: bool | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5275,8 +5268,8 @@ class LinePlot(Plot): stroke_dash_legend_position: str | None = None, height: int | None = None, width: int | None = None, - x_lim: List[int] | None = None, - y_lim: List[int] | None = None, + x_lim: list[int] | None = None, + y_lim: list[int] | None = None, caption: str | None = None, interactive: bool | None = True, label: str | None = None, @@ -5284,7 +5277,7 @@ class LinePlot(Plot): every: float | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, ): """ Parameters: @@ -5354,14 +5347,14 @@ class LinePlot(Plot): @staticmethod def update( - value: pd.DataFrame | Dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, + value: pd.DataFrame | dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, x: str | None = None, y: str | None = None, color: str | None = None, stroke_dash: str | None = None, overlay_point: bool | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5370,8 +5363,8 @@ class LinePlot(Plot): stroke_dash_legend_position: str | None = None, height: int | None = None, width: int | None = None, - x_lim: List[int] | None = None, - y_lim: List[int] | None = None, + x_lim: list[int] | None = None, + y_lim: list[int] | None = None, interactive: bool | None = None, caption: str | None = None, label: str | None = None, @@ -5462,7 +5455,7 @@ class LinePlot(Plot): stroke_dash: str | None = None, overlay_point: bool | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5471,8 +5464,8 @@ class LinePlot(Plot): stroke_dash_legend_position: str | None = None, height: int | None = None, width: int | None = None, - x_lim: List[int] | None = None, - y_lim: List[int] | None = None, + x_lim: list[int] | None = None, + y_lim: list[int] | None = None, interactive: bool | None = None, ): """Helper for creating the scatter plot.""" @@ -5552,7 +5545,7 @@ class LinePlot(Plot): return chart - def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: + def postprocess(self, y: pd.DataFrame | dict | None) -> dict[str, str] | None: # if None or update if y is None or isinstance(y, Dict): return y @@ -5604,7 +5597,7 @@ class BarPlot(Plot): vertical: bool = True, group: str | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5612,7 +5605,7 @@ class BarPlot(Plot): color_legend_position: str | None = None, height: int | None = None, width: int | None = None, - y_lim: List[int] | None = None, + y_lim: list[int] | None = None, caption: str | None = None, interactive: bool | None = True, label: str | None = None, @@ -5620,7 +5613,7 @@ class BarPlot(Plot): every: float | None = None, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, ): """ Parameters: @@ -5687,14 +5680,14 @@ class BarPlot(Plot): @staticmethod def update( - value: pd.DataFrame | Dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, + value: pd.DataFrame | dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, x: str | None = None, y: str | None = None, color: str | None = None, vertical: bool = True, group: str | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5702,7 +5695,7 @@ class BarPlot(Plot): color_legend_position: str | None = None, height: int | None = None, width: int | None = None, - y_lim: List[int] | None = None, + y_lim: list[int] | None = None, caption: str | None = None, interactive: bool | None = None, label: str | None = None, @@ -5789,7 +5782,7 @@ class BarPlot(Plot): vertical: bool = True, group: str | None = None, title: str | None = None, - tooltip: List[str] | str | None = None, + tooltip: list[str] | str | None = None, x_title: str | None = None, y_title: str | None = None, color_legend_title: str | None = None, @@ -5797,7 +5790,7 @@ class BarPlot(Plot): color_legend_position: str | None = None, height: int | None = None, width: int | None = None, - y_lim: List[int] | None = None, + y_lim: list[int] | None = None, interactive: bool | None = True, ): """Helper for creating the bar plot.""" @@ -5871,7 +5864,7 @@ class BarPlot(Plot): return chart - def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None: + def postprocess(self, y: pd.DataFrame | dict | None) -> dict[str, str] | None: # if None or update if y is None or isinstance(y, Dict): return y @@ -5917,7 +5910,7 @@ class Markdown(IOComponent, Changeable, StringSerializable): *, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -6000,7 +5993,7 @@ class Code(Changeable, IOComponent, StringSerializable): def __init__( self, - value: str | Tuple[str] | None = None, + value: str | tuple[str] | None = None, language: str | None = None, *, lines: int = 5, @@ -6009,7 +6002,7 @@ class Code(Changeable, IOComponent, StringSerializable): show_label: bool = True, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -6058,7 +6051,7 @@ class Code(Changeable, IOComponent, StringSerializable): @staticmethod def update( value: str - | Tuple[str] + | tuple[str] | None | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, label: str | None = None, @@ -6099,14 +6092,14 @@ class Dataset(Clickable, Selectable, Component, StringSerializable): self, *, label: str | None = None, - components: List[IOComponent] | List[str], - samples: List[List[Any]] | None = None, - headers: List[str] | None = None, + components: list[IOComponent] | list[str], + samples: list[list[Any]] | None = None, + headers: list[str] | None = None, type: str = "values", samples_per_page: int = 10, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ @@ -6180,7 +6173,7 @@ class Dataset(Clickable, Selectable, Component, StringSerializable): elif self.type == "values": return self.samples[x] - def postprocess(self, samples: List[List[Any]]) -> Dict: + def postprocess(self, samples: list[list[Any]]) -> dict: return { "samples": samples, "__type__": "update", @@ -6210,7 +6203,7 @@ class Interpretation(Component, SimpleSerializable): *, visible: bool = True, elem_id: str | None = None, - elem_classes: List[str] | str | None = None, + elem_classes: list[str] | str | None = None, **kwargs, ): """ diff --git a/gradio/events.py b/gradio/events.py index d52db6fe78..dfd117ae4d 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -4,7 +4,7 @@ of the on-page-load event, which is defined in gr.Blocks().load().""" from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple +from typing import TYPE_CHECKING, Any, Callable from gradio_client.documentation import document, set_documentation_group @@ -19,7 +19,7 @@ set_documentation_group("events") def set_cancel_events( - block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]] + block: Block, event_name: str, cancels: None | dict[str, Any] | list[dict[str, Any]] ): if cancels: if not isinstance(cancels, list): @@ -91,8 +91,8 @@ class EventListenerMethod: def __call__( self, fn: Callable | None, - inputs: Component | List[Component] | Set[Component] | None = None, - outputs: Component | List[Component] | None = None, + inputs: Component | list[Component] | set[Component] | None = None, + outputs: Component | list[Component] | None = None, api_name: str | None = None, status_tracker: StatusTracker | None = None, scroll_to_output: bool = False, @@ -102,7 +102,7 @@ class EventListenerMethod: max_batch_size: int = 4, preprocess: bool = True, postprocess: bool = True, - cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None, + cancels: dict[str, Any] | list[dict[str, Any]] | None = None, every: float | None = None, _js: str | None = None, ) -> Dependency: @@ -290,7 +290,7 @@ class Selectable(EventListener): class SelectData(EventData): def __init__(self, target: Block | None, data: Any): super().__init__(target, data) - self.index: int | Tuple[int, int] = data["index"] + self.index: int | tuple[int, int] = data["index"] """ The index of the selected item. Is a tuple if the component is two dimensional or selection is a range. """ diff --git a/gradio/exceptions.py b/gradio/exceptions.py index 48ede6c3fc..493fcbf826 100644 --- a/gradio/exceptions.py +++ b/gradio/exceptions.py @@ -15,10 +15,13 @@ class TooManyRequestsError(Exception): pass -class InvalidApiName(ValueError): +class InvalidApiNameError(ValueError): pass +InvalidApiName = InvalidApiNameError # backwards compatibility + + @document() class Error(Exception): """ diff --git a/gradio/external.py b/gradio/external.py index 569c24a6bd..6080d5405a 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -6,7 +6,7 @@ from __future__ import annotations import json import re import warnings -from typing import TYPE_CHECKING, Callable, Dict +from typing import TYPE_CHECKING, Callable import requests from gradio_client import Client @@ -87,7 +87,7 @@ def load_blocks_from_repo( src = tokens[0] name = "/".join(tokens[1:]) - factory_methods: Dict[str, Callable] = { + factory_methods: dict[str, Callable] = { # for each repo type, we have a method that returns the Interface given the model name & optionally an api_key "huggingface": from_model, "models": from_model, @@ -393,7 +393,7 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs data.update({"options": {"wait_for_model": True}}) data = json.dumps(data) response = requests.request("POST", api_url, headers=headers, data=data) - if not (response.status_code == 200): + if response.status_code != 200: errors_json = response.json() errors, warns = "", "" if errors_json.get("error"): @@ -494,7 +494,7 @@ def from_spaces_blocks(space: str, api_key: str | None) -> Blocks: def from_spaces_interface( model_name: str, - config: Dict, + config: dict, alias: str | None, api_key: str | None, iframe_url: str, diff --git a/gradio/flagging.py b/gradio/flagging.py index 89b65f888d..11deab6ab4 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -10,7 +10,7 @@ import warnings from abc import ABC, abstractmethod from distutils.version import StrictVersion from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any import filelock import huggingface_hub @@ -33,7 +33,7 @@ class FlaggingCallback(ABC): """ @abstractmethod - def setup(self, components: List[IOComponent], flagging_dir: str): + def setup(self, components: list[IOComponent], flagging_dir: str): """ This method should be overridden and ensure that everything is set up correctly for flag(). This method gets called once at the beginning of the Interface.launch() method. @@ -46,7 +46,7 @@ class FlaggingCallback(ABC): @abstractmethod def flag( self, - flag_data: List[Any], + flag_data: list[Any], flag_option: str = "", username: str | None = None, ) -> int: @@ -81,14 +81,14 @@ class SimpleCSVLogger(FlaggingCallback): def __init__(self): pass - def setup(self, components: List[IOComponent], flagging_dir: str | Path): + def setup(self, components: list[IOComponent], flagging_dir: str | Path): self.components = components self.flagging_dir = flagging_dir os.makedirs(flagging_dir, exist_ok=True) def flag( self, - flag_data: List[Any], + flag_data: list[Any], flag_option: str = "", username: str | None = None, ) -> int: @@ -112,7 +112,7 @@ class SimpleCSVLogger(FlaggingCallback): writer = csv.writer(csvfile) writer.writerow(utils.sanitize_list_for_csv(csv_data)) - with open(log_filepath, "r") as csvfile: + with open(log_filepath) as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 return line_count @@ -136,7 +136,7 @@ class CSVLogger(FlaggingCallback): def setup( self, - components: List[IOComponent], + components: list[IOComponent], flagging_dir: str | Path, ): self.components = components @@ -145,7 +145,7 @@ class CSVLogger(FlaggingCallback): def flag( self, - flag_data: List[Any], + flag_data: list[Any], flag_option: str = "", username: str | None = None, ) -> int: @@ -186,7 +186,7 @@ class CSVLogger(FlaggingCallback): writer.writerow(utils.sanitize_list_for_csv(headers)) writer.writerow(utils.sanitize_list_for_csv(csv_data)) - with open(log_filepath, "r", encoding="utf-8") as csvfile: + with open(log_filepath, encoding="utf-8") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 return line_count @@ -235,7 +235,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): self.info_filename = info_filename self.separate_dirs = separate_dirs - def setup(self, components: List[IOComponent], flagging_dir: str): + def setup(self, components: list[IOComponent], flagging_dir: str): """ Params: flagging_dir (str): local directory where the dataset is cloned, @@ -286,7 +286,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): except huggingface_hub.utils.EntryNotFoundError: pass - def flag(self, flag_data: List[Any], flag_option: str = "") -> int: + def flag(self, flag_data: list[Any], flag_option: str = "") -> int: if self.separate_dirs: # JSONL files to support dataset preview on the Hub unique_id = str(uuid.uuid4()) @@ -312,7 +312,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): data_file: Path, components_dir: Path, path_in_repo: str | None, - flag_data: List[Any], + flag_data: list[Any], flag_option: str = "", ) -> int: # Deserialize components (write images/audio to files) @@ -368,7 +368,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): return sample_nb @staticmethod - def _save_as_csv(data_file: Path, headers: List[str], row: List[Any]) -> int: + def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int: """Save data as CSV and return the sample name (row number).""" is_new = not data_file.exists() @@ -386,7 +386,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): return sum(1 for _ in csv.reader(csvfile)) - 1 @staticmethod - def _save_as_jsonl(data_file: Path, headers: List[str], row: List[Any]) -> str: + def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str: """Save data as JSONL and return the sample name (uuid).""" Path.mkdir(data_file.parent, parents=True, exist_ok=True) with open(data_file, "w") as f: @@ -394,15 +394,15 @@ class HuggingFaceDatasetSaver(FlaggingCallback): return data_file.parent.name def _deserialize_components( - self, data_dir: Path, flag_data: List[Any], flag_option: str = "" - ) -> Tuple[Dict[Any, Any], List[Any]]: + self, data_dir: Path, flag_data: list[Any], flag_option: str = "" + ) -> tuple[dict[Any, Any], list[Any]]: """Deserialize components and return the corresponding row for the flagged sample. Images/audio are saved to disk as individual files. """ # Components that can have a preview on dataset repos # NOTE: not at root level to avoid circular imports - FILE_PREVIEW_TYPES = {gr.Audio: "Audio", gr.Image: "Image"} + file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} # Generate the row corresponding to the flagged sample features = {} @@ -418,8 +418,8 @@ class HuggingFaceDatasetSaver(FlaggingCallback): row.append(Path(deserialized).name) # If component is eligible for a preview, add the URL of the file - if isinstance(component, tuple(FILE_PREVIEW_TYPES)): # type: ignore - for _component, _type in FILE_PREVIEW_TYPES.items(): + if isinstance(component, tuple(file_preview_types)): # type: ignore + for _component, _type in file_preview_types.items(): if isinstance(component, _component): features[label + " file"] = {"_type": _type} break diff --git a/gradio/helpers.py b/gradio/helpers.py index 32b8125284..9ac56aef70 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -12,7 +12,7 @@ import tempfile import threading import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, Iterable import matplotlib.pyplot as plt import numpy as np @@ -36,9 +36,9 @@ set_documentation_group("helpers") def create_examples( - examples: List[Any] | List[List[Any]] | str, - inputs: IOComponent | List[IOComponent], - outputs: IOComponent | List[IOComponent] | None = None, + examples: list[Any] | list[list[Any]] | str, + inputs: IOComponent | list[IOComponent], + outputs: IOComponent | list[IOComponent] | None = None, fn: Callable | None = None, cache_examples: bool = False, examples_per_page: int = 10, @@ -85,9 +85,9 @@ class Examples: def __init__( self, - examples: List[Any] | List[List[Any]] | str, - inputs: IOComponent | List[IOComponent], - outputs: IOComponent | List[IOComponent] | None = None, + examples: list[Any] | list[list[Any]] | str, + inputs: IOComponent | list[IOComponent], + outputs: IOComponent | list[IOComponent] | None = None, fn: Callable | None = None, cache_examples: bool = False, examples_per_page: int = 10, @@ -323,7 +323,7 @@ class Examples: Context.root_block.dependencies.remove(dependency) Context.root_block.fns.pop(fn_index) - async def load_from_cache(self, example_id: int) -> List[Any]: + async def load_from_cache(self, example_id: int) -> list[Any]: """Loads a particular cached example for the interface. Parameters: example_id: The id of the example to process (zero-indexed). @@ -396,7 +396,7 @@ class Progress(Iterable): self.track_tqdm = track_tqdm self._callback = _callback self._event_id = _event_id - self.iterables: List[TrackedIterable] = [] + self.iterables: list[TrackedIterable] = [] def __len__(self): return self.iterables[-1].length @@ -431,7 +431,7 @@ class Progress(Iterable): def __call__( self, - progress: float | Tuple[int, int | None] | None, + progress: float | tuple[int, int | None] | None, desc: str | None = None, total: int | None = None, unit: str = "steps", @@ -541,7 +541,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm): if self._progress is not None: self._progress.event_id = event_id self._progress.tqdm(iterable, desc, _tqdm=self) - kwargs["file"] = open(os.devnull, "w") + kwargs["file"] = open(os.devnull, "w") # noqa: SIM115 self.__init__orig__(iterable, desc, *args, **kwargs) def iter_tqdm(self): @@ -595,7 +595,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm): def special_args( fn: Callable, - inputs: List[Any] | None = None, + inputs: list[Any] | None = None, request: routes.Request | None = None, event_data: EventData | None = None, ): @@ -632,9 +632,10 @@ def special_args( event_data_index = i if inputs is not None and event_data is not None: inputs.insert(i, param.annotation(event_data.target, event_data._data)) - elif param.default is not param.empty: - if inputs is not None and len(inputs) <= i: - inputs.insert(i, param.default) + elif ( + param.default is not param.empty and inputs is not None and len(inputs) <= i + ): + inputs.insert(i, param.default) if inputs is not None: while len(inputs) < len(positional_args): i = len(inputs) @@ -696,12 +697,12 @@ def skip() -> dict: @document() def make_waveform( - audio: str | Tuple[int, np.ndarray], + audio: str | tuple[int, np.ndarray], *, bg_color: str = "#f3f4f6", bg_image: str | None = None, fg_alpha: float = 0.75, - bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"), + bars_color: str | tuple[str, str] = ("#fbbf24", "#ea580c"), bar_count: int = 50, bar_width: float = 0.6, ): @@ -728,13 +729,13 @@ def make_waveform( duration = round(len(audio[1]) / audio[0], 4) # Helper methods to create waveform - def hex_to_RGB(hex_str): + def hex_to_rgb(hex_str): return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)] def get_color_gradient(c1, c2, n): assert n > 1 - c1_rgb = np.array(hex_to_RGB(c1)) / 255 - c2_rgb = np.array(hex_to_RGB(c2)) / 255 + c1_rgb = np.array(hex_to_rgb(c1)) / 255 + c2_rgb = np.array(hex_to_rgb(c2)) / 255 mix_pcts = [x / (n - 1) for x in range(n)] rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts] return [ @@ -770,7 +771,7 @@ def make_waveform( plt.axis("off") plt.margins(x=0) tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"} + savefig_kwargs: dict[str, Any] = {"bbox_inches": "tight"} if bg_image is not None: savefig_kwargs["transparent"] = True else: diff --git a/gradio/inputs.py b/gradio/inputs.py index d2991db9fb..db83e59260 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -8,7 +8,7 @@ automatically added to a registry, which allows them to be easily referenced in from __future__ import annotations import warnings -from typing import Any, List, Optional, Tuple +from typing import Any, Optional from gradio import components @@ -132,8 +132,8 @@ class CheckboxGroup(components.CheckboxGroup): def __init__( self, - choices: List[str], - default: List[str] | None = None, + choices: list[str], + default: list[str] | None = None, type: str = "value", label: Optional[str] = None, optional: bool = False, @@ -168,7 +168,7 @@ class Radio(components.Radio): def __init__( self, - choices: List[str], + choices: list[str], type: str = "value", default: Optional[str] = None, label: Optional[str] = None, @@ -202,7 +202,7 @@ class Dropdown(components.Dropdown): def __init__( self, - choices: List[str], + choices: list[str], type: str = "value", default: Optional[str] = None, label: Optional[str] = None, @@ -236,7 +236,7 @@ class Image(components.Image): def __init__( self, - shape: Tuple[int, int] = None, + shape: tuple[int, int] = None, image_mode: str = "RGB", invert_colors: bool = False, source: str = "upload", @@ -366,12 +366,12 @@ class Dataframe(components.Dataframe): def __init__( self, - headers: Optional[List[str]] = None, + headers: Optional[list[str]] = None, row_count: int = 3, col_count: Optional[int] = 3, - datatype: str | List[str] = "str", - col_width: int | List[int] = None, - default: Optional[List[List[Any]]] = None, + datatype: str | list[str] = "str", + col_width: int | list[int] = None, + default: Optional[list[list[Any]]] = None, type: str = "pandas", label: Optional[str] = None, optional: bool = False, @@ -413,7 +413,7 @@ class Timeseries(components.Timeseries): def __init__( self, x: Optional[str] = None, - y: str | List[str] = None, + y: str | list[str] = None, label: Optional[str] = None, optional: bool = False, ): diff --git a/gradio/interface.py b/gradio/interface.py index c55ac16233..f32afa0cc4 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -8,10 +8,9 @@ from __future__ import annotations import inspect import json import os -import re import warnings import weakref -from typing import TYPE_CHECKING, Any, Callable, List, Tuple +from typing import TYPE_CHECKING, Any, Callable from gradio_client.documentation import document, set_documentation_group @@ -63,7 +62,7 @@ class Interface(Blocks): instances: weakref.WeakSet = weakref.WeakSet() @classmethod - def get_instances(cls) -> List[Interface]: + def get_instances(cls) -> list[Interface]: """ :return: list of all current instances. """ @@ -119,9 +118,9 @@ class Interface(Blocks): def __init__( self, fn: Callable, - inputs: str | IOComponent | List[str | IOComponent] | None, - outputs: str | IOComponent | List[str | IOComponent] | None, - examples: List[Any] | List[List[Any]] | str | None = None, + inputs: str | IOComponent | list[str | IOComponent] | None, + outputs: str | IOComponent | list[str | IOComponent] | None, + examples: list[Any] | list[list[Any]] | str | None = None, cache_examples: bool | None = None, examples_per_page: int = 10, live: bool = False, @@ -134,7 +133,7 @@ class Interface(Blocks): theme: Theme | str | None = None, css: str | None = None, allow_flagging: str | None = None, - flagging_options: List[str] | List[Tuple[str, str]] | None = None, + flagging_options: list[str] | list[tuple[str, str]] | None = None, flagging_dir: str = "flagged", flagging_callback: FlaggingCallback = CSVLogger(), analytics_enabled: bool | None = None, @@ -287,17 +286,11 @@ class Interface(Blocks): self.live = live self.title = title - CLEANER = re.compile("<.*?>") - - def clean_html(raw_html): - cleantext = re.sub(CLEANER, "", raw_html) - return cleantext - md = utils.get_markdown_parser() - simple_description = None + simple_description: str | None = None if description is not None: description = md.render(description) - simple_description = clean_html(description) + simple_description = utils.remove_html_tags(description) self.simple_description = simple_description self.description = description if article is not None: @@ -466,19 +459,19 @@ class Interface(Blocks): if self.description: Markdown(self.description) - def render_flag_btns(self) -> List[Button]: + def render_flag_btns(self) -> list[Button]: return [Button(label) for label, _ in self.flagging_options] def render_input_column( self, - ) -> Tuple[ + ) -> tuple[ Button | None, Button | None, Button | None, - List[Button] | None, + list[Button] | None, Column, Column | None, - List[Interpretation] | None, + list[Interpretation] | None, ]: submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None interpret_component_column, interpretation_set = None, None @@ -531,7 +524,7 @@ class Interface(Blocks): def render_output_column( self, submit_btn_in: Button | None, - ) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]: + ) -> tuple[Button | None, Button | None, Button | None, list | None, Button | None]: submit_btn = submit_btn_in interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None @@ -699,7 +692,7 @@ class Interface(Blocks): def attach_interpretation_events( self, interpretation_btn: Button | None, - interpretation_set: List[Interpretation] | None, + interpretation_set: list[Interpretation] | None, input_component_column: Column | None, interpret_component_column: Column | None, ): @@ -711,53 +704,58 @@ class Interface(Blocks): preprocess=False, ) - def attach_flagging_events(self, flag_btns: List[Button] | None, clear_btn: Button): - if flag_btns: - if self.interface_type in [ + def attach_flagging_events(self, flag_btns: list[Button] | None, clear_btn: Button): + if not ( + flag_btns + and self.interface_type + in ( InterfaceTypes.STANDARD, InterfaceTypes.OUTPUT_ONLY, InterfaceTypes.UNIFIED, - ]: - if self.allow_flagging == "auto": - flag_method = FlagMethod( - self.flagging_callback, "", "", visual_feedback=False - ) - flag_btns[0].click( # flag_btns[0] is just the "Submit" button - flag_method, - inputs=self.input_components, - outputs=None, - preprocess=False, - queue=False, - ) - return + ) + ): + return - if self.interface_type == InterfaceTypes.UNIFIED: - flag_components = self.input_components - else: - flag_components = self.input_components + self.output_components + if self.allow_flagging == "auto": + flag_method = FlagMethod( + self.flagging_callback, "", "", visual_feedback=False + ) + flag_btns[0].click( # flag_btns[0] is just the "Submit" button + flag_method, + inputs=self.input_components, + outputs=None, + preprocess=False, + queue=False, + ) + return - for flag_btn, (label, value) in zip(flag_btns, self.flagging_options): - assert isinstance(value, str) - flag_method = FlagMethod(self.flagging_callback, label, value) - flag_btn.click( - lambda: Button.update(value="Saving...", interactive=False), - None, - flag_btn, - queue=False, - ) - flag_btn.click( - flag_method, - inputs=flag_components, - outputs=flag_btn, - preprocess=False, - queue=False, - ) - clear_btn.click( - flag_method.reset, - None, - flag_btn, - queue=False, - ) + if self.interface_type == InterfaceTypes.UNIFIED: + flag_components = self.input_components + else: + flag_components = self.input_components + self.output_components + + for flag_btn, (label, value) in zip(flag_btns, self.flagging_options): + assert isinstance(value, str) + flag_method = FlagMethod(self.flagging_callback, label, value) + flag_btn.click( + lambda: Button.update(value="Saving...", interactive=False), + None, + flag_btn, + queue=False, + ) + flag_btn.click( + flag_method, + inputs=flag_components, + outputs=flag_btn, + preprocess=False, + queue=False, + ) + clear_btn.click( + flag_method.reset, + None, + flag_btn, + queue=False, + ) def render_examples(self): if self.examples: @@ -798,7 +796,7 @@ class Interface(Blocks): Column.update(visible=True), ] - async def interpret(self, raw_input: List[Any]) -> List[Any]: + async def interpret(self, raw_input: list[Any]) -> list[Any]: return [ {"original": raw_value, "interpretation": interpretation} for interpretation, raw_value in zip( @@ -823,8 +821,8 @@ class TabbedInterface(Blocks): def __init__( self, - interface_list: List[Interface], - tab_names: List[str] | None = None, + interface_list: list[Interface], + tab_names: list[str] | None = None, title: str | None = None, theme: Theme | None = None, analytics_enabled: bool | None = None, diff --git a/gradio/interpretation.py b/gradio/interpretation.py index 77c71237da..295aeabf83 100644 --- a/gradio/interpretation.py +++ b/gradio/interpretation.py @@ -5,7 +5,7 @@ from __future__ import annotations import copy import math from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any import numpy as np from gradio_client import utils as client_utils @@ -28,8 +28,8 @@ class Interpretable(ABC): # noqa: B024 pass def get_interpretation_scores( - self, x: Any, neighbors: List[Any] | None, scores: List[float], **kwargs - ) -> List: + self, x: Any, neighbors: list[Any] | None, scores: list[float], **kwargs + ) -> list: """ Arrange the output values from the neighbors into interpretation scores for the interface to render. Parameters: @@ -44,7 +44,7 @@ class Interpretable(ABC): # noqa: B024 class TokenInterpretable(Interpretable, ABC): @abstractmethod - def tokenize(self, x: Any) -> Tuple[List, List, None]: + def tokenize(self, x: Any) -> tuple[list, list, None]: """ Interprets an input data point x by splitting it into a list of tokens (e.g a string into words or an image into super-pixels). @@ -52,13 +52,13 @@ class TokenInterpretable(Interpretable, ABC): return [], [], None @abstractmethod - def get_masked_inputs(self, tokens: List, binary_mask_matrix: List[List]) -> List: + def get_masked_inputs(self, tokens: list, binary_mask_matrix: list[list]) -> list: return [] class NeighborInterpretable(Interpretable, ABC): @abstractmethod - def get_interpretation_neighbors(self, x: Any) -> Tuple[List, Dict]: + def get_interpretation_neighbors(self, x: Any) -> tuple[list, dict]: """ Generates values similar to input to be used to interpret the significance of the input in the final output. Parameters: @@ -70,7 +70,7 @@ class NeighborInterpretable(Interpretable, ABC): return [], {} -async def run_interpret(interface: Interface, raw_input: List): +async def run_interpret(interface: Interface, raw_input: list): """ Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box interpretation for a certain set of UI component types, as well as the custom interpretation case. @@ -265,12 +265,12 @@ def diff(original: Any, perturbed: Any) -> int | float: try: # try computing numerical difference score = float(original) - float(perturbed) except ValueError: # otherwise, look at strict difference in label - score = int(not (original == perturbed)) + score = int(original != perturbed) return score def quantify_difference_in_label( - interface: Interface, original_output: List, perturbed_output: List + interface: Interface, original_output: list, perturbed_output: list ) -> int | float: output_component = interface.output_components[0] post_original_output = output_component.postprocess(original_output[0]) @@ -300,7 +300,7 @@ def quantify_difference_in_label( def get_regression_or_classification_value( - interface: Interface, original_output: List, perturbed_output: List + interface: Interface, original_output: list, perturbed_output: list ) -> int | float: """Used to combine regression/classification for Shap interpretation method.""" output_component = interface.output_components[0] diff --git a/gradio/layouts.py b/gradio/layouts.py index 47cda5bb59..1401e01dba 100644 --- a/gradio/layouts.py +++ b/gradio/layouts.py @@ -1,7 +1,6 @@ from __future__ import annotations import warnings -from typing import Type from gradio_client.documentation import document, set_documentation_group @@ -216,7 +215,7 @@ class Tab(BlockContext, Selectable): **super(BlockContext, self).get_config(), } - def get_expected_parent(self) -> Type[Tabs]: + def get_expected_parent(self) -> type[Tabs]: return Tabs def get_block_name(self): diff --git a/gradio/networking.py b/gradio/networking.py index 7f6af29fe9..b9512e653a 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -9,7 +9,7 @@ import socket import threading import time import warnings -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING import requests import uvicorn @@ -89,7 +89,7 @@ def start_server( ssl_keyfile: str | None = None, ssl_certfile: str | None = None, ssl_keyfile_password: str | None = None, -) -> Tuple[str, int, str, App, Server]: +) -> tuple[str, int, str, App, Server]: """Launches a local server running the provided Interface Parameters: blocks: The Blocks object to run on the server diff --git a/gradio/outputs.py b/gradio/outputs.py index 2995fdcac8..66f170c46f 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -8,7 +8,7 @@ automatically added to a registry, which allows them to be easily referenced in from __future__ import annotations import warnings -from typing import Dict, List, Optional +from typing import Optional from gradio import components @@ -109,7 +109,7 @@ class Dataframe(components.Dataframe): def __init__( self, - headers: Optional[List[str]] = None, + headers: Optional[list[str]] = None, max_rows: Optional[int] = 20, max_cols: Optional[int] = None, overflow_row_behaviour: str = "paginate", @@ -145,7 +145,7 @@ class Timeseries(components.Timeseries): """ def __init__( - self, x: str = None, y: str | List[str] = None, label: Optional[str] = None + self, x: str = None, y: str | list[str] = None, label: Optional[str] = None ): """ Parameters: @@ -227,7 +227,7 @@ class HighlightedText(components.HighlightedText): def __init__( self, - color_map: Dict[str, str] = None, + color_map: dict[str, str] = None, label: Optional[str] = None, show_legend: bool = False, ): @@ -281,7 +281,7 @@ class Carousel(components.Carousel): def __init__( self, - components: components.Component | List[components.Component], + components: components.Component | list[components.Component], label: Optional[str] = None, ): """ diff --git a/gradio/pipelines.py b/gradio/pipelines.py index 73bb15e608..144f1f7ecd 100644 --- a/gradio/pipelines.py +++ b/gradio/pipelines.py @@ -3,7 +3,7 @@ please use the `gr.Interface.from_pipeline()` function.""" from __future__ import annotations -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from gradio import components @@ -11,7 +11,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime). from transformers import pipelines -def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict: +def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict: """ Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline. pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index a5772658b6..ff541694b5 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -8,7 +8,6 @@ import tempfile import warnings from io import BytesIO from pathlib import Path -from typing import Dict import numpy as np from ffmpy import FFmpeg, FFprobe, FFRuntimeError @@ -25,7 +24,7 @@ with warnings.catch_warnings(): ######################### -def to_binary(x: str | Dict) -> bytes: +def to_binary(x: str | dict) -> bytes: """Converts a base64 string or dictionary to a binary string that can be sent in a POST.""" if isinstance(x, dict): if x.get("data"): @@ -362,10 +361,7 @@ def _convert(image, dtype, force_copy=False, uniform=False): image = np.asarray(image) dtypeobj_in = image.dtype - if dtype is np.floating: - dtypeobj_out = np.dtype("float64") - else: - dtypeobj_out = np.dtype(dtype) + dtypeobj_out = np.dtype("float64") if dtype is np.floating else np.dtype(dtype) dtype_in = dtypeobj_in.type dtype_out = dtypeobj_out.type kind_in = dtypeobj_in.kind diff --git a/gradio/queueing.py b/gradio/queueing.py index 43ea1c9694..525551b1a7 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -6,7 +6,7 @@ import sys import time from asyncio import TimeoutError as AsyncTimeOutError from collections import deque -from typing import Any, Deque, Dict, List, Tuple +from typing import Any, Deque import fastapi import httpx @@ -44,14 +44,14 @@ class Queue: concurrency_count: int, update_intervals: float, max_size: int | None, - blocks_dependencies: List, + blocks_dependencies: list, ): self.event_queue: Deque[Event] = deque() self.events_pending_reconnection = [] self.stopped = False self.max_thread_count = concurrency_count self.update_intervals = update_intervals - self.active_jobs: List[None | List[Event]] = [None] * concurrency_count + self.active_jobs: list[None | list[Event]] = [None] * concurrency_count self.delete_lock = asyncio.Lock() self.server_path = None self.duration_history_total = 0 @@ -96,7 +96,7 @@ class Queue: count += 1 return count - def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]: + def get_events_in_batch(self) -> tuple[list[Event] | None, bool]: if not (self.event_queue): return None, False @@ -158,7 +158,7 @@ class Queue: def set_progress( self, event_id: str, - iterables: List[TrackedIterable] | None, + iterables: list[TrackedIterable] | None, ): if iterables is None: return @@ -167,7 +167,7 @@ class Queue: continue for evt in job: if evt._id == event_id: - progress_data: List[ProgressUnit] = [] + progress_data: list[ProgressUnit] = [] for iterable in iterables: progress_unit = ProgressUnit( index=iterable.index, @@ -303,7 +303,7 @@ class Queue: queue_eta=self.queue_duration, ) - def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]: + def get_request_params(self, websocket: fastapi.WebSocket) -> dict[str, Any]: return { "url": str(websocket.url), "headers": dict(websocket.headers), @@ -312,7 +312,7 @@ class Queue: "client": {"host": websocket.client.host, "port": websocket.client.port}, # type: ignore } - async def call_prediction(self, events: List[Event], batch: bool): + async def call_prediction(self, events: list[Event], batch: bool): data = events[0].data assert data is not None, "No event data" token = events[0].token @@ -340,8 +340,8 @@ class Queue: ) return response - async def process_events(self, events: List[Event], batch: bool) -> None: - awake_events: List[Event] = [] + async def process_events(self, events: list[Event], batch: bool) -> None: + awake_events: list[Event] = [] try: for event in events: client_awake = await self.gather_event_data(event) @@ -438,7 +438,7 @@ class Queue: # to start "from scratch" await self.reset_iterators(event.session_hash, event.fn_index) - async def send_message(self, event, data: Dict, timeout: float | int = 1) -> bool: + async def send_message(self, event, data: dict, timeout: float | int = 1) -> bool: try: await asyncio.wait_for( event.websocket.send_json(data=data), timeout=timeout @@ -448,7 +448,7 @@ class Queue: await self.clean_event(event) return False - async def get_message(self, event, timeout=5) -> Tuple[PredictBody | None, bool]: + async def get_message(self, event, timeout=5) -> tuple[PredictBody | None, bool]: try: data = await asyncio.wait_for( event.websocket.receive_json(), timeout=timeout diff --git a/gradio/ranged_response.py b/gradio/ranged_response.py index eea65227da..88eb696184 100644 --- a/gradio/ranged_response.py +++ b/gradio/ranged_response.py @@ -5,7 +5,7 @@ from __future__ import annotations import os import re import stat -from typing import Dict, NamedTuple +from typing import NamedTuple from urllib.parse import quote import aiofiles @@ -36,7 +36,7 @@ class OpenRange(NamedTuple): def clamp(self, start: int, end: int) -> ClosedRange: begin = max(self.start, start) - end = min((x for x in (self.end, end) if x)) + end = min(x for x in (self.end, end) if x) begin = min(begin, end) end = max(begin, end) @@ -51,7 +51,7 @@ class RangedFileResponse(Response): self, path: str | os.PathLike, range: OpenRange, - headers: Dict[str, str] | None = None, + headers: dict[str, str] | None = None, media_type: str | None = None, filename: str | None = None, stat_result: os.stat_result | None = None, diff --git a/gradio/reload.py b/gradio/reload.py index 72276e1eae..1ae802ff32 100644 --- a/gradio/reload.py +++ b/gradio/reload.py @@ -18,10 +18,7 @@ def run_in_reload_mode(): args = sys.argv[1:] if len(args) == 0: raise ValueError("No file specified.") - if len(args) == 1: - demo_name = "demo" - else: - demo_name = args[1] + demo_name = "demo" if len(args) == 1 else args[1] original_path = args[0] abs_original_path = utils.abspath(original_path) diff --git a/gradio/routes.py b/gradio/routes.py index 0b2961a270..e275d5ae2c 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -309,10 +309,8 @@ class App(FastAPI): ) abs_path = utils.abspath(path_or_url) in_blocklist = any( - ( - utils.is_in_or_equal(abs_path, blocked_path) - for blocked_path in blocks.blocked_paths - ) + utils.is_in_or_equal(abs_path, blocked_path) + for blocked_path in blocks.blocked_paths ) if in_blocklist: raise HTTPException(403, f"File not allowed: {path_or_url}.") @@ -320,10 +318,8 @@ class App(FastAPI): in_app_dir = utils.abspath(app.cwd) in abs_path.parents created_by_app = str(abs_path) in set().union(*blocks.temp_file_sets) in_file_dir = any( - ( - utils.is_in_or_equal(abs_path, allowed_path) - for allowed_path in blocks.allowed_paths - ) + utils.is_in_or_equal(abs_path, allowed_path) + for allowed_path in blocks.allowed_paths ) was_uploaded = utils.abspath(app.uploaded_file_dir) in abs_path.parents @@ -464,14 +460,15 @@ class App(FastAPI): ) else: fn_index_inferred = body.fn_index - if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn( - fn_index_inferred + if ( + not app.get_blocks().api_open + and app.get_blocks().queue_enabled_for_fn(fn_index_inferred) + and f"Bearer {app.queue_token}" != request.headers.get("Authorization") ): - if f"Bearer {app.queue_token}" != request.headers.get("Authorization"): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authorized to skip the queue", - ) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authorized to skip the queue", + ) # If this fn_index cancels jobs, then the only input we need is the # current session hash diff --git a/gradio/templates.py b/gradio/templates.py index dc27f6e118..5e970be1a8 100644 --- a/gradio/templates.py +++ b/gradio/templates.py @@ -1,7 +1,6 @@ from __future__ import annotations -import typing -from typing import Any, Callable, Tuple +from typing import Any, Callable import numpy as np from PIL.Image import Image @@ -55,7 +54,7 @@ class Webcam(components.Image): self, value: str | Image | np.ndarray | None = None, *, - shape: Tuple[int, int] | None = None, + shape: tuple[int, int] | None = None, image_mode: str = "RGB", invert_colors: bool = False, source: str = "webcam", @@ -102,7 +101,7 @@ class Sketchpad(components.Image): self, value: str | Image | np.ndarray | None = None, *, - shape: Tuple[int, int] = (28, 28), + shape: tuple[int, int] = (28, 28), image_mode: str = "L", invert_colors: bool = True, source: str = "canvas", @@ -149,7 +148,7 @@ class Paint(components.Image): self, value: str | Image | np.ndarray | None = None, *, - shape: Tuple[int, int] | None = None, + shape: tuple[int, int] | None = None, image_mode: str = "RGB", invert_colors: bool = False, source: str = "canvas", @@ -196,7 +195,7 @@ class ImageMask(components.Image): self, value: str | Image | np.ndarray | None = None, *, - shape: Tuple[int, int] | None = None, + shape: tuple[int, int] | None = None, image_mode: str = "RGB", invert_colors: bool = False, source: str = "upload", @@ -243,7 +242,7 @@ class ImagePaint(components.Image): self, value: str | Image | np.ndarray | None = None, *, - shape: Tuple[int, int] | None = None, + shape: tuple[int, int] | None = None, image_mode: str = "RGB", invert_colors: bool = False, source: str = "upload", @@ -290,7 +289,7 @@ class Pil(components.Image): self, value: str | Image | np.ndarray | None = None, *, - shape: Tuple[int, int] | None = None, + shape: tuple[int, int] | None = None, image_mode: str = "RGB", invert_colors: bool = False, source: str = "upload", @@ -372,7 +371,7 @@ class Microphone(components.Audio): def __init__( self, - value: str | Tuple[int, np.ndarray] | Callable | None = None, + value: str | tuple[int, np.ndarray] | Callable | None = None, *, source: str = "microphone", type: str = "numpy", @@ -407,7 +406,7 @@ class Files(components.File): def __init__( self, - value: str | typing.List[str] | Callable | None = None, + value: str | list[str] | Callable | None = None, *, file_count: str = "multiple", type: str = "file", @@ -440,12 +439,12 @@ class Numpy(components.Dataframe): def __init__( self, - value: typing.List[typing.List[Any]] | Callable | None = None, + value: list[list[Any]] | Callable | None = None, *, - headers: typing.List[str] | None = None, - row_count: int | Tuple[int, str] = (1, "dynamic"), - col_count: int | Tuple[int, str] | None = None, - datatype: str | typing.List[str] = "str", + headers: list[str] | None = None, + row_count: int | tuple[int, str] = (1, "dynamic"), + col_count: int | tuple[int, str] | None = None, + datatype: str | list[str] = "str", type: str = "numpy", max_rows: int | None = 20, max_cols: int | None = None, @@ -487,12 +486,12 @@ class Matrix(components.Dataframe): def __init__( self, - value: typing.List[typing.List[Any]] | Callable | None = None, + value: list[list[Any]] | Callable | None = None, *, - headers: typing.List[str] | None = None, - row_count: int | Tuple[int, str] = (1, "dynamic"), - col_count: int | Tuple[int, str] | None = None, - datatype: str | typing.List[str] = "str", + headers: list[str] | None = None, + row_count: int | tuple[int, str] = (1, "dynamic"), + col_count: int | tuple[int, str] | None = None, + datatype: str | list[str] = "str", type: str = "array", max_rows: int | None = 20, max_cols: int | None = None, @@ -534,12 +533,12 @@ class List(components.Dataframe): def __init__( self, - value: typing.List[typing.List[Any]] | Callable | None = None, + value: list[list[Any]] | Callable | None = None, *, - headers: typing.List[str] | None = None, - row_count: int | Tuple[int, str] = (1, "dynamic"), - col_count: int | Tuple[int, str] = 1, - datatype: str | typing.List[str] = "str", + headers: list[str] | None = None, + row_count: int | tuple[int, str] = (1, "dynamic"), + col_count: int | tuple[int, str] = 1, + datatype: str | list[str] = "str", type: str = "array", max_rows: int | None = 20, max_cols: int | None = None, diff --git a/gradio/themes/base.py b/gradio/themes/base.py index 85dba4f44b..2c33b8079a 100644 --- a/gradio/themes/base.py +++ b/gradio/themes/base.py @@ -5,7 +5,7 @@ import re import tempfile import textwrap from pathlib import Path -from typing import Dict, Iterable +from typing import Iterable import huggingface_hub import requests @@ -108,17 +108,17 @@ class ThemeClass: return schema @classmethod - def load(cls, path: str) -> "ThemeClass": + def load(cls, path: str) -> ThemeClass: """Load a theme from a json file. Parameters: path: The filepath to read. """ - theme = json.load(open(path), object_hook=fonts.as_font) - return cls.from_dict(theme) + with open(path) as fp: + return cls.from_dict(json.load(fp, object_hook=fonts.as_font)) @classmethod - def from_dict(cls, theme: Dict[str, Dict[str, str]]) -> "ThemeClass": + def from_dict(cls, theme: dict[str, dict[str, str]]) -> ThemeClass: """Create a theme instance from a dictionary representation. Parameters: @@ -142,8 +142,7 @@ class ThemeClass: Parameters: filename: The path to write the theme too """ - as_dict = self.to_dict() - json.dump(as_dict, open(Path(filename), "w"), cls=fonts.FontEncoder) + Path(filename).write_text(json.dumps(self.to_dict(), cls=fonts.FontEncoder)) @classmethod def from_hub(cls, repo_name: str, hf_token: str | None = None): @@ -248,10 +247,7 @@ class ThemeClass: # If no version, set the version to next patch release if not version: - if space_exists: - version = self._get_next_version(space_info) - else: - version = "0.0.1" + version = self._get_next_version(space_info) if space_exists else "0.0.1" else: _ = semver.Version(version) @@ -279,7 +275,7 @@ class ThemeClass: ) readme_file.write(textwrap.dedent(readme_content)) with tempfile.NamedTemporaryFile(mode="w", delete=False) as app_file: - contents = open(str(Path(__file__).parent / "app.py")).read() + contents = (Path(__file__).parent / "app.py").read_text() contents = re.sub( r"theme=gr.themes.Default\(\)", f"theme='{space_id}'", diff --git a/gradio/themes/builder.py b/gradio/themes/builder.py index 250aaef00f..f3cc870bbf 100644 --- a/gradio/themes/builder.py +++ b/gradio/themes/builder.py @@ -76,7 +76,11 @@ css = """ } """ -with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as demo: +with gr.Blocks( # noqa: SIM117 + theme=gr.themes.Base(), + css=css, + title="Gradio Theme Builder", +) as demo: with gr.Row(): with gr.Column(scale=1, elem_id="controls", min_width=400): with gr.Row(): diff --git a/gradio/themes/utils/semver_match.py b/gradio/themes/utils/semver_match.py index 6174360688..1b12d3c3e3 100644 --- a/gradio/themes/utils/semver_match.py +++ b/gradio/themes/utils/semver_match.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import List import huggingface_hub import semantic_version @@ -17,7 +16,7 @@ class ThemeAsset: self.version = semver.Version(self.filename.split("@")[1].replace(".json", "")) -def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> List[ThemeAsset]: +def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> list[ThemeAsset]: if "gradio-theme" not in getattr(space_info, "tags", []): raise ValueError(f"{space_info.id} is not a valid gradio-theme space!") @@ -29,7 +28,7 @@ def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> List[Theme def get_matching_version( - assets: List[ThemeAsset], expression: str | None + assets: list[ThemeAsset], expression: str | None ) -> ThemeAsset | None: expression = expression or "*" diff --git a/gradio/utils.py b/gradio/utils.py index 8cbf6592d0..9a3a9a3fe9 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -27,11 +27,7 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Generator, - List, - Tuple, - Type, TypeVar, Union, ) @@ -104,10 +100,10 @@ def get_local_ip_address() -> str: return ip_address -def initiated_analytics(data: Dict[str, Any]) -> None: +def initiated_analytics(data: dict[str, Any]) -> None: data.update({"ip_address": get_local_ip_address()}) - def initiated_analytics_thread(data: Dict[str, Any]) -> None: + def initiated_analytics_thread(data: dict[str, Any]) -> None: try: requests.post( f"{analytics_url}gradio-initiated-analytics/", data=data, timeout=5 @@ -118,10 +114,10 @@ def initiated_analytics(data: Dict[str, Any]) -> None: threading.Thread(target=initiated_analytics_thread, args=(data,)).start() -def launch_analytics(data: Dict[str, Any]) -> None: +def launch_analytics(data: dict[str, Any]) -> None: data.update({"ip_address": get_local_ip_address()}) - def launch_analytics_thread(data: Dict[str, Any]) -> None: + def launch_analytics_thread(data: dict[str, Any]) -> None: try: requests.post( f"{analytics_url}gradio-launched-analytics/", data=data, timeout=5 @@ -132,7 +128,7 @@ def launch_analytics(data: Dict[str, Any]) -> None: threading.Thread(target=launch_analytics_thread, args=(data,)).start() -def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None: +def launched_telemetry(blocks: gradio.Blocks, data: dict[str, Any]) -> None: blocks_telemetry, inputs_telemetry, outputs_telemetry, targets_telemetry = ( [], [], @@ -180,7 +176,7 @@ def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None: data.update(additional_data) data.update({"ip_address": get_local_ip_address()}) - def launched_telemtry_thread(data: Dict[str, Any]) -> None: + def launched_telemtry_thread(data: dict[str, Any]) -> None: try: requests.post( f"{analytics_url}gradio-launched-telemetry/", data=data, timeout=5 @@ -191,10 +187,10 @@ def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None: threading.Thread(target=launched_telemtry_thread, args=(data,)).start() -def integration_analytics(data: Dict[str, Any]) -> None: +def integration_analytics(data: dict[str, Any]) -> None: data.update({"ip_address": get_local_ip_address()}) - def integration_analytics_thread(data: Dict[str, Any]) -> None: + def integration_analytics_thread(data: dict[str, Any]) -> None: try: requests.post( f"{analytics_url}gradio-integration-analytics/", data=data, timeout=5 @@ -213,7 +209,7 @@ def error_analytics(message: str) -> None: """ data = {"ip_address": get_local_ip_address(), "error": message} - def error_analytics_thread(data: Dict[str, Any]) -> None: + def error_analytics_thread(data: dict[str, Any]) -> None: try: requests.post( f"{analytics_url}gradio-error-analytics/", data=data, timeout=5 @@ -320,7 +316,7 @@ def launch_counter() -> None: pass -def get_default_args(func: Callable) -> List[Any]: +def get_default_args(func: Callable) -> list[Any]: signature = inspect.signature(func) return [ v.default if v.default is not inspect.Parameter.empty else None @@ -329,7 +325,7 @@ def get_default_args(func: Callable) -> List[Any]: def assert_configs_are_equivalent_besides_ids( - config1: Dict, config2: Dict, root_keys: Tuple = ("mode",) + config1: dict, config2: dict, root_keys: tuple = ("mode",) ): """Allows you to test if two different Blocks configs produce the same demo. @@ -382,7 +378,7 @@ def assert_configs_are_equivalent_besides_ids( return True -def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]): +def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]): if len(ner_groups) == 0: return [(input_string, None)] @@ -400,7 +396,7 @@ def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]): return output -def delete_none(_dict: Dict, skip_value: bool = False) -> Dict: +def delete_none(_dict: dict, skip_value: bool = False) -> dict: """ Delete keys whose values are None from a dictionary """ @@ -412,14 +408,14 @@ def delete_none(_dict: Dict, skip_value: bool = False) -> Dict: return _dict -def resolve_singleton(_list: List[Any] | Any) -> Any: +def resolve_singleton(_list: list[Any] | Any) -> Any: if len(_list) == 1: return _list[0] else: return _list -def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockContext]: +def component_or_layout_class(cls_name: str) -> type[Component] | type[BlockContext]: """ Returns the component, template, or layout class with the given class name, or raises a ValueError if not found. @@ -536,9 +532,9 @@ class AsyncRequest: method: Method, url: str, *, - validation_model: Type[BaseModel] | None = None, + validation_model: type[BaseModel] | None = None, validation_function: Union[Callable, None] = None, - exception_type: Type[Exception] = Exception, + exception_type: type[Exception] = Exception, raise_for_status: bool = False, client: httpx.AsyncClient | None = None, **kwargs, @@ -565,7 +561,7 @@ class AsyncRequest: self._request = self._create_request(method, url, **kwargs) self.client_ = client or self.client - def __await__(self) -> Generator[None, Any, "AsyncRequest"]: + def __await__(self) -> Generator[None, Any, AsyncRequest]: """ Wrap Request's __await__ magic function to create request calls which are executed in one line. """ @@ -740,7 +736,7 @@ def sanitize_value_for_csv(value: str | Number) -> str | Number: return value -def sanitize_list_for_csv(values: List[Any]) -> List[Any]: +def sanitize_list_for_csv(values: list[Any]) -> list[Any]: """ Sanitizes a list of values (or a list of list of values) that is being written to a CSV file to prevent CSV injection attacks. @@ -756,7 +752,7 @@ def sanitize_list_for_csv(values: List[Any]) -> List[Any]: return sanitized_values -def append_unique_suffix(name: str, list_of_names: List[str]): +def append_unique_suffix(name: str, list_of_names: list[str]): """Appends a numerical suffix to `name` so that it does not appear in `list_of_names`.""" set_of_names: set[str] = set(list_of_names) # for O(1) lookup if name not in set_of_names: @@ -815,8 +811,8 @@ def set_task_name(task, session_hash: str, fn_index: int, batch: bool): def get_cancel_function( - dependencies: List[Dict[str, Any]] -) -> Tuple[Callable, List[int]]: + dependencies: list[dict[str, Any]] +) -> tuple[Callable, list[int]]: fn_to_comp = {} for dep in dependencies: if Context.root_block: @@ -858,7 +854,7 @@ def is_special_typed_parameter(name, parameter_types): return is_request or is_event_data -def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool): +def check_function_inputs_match(fn: Callable, inputs: list, inputs_as_dict: bool): """ Checks if the input component set matches the function Returns: None if valid, a string error message if mismatch @@ -878,9 +874,8 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool max_args += 1 elif param.kind == param.VAR_POSITIONAL: max_args = infinity - elif param.kind == param.KEYWORD_ONLY: - if not has_default: - return f"Keyword-only args must have default values for function {fn}" + elif param.kind == param.KEYWORD_ONLY and not has_default: + return f"Keyword-only args must have default values for function {fn}" arg_count = 1 if inputs_as_dict else len(inputs) if min_args == max_args and max_args != arg_count: warnings.warn( @@ -918,15 +913,15 @@ def tex2svg(formula, *args): with MatplotlibBackendMananger(): import matplotlib.pyplot as plt - FONTSIZE = 20 - DPI = 300 + fontsize = 20 + dpi = 300 plt.rc("mathtext", fontset="cm") fig = plt.figure(figsize=(0.01, 0.01)) - fig.text(0, 0, rf"${formula}$", fontsize=FONTSIZE) + fig.text(0, 0, rf"${formula}$", fontsize=fontsize) output = BytesIO() fig.savefig( output, - dpi=DPI, + dpi=dpi, transparent=True, format="svg", bbox_inches="tight", @@ -942,7 +937,7 @@ def tex2svg(formula, *args): height_match = re.search(r'height="([\d.]+)pt"', svg_code) if height_match: height = float(height_match.group(1)) - new_height = height / FONTSIZE # conversion from pt to em + new_height = height / fontsize # conversion from pt to em svg_code = re.sub( r'height="[\d.]+pt"', f'height="{new_height}em"', svg_code ) @@ -1016,7 +1011,7 @@ def get_serializer_name(block: Block) -> str | None: def highlight_code(code, name, attrs): try: lexer = get_lexer_by_name(name) - except: + except Exception: lexer = get_lexer_by_name("text") formatter = HtmlFormatter() @@ -1047,3 +1042,10 @@ def get_markdown_parser() -> MarkdownIt: md.add_render_rule("link_open", render_blank_link) return md + + +HTML_TAG_RE = re.compile("<.*?>") + + +def remove_html_tags(raw_html: str | None) -> str: + return re.sub(HTML_TAG_RE, "", raw_html or "") diff --git a/pyproject.toml b/pyproject.toml index 23f18e523d..b79a15b7bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,10 +67,9 @@ extend-select = [ "B", "C", "I", - # Formatting-related UP rules - "UP030", - "UP031", - "UP032", + "N", + "SIM", + "UP", ] ignore = [ "C901", # function is too complex (TODO: un-ignore this) @@ -79,6 +78,9 @@ ignore = [ "B017", # pytest.raises considered evil "B028", # explicit stacklevel for warnings "E501", # from scripts/lint_backend.sh + "SIM105", # contextlib.suppress (has a performance cost) + "SIM117", # multiple nested with blocks (doesn't look good with gr.Row etc) + "UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them) ] [tool.ruff.per-file-ignores] @@ -91,3 +93,6 @@ ignore = [ "gradio/__init__.py" = [ "F401", # "Imported but unused" (TODO: it would be better to be explicit and use __all__) ] +"gradio/routes.py" = [ + "UP006", # Pydantic on Python 3.7 requires old-style type annotations (TODO: drop when Python 3.7 is dropped) +] diff --git a/test/requirements-37.txt b/test/requirements-37.txt index 0967672ed3..8f7683d20d 100644 --- a/test/requirements-37.txt +++ b/test/requirements-37.txt @@ -197,7 +197,7 @@ respx==0.19.2 # via -r requirements.in rfc3986[idna2008]==1.5.0 # via httpx -ruff==0.0.263 +ruff==0.0.264 # via -r requirements.in s3transfer==0.6.0 # via boto3 diff --git a/test/requirements.txt b/test/requirements.txt index 2ddc10dcd1..9196e2ae50 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -185,7 +185,7 @@ requests==2.28.1 # transformers respx==0.19.2 # via -r requirements.in -ruff==0.0.263 +ruff==0.0.264 # via -r requirements.in rfc3986[idna2008]==1.5.0 # via httpx diff --git a/test/test_blocks.py b/test/test_blocks.py index dc1b1f21f4..baf76037a5 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -279,8 +279,7 @@ class TestBlocksMethods: return 42 def generator_function(): - for index in range(10): - yield index + yield from range(10) with gr.Blocks() as demo: @@ -670,8 +669,7 @@ class TestCallFunction: @pytest.mark.asyncio async def test_call_generator(self): def generator(x): - for i in range(x): - yield i + yield from range(x) with gr.Blocks() as demo: inp = gr.Number() @@ -1368,11 +1366,10 @@ class TestProgressBar: await ws.send(json.dumps({"data": [0], "fn_index": 0})) if msg["msg"] == "send_hash": await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"})) - if msg["msg"] == "progress": - if msg[ - "progress_data" - ]: # Ignore empty lists which sometimes appear on Windows - progress_updates.append(msg["progress_data"]) + if ( + msg["msg"] == "progress" and msg["progress_data"] + ): # Ignore empty lists which sometimes appear on Windows + progress_updates.append(msg["progress_data"]) if msg["msg"] == "process_completed": completed = True break diff --git a/test/test_components.py b/test/test_components.py index a245ed3ee6..a1b472e482 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1489,7 +1489,8 @@ class TestLabel: label_output = gr.Label() label = label_output.postprocess(y) assert label == {"label": "happy"} - assert json.load(open(label_output.deserialize(label))) == label + with open(label_output.deserialize(label)) as f: + assert json.load(f) == label y = {3: 0.7, 1: 0.2, 0: 0.1} label = label_output.postprocess(y) diff --git a/test/test_external.py b/test/test_external.py index 6bb1352438..9c20ee27ba 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -11,7 +11,7 @@ from gradio_client import media_data import gradio as gr from gradio.context import Context -from gradio.exceptions import InvalidApiName +from gradio.exceptions import InvalidApiNameError from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples """ @@ -190,16 +190,16 @@ class TestLoadInterface: def test_sentiment_model(self): io = gr.load("models/distilbert-base-uncased-finetuned-sst-2-english") try: - output = io("I am happy, I love you") - assert json.load(open(output))["label"] == "POSITIVE" + with open(io("I am happy, I love you")) as f: + assert json.load(f)["label"] == "POSITIVE" except TooManyRequestsError: pass def test_image_classification_model(self): io = gr.Blocks.load(name="models/google/vit-base-patch16-224") try: - output = io("gradio/test_data/lion.jpg") - assert json.load(open(output))["label"] == "lion" + with open(io("gradio/test_data/lion.jpg")) as f: + assert json.load(f)["label"] == "lion" except TooManyRequestsError: pass @@ -214,9 +214,9 @@ class TestLoadInterface: def test_numerical_to_label_space(self): io = gr.load("spaces/abidlabs/titanic-survival") try: - output = io("male", 77, 10) - assert json.load(open(output))["label"] == "Perishes" assert io.theme.name == "soft" + with open(io("male", 77, 10)) as f: + assert json.load(f)["label"] == "Perishes" except TooManyRequestsError: pass @@ -472,7 +472,7 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme): def test_raise_value_error_when_api_name_invalid(): - with pytest.raises(InvalidApiName): + with pytest.raises(InvalidApiNameError): demo = gr.Blocks.load(name="spaces/gradio/hello_world") demo("freddy", api_name="route does not exist") diff --git a/test/test_interpretation.py b/test/test_interpretation.py index 03b97cf1d0..95cab4167b 100644 --- a/test/test_interpretation.py +++ b/test/test_interpretation.py @@ -26,7 +26,7 @@ class TestDefault: "interpretation" ] assert interpretation[0][1] > 0 # Checks to see if the first word has >0 score. - assert 0 == interpretation[-1][1] # Checks to see if the last word has 0 score. + assert interpretation[-1][1] == 0 # Checks to see if the last word has 0 score. class TestShapley: @@ -92,9 +92,9 @@ class TestHelperMethods: def test_quantify_difference_with_label(self): iface = Interface(lambda text: len(text), ["textbox"], ["label"]) diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"]) - assert -7 == diff + assert diff == -7 diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"]) - assert -100 == diff + assert diff == -100 def test_quantify_difference_with_confidences(self): iface = Interface(lambda text: len(text), ["textbox"], ["label"]) @@ -104,11 +104,11 @@ class TestHelperMethods: diff = gradio.interpretation.quantify_difference_in_label( iface, [output_1], [output_2] ) - assert 0.3 == pytest.approx(diff) + assert pytest.approx(diff) == 0.3 diff = gradio.interpretation.quantify_difference_in_label( iface, [output_1], [output_3] ) - assert 0.8 == pytest.approx(diff) + assert pytest.approx(diff) == 0.8 def test_get_regression_value(self): iface = Interface(lambda text: text, ["textbox"], ["label"]) @@ -118,19 +118,19 @@ class TestHelperMethods: diff = gradio.interpretation.get_regression_or_classification_value( iface, [output_1], [output_2] ) - assert 0 == diff + assert diff == 0 diff = gradio.interpretation.get_regression_or_classification_value( iface, [output_1], [output_3] ) - assert 0.1 == pytest.approx(diff) + assert pytest.approx(diff) == 0.1 def test_get_classification_value(self): iface = Interface(lambda text: text, ["textbox"], ["label"]) diff = gradio.interpretation.get_regression_or_classification_value( iface, ["cat"], ["test"] ) - assert 1 == diff + assert diff == 1 diff = gradio.interpretation.get_regression_or_classification_value( iface, ["test"], ["test"] ) - assert 0 == diff + assert diff == 0 diff --git a/test/test_mix.py b/test/test_mix.py index c8fe5dda2d..a8bcaeb9d7 100644 --- a/test/test_mix.py +++ b/test/test_mix.py @@ -28,8 +28,8 @@ class TestSeries: io2 = gr.load("spaces/abidlabs/image-classifier") series = mix.Series(io1, io2) try: - output = series("gradio/test_data/lion.jpg") - assert json.load(open(output))["label"] == "lion" + with open(series("gradio/test_data/lion.jpg")) as f: + assert json.load(f)["label"] == "lion" except TooManyRequestsError: pass diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index fca2fdaa1b..edc1065cdf 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -187,15 +187,16 @@ class TestVideoProcessing: def test_video_has_playable_codecs_catches_exceptions( self, exception_to_raise, test_file_dir ): - with patch("ffmpy.FFprobe.run", side_effect=exception_to_raise): - with tempfile.NamedTemporaryFile( - suffix="out.avi", delete=False - ) as tmp_not_playable_vid: - shutil.copy( - str(test_file_dir / "bad_video_sample.mp4"), - tmp_not_playable_vid.name, - ) - assert processing_utils.video_is_playable(tmp_not_playable_vid.name) + with patch( + "ffmpy.FFprobe.run", side_effect=exception_to_raise + ), tempfile.NamedTemporaryFile( + suffix="out.avi", delete=False + ) as tmp_not_playable_vid: + shutil.copy( + str(test_file_dir / "bad_video_sample.mp4"), + tmp_not_playable_vid.name, + ) + assert processing_utils.video_is_playable(tmp_not_playable_vid.name) def test_convert_video_to_playable_mp4(self, test_file_dir): with tempfile.NamedTemporaryFile( diff --git a/test/test_routes.py b/test/test_routes.py index daef41cf38..bd1a4e4fea 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -56,9 +56,8 @@ class TestRoutes: assert response.status_code == 200 def test_upload_path(self, test_client): - response = test_client.post( - "/upload", files={"files": open("test/test_files/alphabet.txt", "r")} - ) + with open("test/test_files/alphabet.txt") as f: + response = test_client.post("/upload", files={"files": f}) assert response.status_code == 200 file = response.json()[0] assert "alphabet" in file @@ -72,9 +71,8 @@ class TestRoutes: app, _, _ = io.launch(prevent_thread_lock=True) test_client = TestClient(app) try: - response = test_client.post( - "/upload", files={"files": open("test/test_files/alphabet.txt", "r")} - ) + with open("test/test_files/alphabet.txt") as f: + response = test_client.post("/upload", files={"files": f}) assert response.status_code == 200 file = response.json()[0] assert "alphabet" in file @@ -399,8 +397,7 @@ class TestRoutes: class TestGeneratorRoutes: def test_generator(self): def generator(string): - for char in string: - yield char + yield from string io = Interface(generator, "text", "text") app, _, _ = io.queue().launch(prevent_thread_lock=True) diff --git a/test/test_utils.py b/test/test_utils.py index 0cd909f6ed..316f91f3d9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -350,7 +350,7 @@ class TestRequest: name: str job: str id: str - createdAt: str + createdAt: str # noqa: N815 client_response: AsyncRequest = await AsyncRequest( method=AsyncRequest.Method.POST, @@ -434,7 +434,7 @@ async def test_validate_with_model(respx_mock): name: str job: str id: str - createdAt: str + createdAt: str # noqa: N815 client_response: AsyncRequest = await AsyncRequest( method=AsyncRequest.Method.POST, @@ -469,7 +469,7 @@ async def test_validate_and_fail_with_model(respx_mock): @mock.patch("gradio.utils.AsyncRequest._validate_response_data") @pytest.mark.asyncio async def test_exception_type(validate_response_data, respx_mock): - class ResponseValidationException(Exception): + class ResponseValidationError(Exception): message = "Response object is not valid." validate_response_data.side_effect = Exception() @@ -479,9 +479,9 @@ async def test_exception_type(validate_response_data, respx_mock): client_response: AsyncRequest = await AsyncRequest( method=AsyncRequest.Method.GET, url=MOCK_REQUEST_URL, - exception_type=ResponseValidationException, + exception_type=ResponseValidationError, ) - assert isinstance(client_response.exception, ResponseValidationException) + assert isinstance(client_response.exception, ResponseValidationError) @pytest.mark.asyncio @@ -511,9 +511,8 @@ async def test_validate_with_function(respx_mock): @pytest.mark.asyncio async def test_validate_and_fail_with_function(respx_mock): def has_name(response): - if response["name"] is not None: - if response["name"] == "Alex": - return response + if response["name"] is not None and response["name"] == "Alex": + return response raise Exception respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response({"name": "morpheus"}))