Added typing to even more files (#2900)

* started pathlib

* blocks.py

* more changes

* fixes

* typing

* formatting

* typing

* renaming files

* changelog

* script

* changelog

* lint

* routes

* renamed

* state

* formatting

* state

* type check script

* remove strictness

* switched to pyright

* switched to pyright

* fixed flaky tests

* fixed test xray

* fixed load test

* fixed blocks tests

* formatting

* fixed components test

* uncomment tests

* fixed interpretation tests

* formatting

* last tests hopefully

* argh lint

* component

* fixed based on review

* refactor

* components.py t yping

* components.py

* formatting

* lint script

* merge

* merge

* lint

* pathlib

* lint

* events too

* lint script

* fixing tests

* lint

* examples

* serializing

* more files

* formatting

* flagging.py

* added to lint script

* fixed tab

* attempt fix

* serialize fix

* formatting

* all demos queue

* addressed review comments

* formatting
This commit is contained in:
Abubakar Abid 2022-12-29 15:18:19 -05:00 committed by GitHub
parent 09ebf00332
commit 5310782ed9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 223 additions and 189 deletions

View File

@ -32,4 +32,4 @@ with gr.Blocks() as mega_demo:
with gr.Tab(demo_name):
demo.render()
mega_demo.launch()
mega_demo.queue().launch()

View File

@ -2350,9 +2350,11 @@ class File(
}
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
) -> Dict:
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
) -> Dict | None:
serialized = FileSerializable.serialize(self, x, load_dir, encryption_key)
if serialized is None:
return None
serialized["size"] = Path(serialized["name"]).stat().st_size
return serialized
@ -3019,9 +3021,11 @@ class UploadButton(
return deepcopy(media_data.BASE64_FILE)
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
) -> Dict:
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
) -> Dict | None:
serialized = FileSerializable.serialize(self, x, load_dir, encryption_key)
if serialized is None:
return None
serialized["size"] = Path(serialized["name"]).stat().st_size
return serialized

View File

