From 5310782ed9c6de8058f2db3e5a97b6412b1e450a Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 29 Dec 2022 15:18:19 -0500 Subject: [PATCH] Added typing to even more files (#2900) * started pathlib * blocks.py * more changes * fixes * typing * formatting * typing * renaming files * changelog * script * changelog * lint * routes * renamed * state * formatting * state * type check script * remove strictness * switched to pyright * switched to pyright * fixed flaky tests * fixed test xray * fixed load test * fixed blocks tests * formatting * fixed components test * uncomment tests * fixed interpretation tests * formatting * last tests hopefully * argh lint * component * fixed based on review * refactor * components.py t yping * components.py * formatting * lint script * merge * merge * lint * pathlib * lint * events too * lint script * fixing tests * lint * examples * serializing * more files * formatting * flagging.py * added to lint script * fixed tab * attempt fix * serialize fix * formatting * all demos queue * addressed review comments * formatting --- demo/all_demos/run.py | 2 +- gradio/components.py | 12 ++- gradio/examples.py | 47 ++++++---- gradio/external.py | 10 +- gradio/external_utils.py | 20 ++-- gradio/flagging.py | 172 +++++++++++++++++----------------- gradio/inputs.py | 1 + gradio/layouts.py | 80 ++++++++-------- gradio/outputs.py | 1 + gradio/processing_utils.py | 6 +- gradio/serializing.py | 53 +++++++---- gradio/utils.py | 6 +- scripts/type_check_backend.sh | 2 +- 13 files changed, 223 insertions(+), 189 deletions(-) diff --git a/demo/all_demos/run.py b/demo/all_demos/run.py index 11f326d9fa..9bf641bd6a 100644 --- a/demo/all_demos/run.py +++ b/demo/all_demos/run.py @@ -32,4 +32,4 @@ with gr.Blocks() as mega_demo: with gr.Tab(demo_name): demo.render() -mega_demo.launch() +mega_demo.queue().launch() diff --git a/gradio/components.py b/gradio/components.py index 825051f513..88a9bbbd14 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -2350,9 +2350,11 @@ class File( } def serialize( - self, x: str, load_dir: str = "", encryption_key: bytes | None = None - ) -> Dict: + self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None + ) -> Dict | None: serialized = FileSerializable.serialize(self, x, load_dir, encryption_key) + if serialized is None: + return None serialized["size"] = Path(serialized["name"]).stat().st_size return serialized @@ -3019,9 +3021,11 @@ class UploadButton( return deepcopy(media_data.BASE64_FILE) def serialize( - self, x: str, load_dir: str = "", encryption_key: bytes | None = None - ) -> Dict: + self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None + ) -> Dict | None: serialized = FileSerializable.serialize(self, x, load_dir, encryption_key) + if serialized is None: + return None serialized["size"] = Path(serialized["name"]).stat().st_size return serialized diff --git a/gradio/examples.py b/gradio/examples.py index 103956da23..a40ae25e90 100644 --- a/gradio/examples.py +++ b/gradio/examples.py @@ -8,7 +8,7 @@ import csv import os import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, List from gradio import utils from gradio.components import Dataset @@ -77,13 +77,13 @@ class Examples: self, examples: List[Any] | List[List[Any]] | str, inputs: IOComponent | List[IOComponent], - outputs: Optional[IOComponent | List[IOComponent]] = None, - fn: Optional[Callable] = None, + outputs: IOComponent | List[IOComponent] | None = None, + fn: Callable | None = None, cache_examples: bool = False, examples_per_page: int = 10, _api_mode: bool = False, - label: str = "Examples", - elem_id: Optional[str] = None, + label: str | None = "Examples", + elem_id: str | None = None, run_on_click: bool = False, preprocess: bool = True, postprocess: bool = True, @@ -115,7 +115,7 @@ class Examples: if not isinstance(inputs, list): inputs = [inputs] - if not isinstance(outputs, list): + if outputs and not isinstance(outputs, list): outputs = [outputs] working_directory = Path().absolute() @@ -131,12 +131,12 @@ class Examples: ): # If there is only one input component, examples can be provided as a regular list instead of a list of lists examples = [[e] for e in examples] elif isinstance(examples, str): - if not os.path.exists(examples): + if not Path(examples).exists(): raise FileNotFoundError( "Could not find examples directory: " + examples ) working_directory = examples - if not os.path.exists(os.path.join(examples, LOG_FILE)): + if not (Path(examples) / LOG_FILE).exists(): if len(inputs) == 1: examples = [[e] for e in os.listdir(examples)] else: @@ -145,7 +145,7 @@ class Examples: + LOG_FILE ) else: - with open(os.path.join(examples, LOG_FILE)) as logs: + with open(Path(examples) / LOG_FILE) as logs: examples = list(csv.reader(logs)) examples = [ examples[i][: len(inputs)] for i in range(1, len(examples)) @@ -221,8 +221,8 @@ class Examples: elem_id=elem_id, ) - self.cached_folder = os.path.join(CACHED_FOLDER, str(self.dataset._id)) - self.cached_file = os.path.join(self.cached_folder, "log.csv") + self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id) + self.cached_file = Path(self.cached_folder) / "log.csv" self.cache_examples = cache_examples self.run_on_click = run_on_click @@ -240,19 +240,24 @@ class Examples: return utils.resolve_singleton(processed_example) if Context.root_block: + if self.cache_examples and self.outputs: + targets = self.inputs_with_examples + else: + targets = self.inputs self.dataset.click( load_example, inputs=[self.dataset], - outputs=self.inputs_with_examples - + (self.outputs if self.cache_examples else []), + outputs=targets, # type: ignore postprocess=False, queue=False, ) if self.run_on_click and not self.cache_examples: + if self.fn is None: + raise ValueError("Cannot run_on_click if no function is provided") self.dataset.click( self.fn, - inputs=self.inputs, - outputs=self.outputs, + inputs=self.inputs, # type: ignore + outputs=self.outputs, # type: ignore ) if self.cache_examples: @@ -262,29 +267,30 @@ class Examples: """ Caches all of the examples so that their predictions can be shown immediately. """ - if os.path.exists(self.cached_file): + if Path(self.cached_file).exists(): print( - f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache." + f"Using cache from '{Path(self.cached_folder).resolve()}' directory. If method or examples have changed since last caching, delete this folder to clear cache." ) else: if Context.root_block is None: raise ValueError("Cannot cache examples if not in a Blocks context") - print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'") + print(f"Caching examples at: '{Path(self.cached_file).resolve()}'") cache_logger = CSVLogger() # create a fake dependency to process the examples and get the predictions dependency = Context.root_block.set_event_trigger( event_name="fake_event", fn=self.fn, - inputs=self.inputs_with_examples, - outputs=self.outputs, + inputs=self.inputs_with_examples, # type: ignore + outputs=self.outputs, # type: ignore preprocess=self.preprocess and not self._api_mode, postprocess=self.postprocess and not self._api_mode, batch=self.batch, ) fn_index = Context.root_block.dependencies.index(dependency) + assert self.outputs is not None cache_logger.setup(self.outputs, self.cached_folder) for example_id, _ in enumerate(self.examples): processed_input = self.processed_examples[example_id] @@ -310,6 +316,7 @@ class Examples: examples = list(csv.reader(cache)) example = examples[example_id + 1] # +1 to adjust for header output = [] + assert self.outputs is not None for component, value in zip(self.outputs, example): try: value_as_dict = ast.literal_eval(value) diff --git a/gradio/external.py b/gradio/external.py index fde935f278..4a13656233 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -63,7 +63,7 @@ def load_blocks_from_repo( return blocks -def from_model(model_name: str, api_key: str | None, alias: str, **kwargs): +def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs): model_url = "https://huggingface.co/{}".format(model_name) api_url = "https://api-inference.huggingface.co/models/{}".format(model_name) print("Fetching model from: {}".format(model_url)) @@ -316,7 +316,9 @@ def from_model(model_name: str, api_key: str | None, alias: str, **kwargs): return interface -def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks: +def from_spaces( + space_name: str, api_key: str | None, alias: str | None, **kwargs +) -> Blocks: space_url = "https://huggingface.co/spaces/{}".format(space_name) print("Fetching Space from: {}".format(space_url)) @@ -344,7 +346,7 @@ def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> B r"window.gradio_config = (.*?);[\s]*", r.text ) # some basic regex to extract the config try: - config = json.loads(result.group(1)) + config = json.loads(result.group(1)) # type: ignore except AttributeError: raise ValueError("Could not load the Space: {}".format(space_name)) if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces @@ -416,7 +418,7 @@ def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Bl def from_spaces_interface( model_name: str, config: Dict, - alias: str, + alias: str | None, api_key: str | None, iframe_url: str, **kwargs, diff --git a/gradio/external_utils.py b/gradio/external_utils.py index db496bac33..e00b2f4fdd 100644 --- a/gradio/external_utils.py +++ b/gradio/external_utils.py @@ -3,7 +3,6 @@ import base64 import json import math -import numbers import operator import re import warnings @@ -13,6 +12,7 @@ import requests import websockets import yaml from packaging import version +from websockets.legacy.protocol import WebSocketCommonProtocol from gradio import components, exceptions @@ -30,8 +30,13 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]: yaml_regex = re.search( "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text ) - example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])) - example_data = example_yaml.get("widget", {}).get("structuredData", {}) + if yaml_regex is None: + example_data = {} + else: + example_yaml = next( + yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]) + ) + example_data = example_yaml.get("widget", {}).get("structuredData", {}) if not example_data: raise ValueError( f"No example data found in README.md of {model_name} - Cannot build gradio demo. " @@ -41,7 +46,7 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]: # replace nan with string NaN for inference API for data in example_data.values(): for i, val in enumerate(data): - if isinstance(val, numbers.Number) and math.isnan(val): + if isinstance(val, float) and math.isnan(val): data[i] = "NaN" return example_data @@ -76,7 +81,7 @@ def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str] ################## -def postprocess_label(scores): +def postprocess_label(scores: Dict) -> Dict: sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True) return { "label": sorted_pred[0][0], @@ -117,9 +122,10 @@ def encode_to_base64(r: requests.Response) -> str: async def get_pred_from_ws( - websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str + websocket: WebSocketCommonProtocol, data: str, hash_data: str ) -> Dict[str, Any]: completed = False + resp = {} while not completed: msg = await websocket.recv() resp = json.loads(msg) @@ -135,7 +141,7 @@ async def get_pred_from_ws( def get_ws_fn(ws_url, headers): async def ws_fn(data, hash_data): - async with websockets.connect( + async with websockets.connect( # type: ignore ws_url, open_timeout=10, extra_headers=headers ) as websocket: return await get_pred_from_ws(websocket, data, hash_data) diff --git a/gradio/flagging.py b/gradio/flagging.py index 19753d6df2..2eb00cacbc 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -7,7 +7,8 @@ import json import os import uuid from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, List, Optional +from pathlib import Path +from typing import TYPE_CHECKING, Any, List import gradio as gr from gradio import encryptor, utils @@ -51,9 +52,9 @@ def _get_dataset_features_info(is_new, components): headers.append(component.label + " file") for _component, _type in file_preview_types.items(): if isinstance(component, _component): - infos["flagged"]["features"][component.label + " file"] = { - "_type": _type - } + infos["flagged"]["features"][ + (component.label or "") + " file" + ] = {"_type": _type} break headers.append("flag") @@ -85,9 +86,9 @@ class FlaggingCallback(ABC): def flag( self, flag_data: List[Any], - flag_option: Optional[str] = None, - flag_index: Optional[int] = None, - username: Optional[str] = None, + flag_option: str | None = None, + flag_index: int | None = None, + username: str | None = None, ) -> int: """ This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments. @@ -121,7 +122,7 @@ class SimpleCSVLogger(FlaggingCallback): def __init__(self): pass - def setup(self, components: List[IOComponent], flagging_dir: str): + def setup(self, components: List[IOComponent], flagging_dir: str | Path): self.components = components self.flagging_dir = flagging_dir os.makedirs(flagging_dir, exist_ok=True) @@ -129,17 +130,17 @@ class SimpleCSVLogger(FlaggingCallback): def flag( self, flag_data: List[Any], - flag_option: Optional[str] = None, - flag_index: Optional[int] = None, - username: Optional[str] = None, + flag_option: str | None = None, + flag_index: int | None = None, + username: str | None = None, ) -> int: flagging_dir = self.flagging_dir - log_filepath = os.path.join(flagging_dir, "log.csv") + log_filepath = Path(flagging_dir) / "log.csv" csv_data = [] for component, sample in zip(self.components, flag_data): - save_dir = os.path.join( - flagging_dir, utils.strip_invalid_filename_characters(component.label) + save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters( + component.label or "" ) csv_data.append( component.deserialize( @@ -178,8 +179,8 @@ class CSVLogger(FlaggingCallback): def setup( self, components: List[IOComponent], - flagging_dir: str, - encryption_key: Optional[str] = None, + flagging_dir: str | Path, + encryption_key: bytes | None = None, ): self.components = components self.flagging_dir = flagging_dir @@ -189,54 +190,49 @@ class CSVLogger(FlaggingCallback): def flag( self, flag_data: List[Any], - flag_option: Optional[str] = None, - flag_index: Optional[int] = None, - username: Optional[str] = None, + flag_option: str | None = None, + flag_index: int | None = None, + username: str | None = None, ) -> int: flagging_dir = self.flagging_dir - log_filepath = os.path.join(flagging_dir, "log.csv") - is_new = not os.path.exists(log_filepath) + log_filepath = Path(flagging_dir) / "log.csv" + is_new = not Path(log_filepath).exists() + headers = [ + component.label or f"component {idx}" + for idx, component in enumerate(self.components) + ] + [ + "flag", + "username", + "timestamp", + ] - if flag_index is None: - csv_data = [] - for idx, (component, sample) in enumerate(zip(self.components, flag_data)): - save_dir = os.path.join( - flagging_dir, - utils.strip_invalid_filename_characters( - component.label or f"component {idx}" - ), - ) - if utils.is_update(sample): - csv_data.append(str(sample)) - else: - csv_data.append( - component.deserialize( - sample, - save_dir=save_dir, - encryption_key=self.encryption_key, - ) - if sample is not None - else "" + csv_data = [] + for idx, (component, sample) in enumerate(zip(self.components, flag_data)): + save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters( + component.label or f"component {idx}" + ) + if utils.is_update(sample): + csv_data.append(str(sample)) + else: + csv_data.append( + component.deserialize( + sample, + save_dir=save_dir, + encryption_key=self.encryption_key, ) - csv_data.append(flag_option if flag_option is not None else "") - csv_data.append(username if username is not None else "") - csv_data.append(str(datetime.datetime.now())) - if is_new: - headers = [ - component.label or f"component {idx}" - for idx, component in enumerate(self.components) - ] + [ - "flag", - "username", - "timestamp", - ] + if sample is not None + else "" + ) + csv_data.append(flag_option if flag_option is not None else "") + csv_data.append(username if username is not None else "") + csv_data.append(str(datetime.datetime.now())) - def replace_flag_at_index(file_content): - file_content = io.StringIO(file_content) - content = list(csv.reader(file_content)) + def replace_flag_at_index(file_content: str, flag_index: int): + file_content_ = io.StringIO(file_content) + content = list(csv.reader(file_content_)) header = content[0] flag_col_index = header.index("flag") - content[flag_index][flag_col_index] = flag_option + content[flag_index][flag_col_index] = flag_option # type: ignore output = io.StringIO() writer = csv.writer(output) writer.writerows(utils.sanitize_list_for_csv(content)) @@ -252,7 +248,7 @@ class CSVLogger(FlaggingCallback): ) file_content = decrypted_csv.decode() if flag_index is not None: - file_content = replace_flag_at_index(file_content) + file_content = replace_flag_at_index(file_content, flag_index) output.write(file_content) writer = csv.writer(output) if flag_index is None: @@ -273,7 +269,7 @@ class CSVLogger(FlaggingCallback): else: with open(log_filepath, encoding="utf-8") as csvfile: file_content = csvfile.read() - file_content = replace_flag_at_index(file_content) + file_content = replace_flag_at_index(file_content, flag_index) with open( log_filepath, "w", newline="", encoding="utf-8" ) as csvfile: # newline parameter needed for Windows @@ -302,7 +298,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): self, hf_token: str, dataset_name: str, - organization: Optional[str] = None, + organization: str | None = None, private: bool = False, ): """ @@ -340,28 +336,28 @@ class HuggingFaceDatasetSaver(FlaggingCallback): self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10" self.components = components self.flagging_dir = flagging_dir - self.dataset_dir = os.path.join(flagging_dir, self.dataset_name) + self.dataset_dir = Path(flagging_dir) / self.dataset_name self.repo = huggingface_hub.Repository( - local_dir=self.dataset_dir, + local_dir=str(self.dataset_dir), clone_from=path_to_dataset_repo, use_auth_token=self.hf_token, ) self.repo.git_pull(lfs=True) # Should filename be user-specified? - self.log_file = os.path.join(self.dataset_dir, "data.csv") - self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json") + self.log_file = Path(self.dataset_dir) / "data.csv" + self.infos_file = Path(self.dataset_dir) / "dataset_infos.json" def flag( self, flag_data: List[Any], - flag_option: Optional[str] = None, - flag_index: Optional[int] = None, - username: Optional[str] = None, + flag_option: str | None = None, + flag_index: int | None = None, + username: str | None = None, ) -> int: self.repo.git_pull(lfs=True) - is_new = not os.path.exists(self.log_file) + is_new = not Path(self.log_file).exists() with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile: writer = csv.writer(csvfile) @@ -378,10 +374,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback): # Generate the row corresponding to the flagged sample csv_data = [] for component, sample in zip(self.components, flag_data): - save_dir = os.path.join( - self.dataset_dir, - utils.strip_invalid_filename_characters(component.label), - ) + save_dir = Path( + self.dataset_dir + ) / utils.strip_invalid_filename_characters(component.label or "") filepath = component.deserialize(sample, save_dir, None) csv_data.append(filepath) if isinstance(component, tuple(file_preview_types)): @@ -416,7 +411,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): self, hf_foken: str, dataset_name: str, - organization: Optional[str] = None, + organization: str | None = None, private: bool = False, verbose: bool = True, ): @@ -463,34 +458,34 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10" self.components = components self.flagging_dir = flagging_dir - self.dataset_dir = os.path.join(flagging_dir, self.dataset_name) + self.dataset_dir = Path(flagging_dir) / self.dataset_name self.repo = huggingface_hub.Repository( - local_dir=self.dataset_dir, + local_dir=str(self.dataset_dir), clone_from=path_to_dataset_repo, use_auth_token=self.hf_foken, ) self.repo.git_pull(lfs=True) - self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json") + self.infos_file = Path(self.dataset_dir) / "dataset_infos.json" def flag( self, flag_data: List[Any], - flag_option: Optional[str] = None, - flag_index: Optional[int] = None, - username: Optional[str] = None, - ) -> int: + flag_option: str | None = None, + flag_index: int | None = None, + username: str | None = None, + ) -> str: self.repo.git_pull(lfs=True) # Generate unique folder for the flagged sample unique_name = self.get_unique_name() # unique name for folder - folder_name = os.path.join( - self.dataset_dir, unique_name + folder_name = ( + Path(self.dataset_dir) / unique_name ) # unique folder for specific example os.makedirs(folder_name) # Now uses the existence of `dataset_infos.json` to determine if new - is_new = not os.path.exists(self.infos_file) + is_new = not Path(self.infos_file).exists() # File previews for certain input and output types infos, file_preview_types, _ = _get_dataset_features_info( @@ -505,9 +500,10 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): headers.append(component.label) try: - filepath = component.save_flagged( - folder_name, component.label, sample, None + save_dir = Path(folder_name) / utils.strip_invalid_filename_characters( + component.label or "" ) + filepath = component.deserialize(sample, save_dir, None) except Exception: # Could not parse 'sample' (mostly) because it was None and `component.save_flagged` # does not handle None cases. @@ -515,7 +511,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): filepath = None if isinstance(component, tuple(file_preview_types)): - headers.append(component.label + " file") + headers.append(component.label or "" + " file") csv_data.append( "{}/resolve/main/{}/{}".format( @@ -533,7 +529,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): metadata_dict = { header: _csv_data for header, _csv_data in zip(headers, csv_data) } - self.dump_json(metadata_dict, os.path.join(folder_name, "metadata.jsonl")) + self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl") if is_new: json.dump(infos, open(self.infos_file, "w")) @@ -545,7 +541,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): id = uuid.uuid4() return str(id) - def dump_json(self, thing: dict, file_path: str) -> None: + def dump_json(self, thing: dict, file_path: str | Path) -> None: with open(file_path, "w+", encoding="utf8") as f: json.dump(thing, f) diff --git a/gradio/inputs.py b/gradio/inputs.py index 184c965593..ae7c6c25db 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -1,3 +1,4 @@ +# type: ignore """ This module defines various classes that can serve as the `input` to an interface. Each class must inherit from `InputComponent`, and each class must define a path to its template. All of the subclasses of `InputComponent` are diff --git a/gradio/layouts.py b/gradio/layouts.py index 3861d1db2c..cab5aabffe 100644 --- a/gradio/layouts.py +++ b/gradio/layouts.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Callable, List, Optional, Type +from typing import TYPE_CHECKING, Callable, List, Type from gradio.blocks import BlockContext from gradio.documentation import document, set_documentation_group @@ -30,7 +30,7 @@ class Row(BlockContext): *, variant: str = "default", visible: bool = True, - elem_id: Optional[str] = None, + elem_id: str | None = None, **kwargs, ): """ @@ -49,7 +49,7 @@ class Row(BlockContext): @staticmethod def update( - visible: Optional[bool] = None, + visible: bool | None = None, ): return { "visible": visible, @@ -59,8 +59,8 @@ class Row(BlockContext): def style( self, *, - equal_height: Optional[bool] = None, - mobile_collapse: Optional[bool] = None, + equal_height: bool | None = None, + mobile_collapse: bool | None = None, **kwargs, ): """ @@ -100,7 +100,7 @@ class Column(BlockContext): min_width: int = 320, variant: str = "default", visible: bool = True, - elem_id: Optional[str] = None, + elem_id: str | None = None, **kwargs, ): """ @@ -129,8 +129,8 @@ class Column(BlockContext): @staticmethod def update( - variant: Optional[str] = None, - visible: Optional[bool] = None, + variant: str | None = None, + visible: bool | None = None, ): return { "variant": variant, @@ -147,9 +147,9 @@ class Tabs(BlockContext): def __init__( self, *, - selected: Optional[int | str] = None, + selected: int | str | None = None, visible: bool = True, - elem_id: Optional[str] = None, + elem_id: str | None = None, **kwargs, ): """ @@ -164,8 +164,9 @@ class Tabs(BlockContext): def get_config(self): return {"selected": self.selected, **super().get_config()} + @staticmethod def update( - selected: Optional[int | str] = None, + selected: int | str | None = None, ): return { "selected": selected, @@ -183,13 +184,27 @@ class Tabs(BlockContext): self.set_event_trigger("change", fn, inputs, outputs) -class TabItem(BlockContext): +@document() +class Tab(BlockContext): + """ + Tab (or its alias TabItem) is a layout element. Components defined within the Tab will be visible when this tab is selected tab. + Example: + with gradio.Blocks() as demo: + with gradio.Tab("Lion"): + gr.Image("lion.jpg") + gr.Button("New Lion") + with gradio.Tab("Tiger"): + gr.Image("tiger.jpg") + gr.Button("New Tiger") + Guides: controlling_layout + """ + def __init__( self, label: str, *, - id: Optional[int | str] = None, - elem_id: Optional[str] = None, + id: int | str | None = None, + elem_id: str | None = None, **kwargs, ): """ @@ -222,26 +237,11 @@ class TabItem(BlockContext): def get_expected_parent(self) -> Type[Tabs]: return Tabs - -@document() -class Tab(TabItem): - """ - Tab is a layout element. Components defined within the Tab will be visible when this tab is selected tab. - Example: - with gradio.Blocks() as demo: - with gradio.Tab("Lion"): - gr.Image("lion.jpg") - gr.Button("New Lion") - with gradio.Tab("Tiger"): - gr.Image("tiger.jpg") - gr.Button("New Tiger") - Guides: controlling_layout - """ - - pass + def get_block_name(self): + return "tabitem" -Tab = TabItem # noqa: F811 +TabItem = Tab class Group(BlockContext): @@ -258,7 +258,7 @@ class Group(BlockContext): self, *, visible: bool = True, - elem_id: Optional[str] = None, + elem_id: str | None = None, **kwargs, ): """ @@ -273,7 +273,7 @@ class Group(BlockContext): @staticmethod def update( - visible: Optional[bool] = None, + visible: bool | None = None, ): return { "visible": visible, @@ -296,7 +296,7 @@ class Box(BlockContext): self, *, visible: bool = True, - elem_id: Optional[str] = None, + elem_id: str | None = None, **kwargs, ): """ @@ -311,7 +311,7 @@ class Box(BlockContext): @staticmethod def update( - visible: Optional[bool] = None, + visible: bool | None = None, ): return { "visible": visible, @@ -342,7 +342,7 @@ class Accordion(BlockContext): *, open: bool = True, visible: bool = True, - elem_id: Optional[str] = None, + elem_id: str | None = None, **kwargs, ): """ @@ -365,9 +365,9 @@ class Accordion(BlockContext): @staticmethod def update( - open: Optional[bool] = None, - label: Optional[str] = None, - visible: Optional[bool] = None, + open: bool | None = None, + label: str | None = None, + visible: bool | None = None, ): return { "visible": visible, diff --git a/gradio/outputs.py b/gradio/outputs.py index f366155b34..2995fdcac8 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -1,3 +1,4 @@ +# type: ignore """ This module defines various classes that can serve as the `output` to an interface. Each class must inherit from `OutputComponent`, and each class must define a path to its template. All of the subclasses of `OutputComponent` are diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 56e7be128c..f408b37b3c 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -12,6 +12,7 @@ import tempfile import urllib.request import warnings from io import BytesIO +from pathlib import Path from typing import Dict, Tuple import numpy as np @@ -304,8 +305,9 @@ def dict_or_str_to_json_file(jsn, dir=None): return file_obj -def file_to_json(file_path: str) -> Dict: - return json.load(open(file_path)) +def file_to_json(file_path: str | Path) -> Dict: + with open(file_path) as f: + return json.load(f) class TempFileManager: diff --git a/gradio/serializing.py b/gradio/serializing.py index 23f27e7f67..aa5a571ab7 100644 --- a/gradio/serializing.py +++ b/gradio/serializing.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict @@ -11,7 +10,7 @@ from gradio import processing_utils, utils class Serializable(ABC): @abstractmethod def serialize( - self, x: Any, load_dir: str = "", encryption_key: bytes | None = None + self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None ): """ Convert data from human-readable format to serialized format for a browser. @@ -20,7 +19,10 @@ class Serializable(ABC): @abstractmethod def deserialize( - x: Any, save_dir: str | None = None, encryption_key: bytes | None = None + self, + x: Any, + save_dir: str | Path | None = None, + encryption_key: bytes | None = None, ): """ Convert data from serialized format for a browser to human-readable format. @@ -30,7 +32,7 @@ class Serializable(ABC): class SimpleSerializable(Serializable): def serialize( - self, x: Any, load_dir: str = "", encryption_key: bytes | None = None + self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None ) -> Any: """ Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op. @@ -42,7 +44,10 @@ class SimpleSerializable(Serializable): return x def deserialize( - self, x: Any, save_dir: str | None = None, encryption_key: bytes | None = None + self, + x: Any, + save_dir: str | Path | None = None, + encryption_key: bytes | None = None, ): """ Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op. @@ -56,8 +61,11 @@ class SimpleSerializable(Serializable): class ImgSerializable(Serializable): def serialize( - self, x: str, load_dir: str = "", encryption_key: bytes | None = None - ) -> str: + self, + x: str | None, + load_dir: str | Path = "", + encryption_key: bytes | None = None, + ) -> str | None: """ Convert from human-friendly version of a file (string filepath) to a seralized representation (base64). @@ -69,12 +77,15 @@ class ImgSerializable(Serializable): if x is None or x == "": return None return processing_utils.encode_url_or_file_to_base64( - os.path.join(load_dir, x), encryption_key=encryption_key + Path(load_dir) / x, encryption_key=encryption_key ) def deserialize( - self, x: str, save_dir: str | None = None, encryption_key: bytes | None = None - ) -> str: + self, + x: str | None, + save_dir: str | Path | None = None, + encryption_key: bytes | None = None, + ) -> str | None: """ Convert from serialized representation of a file (base64) to a human-friendly version (string filepath). Optionally, save the file to the directory specified by save_dir @@ -93,8 +104,11 @@ class ImgSerializable(Serializable): class FileSerializable(Serializable): def serialize( - self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None - ) -> Any: + self, + x: str | None, + load_dir: str | Path = "", + encryption_key: bytes | None = None, + ) -> Dict | None: """ Convert from human-friendly version of a file (string filepath) to a seralized representation (base64) @@ -105,7 +119,7 @@ class FileSerializable(Serializable): """ if x is None or x == "": return None - filename = os.path.join(load_dir, x) + filename = Path(load_dir) / x return { "name": filename, "data": processing_utils.encode_url_or_file_to_base64( @@ -120,7 +134,7 @@ class FileSerializable(Serializable): x: str | Dict | None, save_dir: Path | str | None = None, encryption_key: bytes | None = None, - ): + ) -> str | None: """ Convert from serialized representation of a file (base64) to a human-friendly version (string filepath). Optionally, save the file to the directory specified by `save_dir` @@ -158,7 +172,10 @@ class FileSerializable(Serializable): class JSONSerializable(Serializable): def serialize( - self, x: str, load_dir: str = "", encryption_key: bytes | None = None + self, + x: str | None, + load_dir: str | Path = "", + encryption_key: bytes | None = None, ) -> Dict | None: """ Convert from a a human-friendly version (string path to json file) to a @@ -170,14 +187,14 @@ class JSONSerializable(Serializable): """ if x is None or x == "": return None - return processing_utils.file_to_json(os.path.join(load_dir, x)) + return processing_utils.file_to_json(Path(load_dir) / x) def deserialize( self, x: str | Dict, - save_dir: str | None = None, + save_dir: str | Path | None = None, encryption_key: bytes | None = None, - ) -> str: + ) -> str | None: """ Convert from serialized representation (json string) to a human-friendly version (string path to json file). Optionally, save the file to the directory specified by `save_dir` diff --git a/gradio/utils.py b/gradio/utils.py index d3700507e3..3c06288e06 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -634,7 +634,7 @@ class AsyncRequest: @contextmanager -def set_directory(path: Path): +def set_directory(path: Path | str): """Context manager that sets the working directory to the given path.""" origin = Path().absolute() try: @@ -673,9 +673,7 @@ def sanitize_value_for_csv(value: str | Number) -> str | Number: return value -def sanitize_list_for_csv( - values: List[str | Number] | List[List[str | Number]], -) -> List[str | Number] | List[List[str | Number]]: +def sanitize_list_for_csv(values: T) -> T: """ Sanitizes a list of values (or a list of list of values) that is being written to a CSV file to prevent CSV injection attacks. diff --git a/scripts/type_check_backend.sh b/scripts/type_check_backend.sh index 33dc367d99..5422d5d0eb 100644 --- a/scripts/type_check_backend.sh +++ b/scripts/type_check_backend.sh @@ -6,4 +6,4 @@ pip_required pip install --upgrade pip pip install pyright cd gradio -pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py +pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py