@ -8,7 +8,7 @@ import csv
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, List
from gradio import utils
from gradio.components import Dataset
@ -77,13 +77,13 @@ class Examples:
self,
examples: List[Any] | List[List[Any]] | str,
inputs: IOComponent | List[IOComponent],
outputs: Optional[IOComponent | List[IOComponent]] = None,
fn: Optional[Callable] = None,
outputs: IOComponent | List[IOComponent] | None = None,
fn: Callable | None = None,
cache_examples: bool = False,
examples_per_page: int = 10,
_api_mode: bool = False,
label: str = "Examples",
elem_id: Optional[str] = None,
label: str | None = "Examples",
elem_id: str | None = None,
run_on_click: bool = False,
preprocess: bool = True,
postprocess: bool = True,
@ -115,7 +115,7 @@ class Examples:
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):
if outputs and not isinstance(outputs, list):
outputs = [outputs]
working_directory = Path().absolute()
@ -131,12 +131,12 @@ class Examples:
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
examples = [[e] for e in examples]
elif isinstance(examples, str):
if not os.path.exists(examples):
if not Path(examples).exists():
raise FileNotFoundError(
"Could not find examples directory: " + examples
)
working_directory = examples
if not os.path.exists(os.path.join(examples, LOG_FILE)):
if not (Path(examples) / LOG_FILE).exists():
if len(inputs) == 1:
examples = [[e] for e in os.listdir(examples)]
else:
@ -145,7 +145,7 @@ class Examples:
+ LOG_FILE
)
else:
with open(os.path.join(examples, LOG_FILE)) as logs:
with open(Path(examples) / LOG_FILE) as logs:
examples = list(csv.reader(logs))
examples = [
examples[i][: len(inputs)] for i in range(1, len(examples))
@ -221,8 +221,8 @@ class Examples:
elem_id=elem_id,
)
self.cached_folder = os.path.join(CACHED_FOLDER, str(self.dataset._id))
self.cached_file = os.path.join(self.cached_folder, "log.csv")
self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cache_examples = cache_examples
self.run_on_click = run_on_click
@ -240,19 +240,24 @@ class Examples:
return utils.resolve_singleton(processed_example)
if Context.root_block:
if self.cache_examples and self.outputs:
targets = self.inputs_with_examples
else:
targets = self.inputs
self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples
+ (self.outputs if self.cache_examples else []),
outputs=targets, # type: ignore
postprocess=False,
queue=False,
)
if self.run_on_click and not self.cache_examples:
if self.fn is None:
raise ValueError("Cannot run_on_click if no function is provided")
self.dataset.click(
self.fn,
inputs=self.inputs,
outputs=self.outputs,
inputs=self.inputs, # type: ignore
outputs=self.outputs, # type: ignore
)
if self.cache_examples:
@ -262,29 +267,30 @@ class Examples:
"""
Caches all of the examples so that their predictions can be shown immediately.
"""
if os.path.exists(self.cached_file):
if Path(self.cached_file).exists():
print(
f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
f"Using cache from '{Path(self.cached_folder).resolve()}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
)
else:
if Context.root_block is None:
raise ValueError("Cannot cache examples if not in a Blocks context")
print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'")
print(f"Caching examples at: '{Path(self.cached_file).resolve()}'")
cache_logger = CSVLogger()
# create a fake dependency to process the examples and get the predictions
dependency = Context.root_block.set_event_trigger(
event_name="fake_event",
fn=self.fn,
inputs=self.inputs_with_examples,
outputs=self.outputs,
inputs=self.inputs_with_examples, # type: ignore
outputs=self.outputs, # type: ignore
preprocess=self.preprocess and not self._api_mode,
postprocess=self.postprocess and not self._api_mode,
batch=self.batch,
)
fn_index = Context.root_block.dependencies.index(dependency)
assert self.outputs is not None
cache_logger.setup(self.outputs, self.cached_folder)
for example_id, _ in enumerate(self.examples):
processed_input = self.processed_examples[example_id]
@ -310,6 +316,7 @@ class Examples:
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header
output = []
assert self.outputs is not None
for component, value in zip(self.outputs, example):
try:
value_as_dict = ast.literal_eval(value)

View File

@ -63,7 +63,7 @@ def load_blocks_from_repo(
return blocks
def from_model(model_name: str, api_key: str | None, alias: str, **kwargs):
def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs):
model_url = "https://huggingface.co/{}".format(model_name)
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
print("Fetching model from: {}".format(model_url))
@ -316,7 +316,9 @@ def from_model(model_name: str, api_key: str | None, alias: str, **kwargs):
return interface
def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
def from_spaces(
space_name: str, api_key: str | None, alias: str | None, **kwargs
) -> Blocks:
space_url = "https://huggingface.co/spaces/{}".format(space_name)
print("Fetching Space from: {}".format(space_url))
@ -344,7 +346,7 @@ def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> B
r"window.gradio_config = (.*?);[\s]*</script>", r.text
) # some basic regex to extract the config
try:
config = json.loads(result.group(1))
config = json.loads(result.group(1)) # type: ignore
except AttributeError:
raise ValueError("Could not load the Space: {}".format(space_name))
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
@ -416,7 +418,7 @@ def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Bl
def from_spaces_interface(
model_name: str,
config: Dict,
alias: str,
alias: str | None,
api_key: str | None,
iframe_url: str,
**kwargs,

View File

@ -3,7 +3,6 @@
import base64
import json
import math
import numbers
import operator
import re
import warnings
@ -13,6 +12,7 @@ import requests
import websockets
import yaml
from packaging import version
from websockets.legacy.protocol import WebSocketCommonProtocol
from gradio import components, exceptions
@ -30,8 +30,13 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
yaml_regex = re.search(
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
)
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
example_data = example_yaml.get("widget", {}).get("structuredData", {})
if yaml_regex is None:
example_data = {}
else:
example_yaml = next(
yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
)
example_data = example_yaml.get("widget", {}).get("structuredData", {})
if not example_data:
raise ValueError(
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
@ -41,7 +46,7 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
# replace nan with string NaN for inference API
for data in example_data.values():
for i, val in enumerate(data):
if isinstance(val, numbers.Number) and math.isnan(val):
if isinstance(val, float) and math.isnan(val):
data[i] = "NaN"
return example_data
@ -76,7 +81,7 @@ def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]
##################
def postprocess_label(scores):
def postprocess_label(scores: Dict) -> Dict:
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
return {
"label": sorted_pred[0][0],
@ -117,9 +122,10 @@ def encode_to_base64(r: requests.Response) -> str:
async def get_pred_from_ws(
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
websocket: WebSocketCommonProtocol, data: str, hash_data: str
) -> Dict[str, Any]:
completed = False
resp = {}
while not completed:
msg = await websocket.recv()
resp = json.loads(msg)
@ -135,7 +141,7 @@ async def get_pred_from_ws(
def get_ws_fn(ws_url, headers):
async def ws_fn(data, hash_data):
async with websockets.connect(
async with websockets.connect( # type: ignore
ws_url, open_timeout=10, extra_headers=headers
) as websocket:
return await get_pred_from_ws(websocket, data, hash_data)

View File

@ -7,7 +7,8 @@ import json
import os
import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional
from pathlib import Path
from typing import TYPE_CHECKING, Any, List
import gradio as gr
from gradio import encryptor, utils
@ -51,9 +52,9 @@ def _get_dataset_features_info(is_new, components):
headers.append(component.label + " file")
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
infos["flagged"]["features"][component.label + " file"] = {
"_type": _type
}
infos["flagged"]["features"][
(component.label or "") + " file"
] = {"_type": _type}
break
headers.append("flag")
@ -85,9 +86,9 @@ class FlaggingCallback(ABC):
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> int:
"""
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
@ -121,7 +122,7 @@ class SimpleCSVLogger(FlaggingCallback):
def __init__(self):
pass
def setup(self, components: List[IOComponent], flagging_dir: str):
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
self.components = components
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
@ -129,17 +130,17 @@ class SimpleCSVLogger(FlaggingCallback):
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> int:
flagging_dir = self.flagging_dir
log_filepath = os.path.join(flagging_dir, "log.csv")
log_filepath = Path(flagging_dir) / "log.csv"
csv_data = []
for component, sample in zip(self.components, flag_data):
save_dir = os.path.join(
flagging_dir, utils.strip_invalid_filename_characters(component.label)
save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
component.label or ""
)
csv_data.append(
component.deserialize(
@ -178,8 +179,8 @@ class CSVLogger(FlaggingCallback):
def setup(
self,
components: List[IOComponent],
flagging_dir: str,
encryption_key: Optional[str] = None,
flagging_dir: str | Path,
encryption_key: bytes | None = None,
):
self.components = components
self.flagging_dir = flagging_dir
@ -189,54 +190,49 @@ class CSVLogger(FlaggingCallback):
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> int:
flagging_dir = self.flagging_dir
log_filepath = os.path.join(flagging_dir, "log.csv")
is_new = not os.path.exists(log_filepath)
log_filepath = Path(flagging_dir) / "log.csv"
is_new = not Path(log_filepath).exists()
headers = [
component.label or f"component {idx}"
for idx, component in enumerate(self.components)
] + [
"flag",
"username",
"timestamp",
]
if flag_index is None:
csv_data = []
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
save_dir = os.path.join(
flagging_dir,
utils.strip_invalid_filename_characters(
component.label or f"component {idx}"
),
)
if utils.is_update(sample):
csv_data.append(str(sample))
else:
csv_data.append(
component.deserialize(
sample,
save_dir=save_dir,
encryption_key=self.encryption_key,
)
if sample is not None
else ""
csv_data = []
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
component.label or f"component {idx}"
)
if utils.is_update(sample):
csv_data.append(str(sample))
else:
csv_data.append(
component.deserialize(
sample,
save_dir=save_dir,
encryption_key=self.encryption_key,
)
csv_data.append(flag_option if flag_option is not None else "")
csv_data.append(username if username is not None else "")
csv_data.append(str(datetime.datetime.now()))
if is_new:
headers = [
component.label or f"component {idx}"
for idx, component in enumerate(self.components)
] + [
"flag",
"username",
"timestamp",
]
if sample is not None
else ""
)
csv_data.append(flag_option if flag_option is not None else "")
csv_data.append(username if username is not None else "")
csv_data.append(str(datetime.datetime.now()))
def replace_flag_at_index(file_content):
file_content = io.StringIO(file_content)
content = list(csv.reader(file_content))
def replace_flag_at_index(file_content: str, flag_index: int):
file_content_ = io.StringIO(file_content)
content = list(csv.reader(file_content_))
header = content[0]
flag_col_index = header.index("flag")
content[flag_index][flag_col_index] = flag_option
content[flag_index][flag_col_index] = flag_option # type: ignore
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(utils.sanitize_list_for_csv(content))
@ -252,7 +248,7 @@ class CSVLogger(FlaggingCallback):
)
file_content = decrypted_csv.decode()
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
file_content = replace_flag_at_index(file_content, flag_index)
output.write(file_content)
writer = csv.writer(output)
if flag_index is None:
@ -273,7 +269,7 @@ class CSVLogger(FlaggingCallback):
else:
with open(log_filepath, encoding="utf-8") as csvfile:
file_content = csvfile.read()
file_content = replace_flag_at_index(file_content)
file_content = replace_flag_at_index(file_content, flag_index)
with open(
log_filepath, "w", newline="", encoding="utf-8"
) as csvfile: # newline parameter needed for Windows
@ -302,7 +298,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
self,
hf_token: str,
dataset_name: str,
organization: Optional[str] = None,
organization: str | None = None,
private: bool = False,
):
"""
@ -340,28 +336,28 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
self.components = components
self.flagging_dir = flagging_dir
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
self.dataset_dir = Path(flagging_dir) / self.dataset_name
self.repo = huggingface_hub.Repository(
local_dir=self.dataset_dir,
local_dir=str(self.dataset_dir),
clone_from=path_to_dataset_repo,
use_auth_token=self.hf_token,
)
self.repo.git_pull(lfs=True)
# Should filename be user-specified?
self.log_file = os.path.join(self.dataset_dir, "data.csv")
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
self.log_file = Path(self.dataset_dir) / "data.csv"
self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> int:
self.repo.git_pull(lfs=True)
is_new = not os.path.exists(self.log_file)
is_new = not Path(self.log_file).exists()
with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
@ -378,10 +374,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
# Generate the row corresponding to the flagged sample
csv_data = []
for component, sample in zip(self.components, flag_data):
save_dir = os.path.join(
self.dataset_dir,
utils.strip_invalid_filename_characters(component.label),
)
save_dir = Path(
self.dataset_dir
) / utils.strip_invalid_filename_characters(component.label or "")
filepath = component.deserialize(sample, save_dir, None)
csv_data.append(filepath)
if isinstance(component, tuple(file_preview_types)):
@ -416,7 +411,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
self,
hf_foken: str,
dataset_name: str,
organization: Optional[str] = None,
organization: str | None = None,
private: bool = False,
verbose: bool = True,
):
@ -463,34 +458,34 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
self.components = components
self.flagging_dir = flagging_dir
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
self.dataset_dir = Path(flagging_dir) / self.dataset_name
self.repo = huggingface_hub.Repository(
local_dir=self.dataset_dir,
local_dir=str(self.dataset_dir),
clone_from=path_to_dataset_repo,
use_auth_token=self.hf_foken,
)
self.repo.git_pull(lfs=True)
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
flag_option: str | None = None,
flag_index: int | None = None,
username: str | None = None,
) -> str:
self.repo.git_pull(lfs=True)
# Generate unique folder for the flagged sample
unique_name = self.get_unique_name() # unique name for folder
folder_name = os.path.join(
self.dataset_dir, unique_name
folder_name = (
Path(self.dataset_dir) / unique_name
) # unique folder for specific example
os.makedirs(folder_name)
# Now uses the existence of `dataset_infos.json` to determine if new
is_new = not os.path.exists(self.infos_file)
is_new = not Path(self.infos_file).exists()
# File previews for certain input and output types
infos, file_preview_types, _ = _get_dataset_features_info(
@ -505,9 +500,10 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
headers.append(component.label)
try:
filepath = component.save_flagged(
folder_name, component.label, sample, None
save_dir = Path(folder_name) / utils.strip_invalid_filename_characters(
component.label or ""
)
filepath = component.deserialize(sample, save_dir, None)
except Exception:
# Could not parse 'sample' (mostly) because it was None and `component.save_flagged`
# does not handle None cases.
@ -515,7 +511,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
filepath = None
if isinstance(component, tuple(file_preview_types)):
headers.append(component.label + " file")
headers.append(component.label or "" + " file")
csv_data.append(
"{}/resolve/main/{}/{}".format(
@ -533,7 +529,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
metadata_dict = {
header: _csv_data for header, _csv_data in zip(headers, csv_data)
}
self.dump_json(metadata_dict, os.path.join(folder_name, "metadata.jsonl"))
self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl")
if is_new:
json.dump(infos, open(self.infos_file, "w"))
@ -545,7 +541,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
id = uuid.uuid4()
return str(id)
def dump_json(self, thing: dict, file_path: str) -> None:
def dump_json(self, thing: dict, file_path: str | Path) -> None:
with open(file_path, "w+", encoding="utf8") as f:
json.dump(thing, f)

View File

@ -1,3 +1,4 @@
# type: ignore
"""
This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
`InputComponent`, and each class must define a path to its template. All of the subclasses of `InputComponent` are

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Callable, List, Optional, Type
from typing import TYPE_CHECKING, Callable, List, Type
from gradio.blocks import BlockContext
from gradio.documentation import document, set_documentation_group
@ -30,7 +30,7 @@ class Row(BlockContext):
*,
variant: str = "default",
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
"""
@ -49,7 +49,7 @@ class Row(BlockContext):
@staticmethod
def update(
visible: Optional[bool] = None,
visible: bool | None = None,
):
return {
"visible": visible,
@ -59,8 +59,8 @@ class Row(BlockContext):
def style(
self,
*,
equal_height: Optional[bool] = None,
mobile_collapse: Optional[bool] = None,
equal_height: bool | None = None,
mobile_collapse: bool | None = None,
**kwargs,
):
"""
@ -100,7 +100,7 @@ class Column(BlockContext):
min_width: int = 320,
variant: str = "default",
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
"""
@ -129,8 +129,8 @@ class Column(BlockContext):
@staticmethod
def update(
variant: Optional[str] = None,
visible: Optional[bool] = None,
variant: str | None = None,
visible: bool | None = None,
):
return {
"variant": variant,
@ -147,9 +147,9 @@ class Tabs(BlockContext):
def __init__(
self,
*,
selected: Optional[int | str] = None,
selected: int | str | None = None,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
"""
@ -164,8 +164,9 @@ class Tabs(BlockContext):
def get_config(self):
return {"selected": self.selected, **super().get_config()}
@staticmethod
def update(
selected: Optional[int | str] = None,
selected: int | str | None = None,
):
return {
"selected": selected,
@ -183,13 +184,27 @@ class Tabs(BlockContext):
self.set_event_trigger("change", fn, inputs, outputs)
class TabItem(BlockContext):
@document()
class Tab(BlockContext):
"""
Tab (or its alias TabItem) is a layout element. Components defined within the Tab will be visible when this tab is selected tab.
Example:
with gradio.Blocks() as demo:
with gradio.Tab("Lion"):
gr.Image("lion.jpg")
gr.Button("New Lion")
with gradio.Tab("Tiger"):
gr.Image("tiger.jpg")
gr.Button("New Tiger")
Guides: controlling_layout
"""
def __init__(
self,
label: str,
*,
id: Optional[int | str] = None,
elem_id: Optional[str] = None,
id: int | str | None = None,
elem_id: str | None = None,
**kwargs,
):
"""
@ -222,26 +237,11 @@ class TabItem(BlockContext):
def get_expected_parent(self) -> Type[Tabs]:
return Tabs
@document()
class Tab(TabItem):
"""
Tab is a layout element. Components defined within the Tab will be visible when this tab is selected tab.
Example:
with gradio.Blocks() as demo:
with gradio.Tab("Lion"):
gr.Image("lion.jpg")
gr.Button("New Lion")
with gradio.Tab("Tiger"):
gr.Image("tiger.jpg")
gr.Button("New Tiger")
Guides: controlling_layout
"""
pass
def get_block_name(self):
return "tabitem"
Tab = TabItem # noqa: F811
TabItem = Tab
class Group(BlockContext):
@ -258,7 +258,7 @@ class Group(BlockContext):
self,
*,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
"""
@ -273,7 +273,7 @@ class Group(BlockContext):
@staticmethod
def update(
visible: Optional[bool] = None,
visible: bool | None = None,
):
return {
"visible": visible,
@ -296,7 +296,7 @@ class Box(BlockContext):
self,
*,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
"""
@ -311,7 +311,7 @@ class Box(BlockContext):
@staticmethod
def update(
visible: Optional[bool] = None,
visible: bool | None = None,
):
return {
"visible": visible,
@ -342,7 +342,7 @@ class Accordion(BlockContext):
*,
open: bool = True,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
"""
@ -365,9 +365,9 @@ class Accordion(BlockContext):
@staticmethod
def update(
open: Optional[bool] = None,
label: Optional[str] = None,
visible: Optional[bool] = None,
open: bool | None = None,
label: str | None = None,
visible: bool | None = None,
):
return {
"visible": visible,

View File

@ -1,3 +1,4 @@
# type: ignore
"""
This module defines various classes that can serve as the `output` to an interface. Each class must inherit from
`OutputComponent`, and each class must define a path to its template. All of the subclasses of `OutputComponent` are

View File

@ -12,6 +12,7 @@ import tempfile
import urllib.request
import warnings
from io import BytesIO
from pathlib import Path
from typing import Dict, Tuple
import numpy as np
@ -304,8 +305,9 @@ def dict_or_str_to_json_file(jsn, dir=None):
return file_obj
def file_to_json(file_path: str) -> Dict:
return json.load(open(file_path))
def file_to_json(file_path: str | Path) -> Dict:
with open(file_path) as f:
return json.load(f)
class TempFileManager:

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict
@ -11,7 +10,7 @@ from gradio import processing_utils, utils
class Serializable(ABC):
@abstractmethod
def serialize(
self, x: Any, load_dir: str = "", encryption_key: bytes | None = None
self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
):
"""
Convert data from human-readable format to serialized format for a browser.
@ -20,7 +19,10 @@ class Serializable(ABC):
@abstractmethod
def deserialize(
x: Any, save_dir: str | None = None, encryption_key: bytes | None = None
self,
x: Any,
save_dir: str | Path | None = None,
encryption_key: bytes | None = None,
):
"""
Convert data from serialized format for a browser to human-readable format.
@ -30,7 +32,7 @@ class Serializable(ABC):
class SimpleSerializable(Serializable):
def serialize(
self, x: Any, load_dir: str = "", encryption_key: bytes | None = None
self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
) -> Any:
"""
Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op.
@ -42,7 +44,10 @@ class SimpleSerializable(Serializable):
return x
def deserialize(
self, x: Any, save_dir: str | None = None, encryption_key: bytes | None = None
self,
x: Any,
save_dir: str | Path | None = None,
encryption_key: bytes | None = None,
):
"""
Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op.
@ -56,8 +61,11 @@ class SimpleSerializable(Serializable):
class ImgSerializable(Serializable):
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
) -> str:
self,
x: str | None,
load_dir: str | Path = "",
encryption_key: bytes | None = None,
) -> str | None:
"""
Convert from human-friendly version of a file (string filepath) to a seralized
representation (base64).
@ -69,12 +77,15 @@ class ImgSerializable(Serializable):
if x is None or x == "":
return None
return processing_utils.encode_url_or_file_to_base64(
os.path.join(load_dir, x), encryption_key=encryption_key
Path(load_dir) / x, encryption_key=encryption_key
)
def deserialize(
self, x: str, save_dir: str | None = None, encryption_key: bytes | None = None
) -> str:
self,
x: str | None,
save_dir: str | Path | None = None,
encryption_key: bytes | None = None,
) -> str | None:
"""
Convert from serialized representation of a file (base64) to a human-friendly
version (string filepath). Optionally, save the file to the directory specified by save_dir
@ -93,8 +104,11 @@ class ImgSerializable(Serializable):
class FileSerializable(Serializable):
def serialize(
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
) -> Any:
self,
x: str | None,
load_dir: str | Path = "",
encryption_key: bytes | None = None,
) -> Dict | None:
"""
Convert from human-friendly version of a file (string filepath) to a
seralized representation (base64)
@ -105,7 +119,7 @@ class FileSerializable(Serializable):
"""
if x is None or x == "":
return None
filename = os.path.join(load_dir, x)
filename = Path(load_dir) / x
return {
"name": filename,
"data": processing_utils.encode_url_or_file_to_base64(
@ -120,7 +134,7 @@ class FileSerializable(Serializable):
x: str | Dict | None,
save_dir: Path | str | None = None,
encryption_key: bytes | None = None,
):
) -> str | None:
"""
Convert from serialized representation of a file (base64) to a human-friendly
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
@ -158,7 +172,10 @@ class FileSerializable(Serializable):
class JSONSerializable(Serializable):
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
self,
x: str | None,
load_dir: str | Path = "",
encryption_key: bytes | None = None,
) -> Dict | None:
"""
Convert from a a human-friendly version (string path to json file) to a
@ -170,14 +187,14 @@ class JSONSerializable(Serializable):
"""
if x is None or x == "":
return None
return processing_utils.file_to_json(os.path.join(load_dir, x))
return processing_utils.file_to_json(Path(load_dir) / x)
def deserialize(
self,
x: str | Dict,
save_dir: str | None = None,
save_dir: str | Path | None = None,
encryption_key: bytes | None = None,
) -> str:
) -> str | None:
"""
Convert from serialized representation (json string) to a human-friendly
version (string path to json file). Optionally, save the file to the directory specified by `save_dir`

View File

@ -634,7 +634,7 @@ class AsyncRequest:
@contextmanager
def set_directory(path: Path):
def set_directory(path: Path | str):
"""Context manager that sets the working directory to the given path."""
origin = Path().absolute()
try:
@ -673,9 +673,7 @@ def sanitize_value_for_csv(value: str | Number) -> str | Number:
return value
def sanitize_list_for_csv(
values: List[str | Number] | List[List[str | Number]],
) -> List[str | Number] | List[List[str | Number]]:
def sanitize_list_for_csv(values: T) -> T:
"""
Sanitizes a list of values (or a list of list of values) that is being written to a
CSV file to prevent CSV injection attacks.

View File

@ -6,4 +6,4 @@ pip_required
pip install --upgrade pip
pip install pyright
cd gradio
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py