mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Added typing to even more files (#2900)
* started pathlib * blocks.py * more changes * fixes * typing * formatting * typing * renaming files * changelog * script * changelog * lint * routes * renamed * state * formatting * state * type check script * remove strictness * switched to pyright * switched to pyright * fixed flaky tests * fixed test xray * fixed load test * fixed blocks tests * formatting * fixed components test * uncomment tests * fixed interpretation tests * formatting * last tests hopefully * argh lint * component * fixed based on review * refactor * components.py t yping * components.py * formatting * lint script * merge * merge * lint * pathlib * lint * events too * lint script * fixing tests * lint * examples * serializing * more files * formatting * flagging.py * added to lint script * fixed tab * attempt fix * serialize fix * formatting * all demos queue * addressed review comments * formatting
This commit is contained in:
parent
09ebf00332
commit
5310782ed9
@ -32,4 +32,4 @@ with gr.Blocks() as mega_demo:
|
||||
with gr.Tab(demo_name):
|
||||
demo.render()
|
||||
|
||||
mega_demo.launch()
|
||||
mega_demo.queue().launch()
|
||||
|
@ -2350,9 +2350,11 @@ class File(
|
||||
}
|
||||
|
||||
def serialize(
|
||||
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
|
||||
) -> Dict:
|
||||
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
|
||||
) -> Dict | None:
|
||||
serialized = FileSerializable.serialize(self, x, load_dir, encryption_key)
|
||||
if serialized is None:
|
||||
return None
|
||||
serialized["size"] = Path(serialized["name"]).stat().st_size
|
||||
return serialized
|
||||
|
||||
@ -3019,9 +3021,11 @@ class UploadButton(
|
||||
return deepcopy(media_data.BASE64_FILE)
|
||||
|
||||
def serialize(
|
||||
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
|
||||
) -> Dict:
|
||||
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
|
||||
) -> Dict | None:
|
||||
serialized = FileSerializable.serialize(self, x, load_dir, encryption_key)
|
||||
if serialized is None:
|
||||
return None
|
||||
serialized["size"] = Path(serialized["name"]).stat().st_size
|
||||
return serialized
|
||||
|
||||
|
@ -8,7 +8,7 @@ import csv
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, List
|
||||
|
||||
from gradio import utils
|
||||
from gradio.components import Dataset
|
||||
@ -77,13 +77,13 @@ class Examples:
|
||||
self,
|
||||
examples: List[Any] | List[List[Any]] | str,
|
||||
inputs: IOComponent | List[IOComponent],
|
||||
outputs: Optional[IOComponent | List[IOComponent]] = None,
|
||||
fn: Optional[Callable] = None,
|
||||
outputs: IOComponent | List[IOComponent] | None = None,
|
||||
fn: Callable | None = None,
|
||||
cache_examples: bool = False,
|
||||
examples_per_page: int = 10,
|
||||
_api_mode: bool = False,
|
||||
label: str = "Examples",
|
||||
elem_id: Optional[str] = None,
|
||||
label: str | None = "Examples",
|
||||
elem_id: str | None = None,
|
||||
run_on_click: bool = False,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
@ -115,7 +115,7 @@ class Examples:
|
||||
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(outputs, list):
|
||||
if outputs and not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
working_directory = Path().absolute()
|
||||
@ -131,12 +131,12 @@ class Examples:
|
||||
): # If there is only one input component, examples can be provided as a regular list instead of a list of lists
|
||||
examples = [[e] for e in examples]
|
||||
elif isinstance(examples, str):
|
||||
if not os.path.exists(examples):
|
||||
if not Path(examples).exists():
|
||||
raise FileNotFoundError(
|
||||
"Could not find examples directory: " + examples
|
||||
)
|
||||
working_directory = examples
|
||||
if not os.path.exists(os.path.join(examples, LOG_FILE)):
|
||||
if not (Path(examples) / LOG_FILE).exists():
|
||||
if len(inputs) == 1:
|
||||
examples = [[e] for e in os.listdir(examples)]
|
||||
else:
|
||||
@ -145,7 +145,7 @@ class Examples:
|
||||
+ LOG_FILE
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(examples, LOG_FILE)) as logs:
|
||||
with open(Path(examples) / LOG_FILE) as logs:
|
||||
examples = list(csv.reader(logs))
|
||||
examples = [
|
||||
examples[i][: len(inputs)] for i in range(1, len(examples))
|
||||
@ -221,8 +221,8 @@ class Examples:
|
||||
elem_id=elem_id,
|
||||
)
|
||||
|
||||
self.cached_folder = os.path.join(CACHED_FOLDER, str(self.dataset._id))
|
||||
self.cached_file = os.path.join(self.cached_folder, "log.csv")
|
||||
self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
|
||||
self.cached_file = Path(self.cached_folder) / "log.csv"
|
||||
self.cache_examples = cache_examples
|
||||
self.run_on_click = run_on_click
|
||||
|
||||
@ -240,19 +240,24 @@ class Examples:
|
||||
return utils.resolve_singleton(processed_example)
|
||||
|
||||
if Context.root_block:
|
||||
if self.cache_examples and self.outputs:
|
||||
targets = self.inputs_with_examples
|
||||
else:
|
||||
targets = self.inputs
|
||||
self.dataset.click(
|
||||
load_example,
|
||||
inputs=[self.dataset],
|
||||
outputs=self.inputs_with_examples
|
||||
+ (self.outputs if self.cache_examples else []),
|
||||
outputs=targets, # type: ignore
|
||||
postprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
if self.run_on_click and not self.cache_examples:
|
||||
if self.fn is None:
|
||||
raise ValueError("Cannot run_on_click if no function is provided")
|
||||
self.dataset.click(
|
||||
self.fn,
|
||||
inputs=self.inputs,
|
||||
outputs=self.outputs,
|
||||
inputs=self.inputs, # type: ignore
|
||||
outputs=self.outputs, # type: ignore
|
||||
)
|
||||
|
||||
if self.cache_examples:
|
||||
@ -262,29 +267,30 @@ class Examples:
|
||||
"""
|
||||
Caches all of the examples so that their predictions can be shown immediately.
|
||||
"""
|
||||
if os.path.exists(self.cached_file):
|
||||
if Path(self.cached_file).exists():
|
||||
print(
|
||||
f"Using cache from '{os.path.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
||||
f"Using cache from '{Path(self.cached_folder).resolve()}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
||||
)
|
||||
else:
|
||||
if Context.root_block is None:
|
||||
raise ValueError("Cannot cache examples if not in a Blocks context")
|
||||
|
||||
print(f"Caching examples at: '{os.path.abspath(self.cached_file)}'")
|
||||
print(f"Caching examples at: '{Path(self.cached_file).resolve()}'")
|
||||
cache_logger = CSVLogger()
|
||||
|
||||
# create a fake dependency to process the examples and get the predictions
|
||||
dependency = Context.root_block.set_event_trigger(
|
||||
event_name="fake_event",
|
||||
fn=self.fn,
|
||||
inputs=self.inputs_with_examples,
|
||||
outputs=self.outputs,
|
||||
inputs=self.inputs_with_examples, # type: ignore
|
||||
outputs=self.outputs, # type: ignore
|
||||
preprocess=self.preprocess and not self._api_mode,
|
||||
postprocess=self.postprocess and not self._api_mode,
|
||||
batch=self.batch,
|
||||
)
|
||||
|
||||
fn_index = Context.root_block.dependencies.index(dependency)
|
||||
assert self.outputs is not None
|
||||
cache_logger.setup(self.outputs, self.cached_folder)
|
||||
for example_id, _ in enumerate(self.examples):
|
||||
processed_input = self.processed_examples[example_id]
|
||||
@ -310,6 +316,7 @@ class Examples:
|
||||
examples = list(csv.reader(cache))
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
output = []
|
||||
assert self.outputs is not None
|
||||
for component, value in zip(self.outputs, example):
|
||||
try:
|
||||
value_as_dict = ast.literal_eval(value)
|
||||
|
@ -63,7 +63,7 @@ def load_blocks_from_repo(
|
||||
return blocks
|
||||
|
||||
|
||||
def from_model(model_name: str, api_key: str | None, alias: str, **kwargs):
|
||||
def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs):
|
||||
model_url = "https://huggingface.co/{}".format(model_name)
|
||||
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
|
||||
print("Fetching model from: {}".format(model_url))
|
||||
@ -316,7 +316,9 @@ def from_model(model_name: str, api_key: str | None, alias: str, **kwargs):
|
||||
return interface
|
||||
|
||||
|
||||
def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
|
||||
def from_spaces(
|
||||
space_name: str, api_key: str | None, alias: str | None, **kwargs
|
||||
) -> Blocks:
|
||||
space_url = "https://huggingface.co/spaces/{}".format(space_name)
|
||||
|
||||
print("Fetching Space from: {}".format(space_url))
|
||||
@ -344,7 +346,7 @@ def from_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> B
|
||||
r"window.gradio_config = (.*?);[\s]*</script>", r.text
|
||||
) # some basic regex to extract the config
|
||||
try:
|
||||
config = json.loads(result.group(1))
|
||||
config = json.loads(result.group(1)) # type: ignore
|
||||
except AttributeError:
|
||||
raise ValueError("Could not load the Space: {}".format(space_name))
|
||||
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
|
||||
@ -416,7 +418,7 @@ def from_spaces_blocks(config: Dict, api_key: str | None, iframe_url: str) -> Bl
|
||||
def from_spaces_interface(
|
||||
model_name: str,
|
||||
config: Dict,
|
||||
alias: str,
|
||||
alias: str | None,
|
||||
api_key: str | None,
|
||||
iframe_url: str,
|
||||
**kwargs,
|
||||
|
@ -3,7 +3,6 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import numbers
|
||||
import operator
|
||||
import re
|
||||
import warnings
|
||||
@ -13,6 +12,7 @@ import requests
|
||||
import websockets
|
||||
import yaml
|
||||
from packaging import version
|
||||
from websockets.legacy.protocol import WebSocketCommonProtocol
|
||||
|
||||
from gradio import components, exceptions
|
||||
|
||||
@ -30,8 +30,13 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
|
||||
yaml_regex = re.search(
|
||||
"(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", readme.text
|
||||
)
|
||||
example_yaml = next(yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]))
|
||||
example_data = example_yaml.get("widget", {}).get("structuredData", {})
|
||||
if yaml_regex is None:
|
||||
example_data = {}
|
||||
else:
|
||||
example_yaml = next(
|
||||
yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]])
|
||||
)
|
||||
example_data = example_yaml.get("widget", {}).get("structuredData", {})
|
||||
if not example_data:
|
||||
raise ValueError(
|
||||
f"No example data found in README.md of {model_name} - Cannot build gradio demo. "
|
||||
@ -41,7 +46,7 @@ def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
|
||||
# replace nan with string NaN for inference API
|
||||
for data in example_data.values():
|
||||
for i, val in enumerate(data):
|
||||
if isinstance(val, numbers.Number) and math.isnan(val):
|
||||
if isinstance(val, float) and math.isnan(val):
|
||||
data[i] = "NaN"
|
||||
return example_data
|
||||
|
||||
@ -76,7 +81,7 @@ def rows_to_cols(incoming_data: Dict) -> Dict[str, Dict[str, Dict[str, List[str]
|
||||
##################
|
||||
|
||||
|
||||
def postprocess_label(scores):
|
||||
def postprocess_label(scores: Dict) -> Dict:
|
||||
sorted_pred = sorted(scores.items(), key=operator.itemgetter(1), reverse=True)
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
@ -117,9 +122,10 @@ def encode_to_base64(r: requests.Response) -> str:
|
||||
|
||||
|
||||
async def get_pred_from_ws(
|
||||
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
|
||||
websocket: WebSocketCommonProtocol, data: str, hash_data: str
|
||||
) -> Dict[str, Any]:
|
||||
completed = False
|
||||
resp = {}
|
||||
while not completed:
|
||||
msg = await websocket.recv()
|
||||
resp = json.loads(msg)
|
||||
@ -135,7 +141,7 @@ async def get_pred_from_ws(
|
||||
|
||||
def get_ws_fn(ws_url, headers):
|
||||
async def ws_fn(data, hash_data):
|
||||
async with websockets.connect(
|
||||
async with websockets.connect( # type: ignore
|
||||
ws_url, open_timeout=10, extra_headers=headers
|
||||
) as websocket:
|
||||
return await get_pred_from_ws(websocket, data, hash_data)
|
||||
|
@ -7,7 +7,8 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
import gradio as gr
|
||||
from gradio import encryptor, utils
|
||||
@ -51,9 +52,9 @@ def _get_dataset_features_info(is_new, components):
|
||||
headers.append(component.label + " file")
|
||||
for _component, _type in file_preview_types.items():
|
||||
if isinstance(component, _component):
|
||||
infos["flagged"]["features"][component.label + " file"] = {
|
||||
"_type": _type
|
||||
}
|
||||
infos["flagged"]["features"][
|
||||
(component.label or "") + " file"
|
||||
] = {"_type": _type}
|
||||
break
|
||||
|
||||
headers.append("flag")
|
||||
@ -85,9 +86,9 @@ class FlaggingCallback(ABC):
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
flag_option: str | None = None,
|
||||
flag_index: int | None = None,
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
|
||||
@ -121,7 +122,7 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def setup(self, components: List[IOComponent], flagging_dir: str):
|
||||
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
os.makedirs(flagging_dir, exist_ok=True)
|
||||
@ -129,17 +130,17 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
flag_option: str | None = None,
|
||||
flag_index: int | None = None,
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_filepath = os.path.join(flagging_dir, "log.csv")
|
||||
log_filepath = Path(flagging_dir) / "log.csv"
|
||||
|
||||
csv_data = []
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
save_dir = os.path.join(
|
||||
flagging_dir, utils.strip_invalid_filename_characters(component.label)
|
||||
save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
|
||||
component.label or ""
|
||||
)
|
||||
csv_data.append(
|
||||
component.deserialize(
|
||||
@ -178,8 +179,8 @@ class CSVLogger(FlaggingCallback):
|
||||
def setup(
|
||||
self,
|
||||
components: List[IOComponent],
|
||||
flagging_dir: str,
|
||||
encryption_key: Optional[str] = None,
|
||||
flagging_dir: str | Path,
|
||||
encryption_key: bytes | None = None,
|
||||
):
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
@ -189,54 +190,49 @@ class CSVLogger(FlaggingCallback):
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
flag_option: str | None = None,
|
||||
flag_index: int | None = None,
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_filepath = os.path.join(flagging_dir, "log.csv")
|
||||
is_new = not os.path.exists(log_filepath)
|
||||
log_filepath = Path(flagging_dir) / "log.csv"
|
||||
is_new = not Path(log_filepath).exists()
|
||||
headers = [
|
||||
component.label or f"component {idx}"
|
||||
for idx, component in enumerate(self.components)
|
||||
] + [
|
||||
"flag",
|
||||
"username",
|
||||
"timestamp",
|
||||
]
|
||||
|
||||
if flag_index is None:
|
||||
csv_data = []
|
||||
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
||||
save_dir = os.path.join(
|
||||
flagging_dir,
|
||||
utils.strip_invalid_filename_characters(
|
||||
component.label or f"component {idx}"
|
||||
),
|
||||
)
|
||||
if utils.is_update(sample):
|
||||
csv_data.append(str(sample))
|
||||
else:
|
||||
csv_data.append(
|
||||
component.deserialize(
|
||||
sample,
|
||||
save_dir=save_dir,
|
||||
encryption_key=self.encryption_key,
|
||||
)
|
||||
if sample is not None
|
||||
else ""
|
||||
csv_data = []
|
||||
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
||||
save_dir = Path(flagging_dir) / utils.strip_invalid_filename_characters(
|
||||
component.label or f"component {idx}"
|
||||
)
|
||||
if utils.is_update(sample):
|
||||
csv_data.append(str(sample))
|
||||
else:
|
||||
csv_data.append(
|
||||
component.deserialize(
|
||||
sample,
|
||||
save_dir=save_dir,
|
||||
encryption_key=self.encryption_key,
|
||||
)
|
||||
csv_data.append(flag_option if flag_option is not None else "")
|
||||
csv_data.append(username if username is not None else "")
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
if is_new:
|
||||
headers = [
|
||||
component.label or f"component {idx}"
|
||||
for idx, component in enumerate(self.components)
|
||||
] + [
|
||||
"flag",
|
||||
"username",
|
||||
"timestamp",
|
||||
]
|
||||
if sample is not None
|
||||
else ""
|
||||
)
|
||||
csv_data.append(flag_option if flag_option is not None else "")
|
||||
csv_data.append(username if username is not None else "")
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
|
||||
def replace_flag_at_index(file_content):
|
||||
file_content = io.StringIO(file_content)
|
||||
content = list(csv.reader(file_content))
|
||||
def replace_flag_at_index(file_content: str, flag_index: int):
|
||||
file_content_ = io.StringIO(file_content)
|
||||
content = list(csv.reader(file_content_))
|
||||
header = content[0]
|
||||
flag_col_index = header.index("flag")
|
||||
content[flag_index][flag_col_index] = flag_option
|
||||
content[flag_index][flag_col_index] = flag_option # type: ignore
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerows(utils.sanitize_list_for_csv(content))
|
||||
@ -252,7 +248,7 @@ class CSVLogger(FlaggingCallback):
|
||||
)
|
||||
file_content = decrypted_csv.decode()
|
||||
if flag_index is not None:
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
file_content = replace_flag_at_index(file_content, flag_index)
|
||||
output.write(file_content)
|
||||
writer = csv.writer(output)
|
||||
if flag_index is None:
|
||||
@ -273,7 +269,7 @@ class CSVLogger(FlaggingCallback):
|
||||
else:
|
||||
with open(log_filepath, encoding="utf-8") as csvfile:
|
||||
file_content = csvfile.read()
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
file_content = replace_flag_at_index(file_content, flag_index)
|
||||
with open(
|
||||
log_filepath, "w", newline="", encoding="utf-8"
|
||||
) as csvfile: # newline parameter needed for Windows
|
||||
@ -302,7 +298,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
self,
|
||||
hf_token: str,
|
||||
dataset_name: str,
|
||||
organization: Optional[str] = None,
|
||||
organization: str | None = None,
|
||||
private: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -340,28 +336,28 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
|
||||
self.dataset_dir = Path(flagging_dir) / self.dataset_name
|
||||
self.repo = huggingface_hub.Repository(
|
||||
local_dir=self.dataset_dir,
|
||||
local_dir=str(self.dataset_dir),
|
||||
clone_from=path_to_dataset_repo,
|
||||
use_auth_token=self.hf_token,
|
||||
)
|
||||
self.repo.git_pull(lfs=True)
|
||||
|
||||
# Should filename be user-specified?
|
||||
self.log_file = os.path.join(self.dataset_dir, "data.csv")
|
||||
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
|
||||
self.log_file = Path(self.dataset_dir) / "data.csv"
|
||||
self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
|
||||
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
flag_option: str | None = None,
|
||||
flag_index: int | None = None,
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
self.repo.git_pull(lfs=True)
|
||||
|
||||
is_new = not os.path.exists(self.log_file)
|
||||
is_new = not Path(self.log_file).exists()
|
||||
|
||||
with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
@ -378,10 +374,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
# Generate the row corresponding to the flagged sample
|
||||
csv_data = []
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
save_dir = os.path.join(
|
||||
self.dataset_dir,
|
||||
utils.strip_invalid_filename_characters(component.label),
|
||||
)
|
||||
save_dir = Path(
|
||||
self.dataset_dir
|
||||
) / utils.strip_invalid_filename_characters(component.label or "")
|
||||
filepath = component.deserialize(sample, save_dir, None)
|
||||
csv_data.append(filepath)
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
@ -416,7 +411,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
self,
|
||||
hf_foken: str,
|
||||
dataset_name: str,
|
||||
organization: Optional[str] = None,
|
||||
organization: str | None = None,
|
||||
private: bool = False,
|
||||
verbose: bool = True,
|
||||
):
|
||||
@ -463,34 +458,34 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
|
||||
self.components = components
|
||||
self.flagging_dir = flagging_dir
|
||||
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
|
||||
self.dataset_dir = Path(flagging_dir) / self.dataset_name
|
||||
self.repo = huggingface_hub.Repository(
|
||||
local_dir=self.dataset_dir,
|
||||
local_dir=str(self.dataset_dir),
|
||||
clone_from=path_to_dataset_repo,
|
||||
use_auth_token=self.hf_foken,
|
||||
)
|
||||
self.repo.git_pull(lfs=True)
|
||||
|
||||
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
|
||||
self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
|
||||
|
||||
def flag(
|
||||
self,
|
||||
flag_data: List[Any],
|
||||
flag_option: Optional[str] = None,
|
||||
flag_index: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
flag_option: str | None = None,
|
||||
flag_index: int | None = None,
|
||||
username: str | None = None,
|
||||
) -> str:
|
||||
self.repo.git_pull(lfs=True)
|
||||
|
||||
# Generate unique folder for the flagged sample
|
||||
unique_name = self.get_unique_name() # unique name for folder
|
||||
folder_name = os.path.join(
|
||||
self.dataset_dir, unique_name
|
||||
folder_name = (
|
||||
Path(self.dataset_dir) / unique_name
|
||||
) # unique folder for specific example
|
||||
os.makedirs(folder_name)
|
||||
|
||||
# Now uses the existence of `dataset_infos.json` to determine if new
|
||||
is_new = not os.path.exists(self.infos_file)
|
||||
is_new = not Path(self.infos_file).exists()
|
||||
|
||||
# File previews for certain input and output types
|
||||
infos, file_preview_types, _ = _get_dataset_features_info(
|
||||
@ -505,9 +500,10 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
headers.append(component.label)
|
||||
|
||||
try:
|
||||
filepath = component.save_flagged(
|
||||
folder_name, component.label, sample, None
|
||||
save_dir = Path(folder_name) / utils.strip_invalid_filename_characters(
|
||||
component.label or ""
|
||||
)
|
||||
filepath = component.deserialize(sample, save_dir, None)
|
||||
except Exception:
|
||||
# Could not parse 'sample' (mostly) because it was None and `component.save_flagged`
|
||||
# does not handle None cases.
|
||||
@ -515,7 +511,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
filepath = None
|
||||
|
||||
if isinstance(component, tuple(file_preview_types)):
|
||||
headers.append(component.label + " file")
|
||||
headers.append(component.label or "" + " file")
|
||||
|
||||
csv_data.append(
|
||||
"{}/resolve/main/{}/{}".format(
|
||||
@ -533,7 +529,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
metadata_dict = {
|
||||
header: _csv_data for header, _csv_data in zip(headers, csv_data)
|
||||
}
|
||||
self.dump_json(metadata_dict, os.path.join(folder_name, "metadata.jsonl"))
|
||||
self.dump_json(metadata_dict, Path(folder_name) / "metadata.jsonl")
|
||||
|
||||
if is_new:
|
||||
json.dump(infos, open(self.infos_file, "w"))
|
||||
@ -545,7 +541,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
id = uuid.uuid4()
|
||||
return str(id)
|
||||
|
||||
def dump_json(self, thing: dict, file_path: str) -> None:
|
||||
def dump_json(self, thing: dict, file_path: str | Path) -> None:
|
||||
with open(file_path, "w+", encoding="utf8") as f:
|
||||
json.dump(thing, f)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
# type: ignore
|
||||
"""
|
||||
This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
|
||||
`InputComponent`, and each class must define a path to its template. All of the subclasses of `InputComponent` are
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Callable, List, Type
|
||||
|
||||
from gradio.blocks import BlockContext
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
@ -30,7 +30,7 @@ class Row(BlockContext):
|
||||
*,
|
||||
variant: str = "default",
|
||||
visible: bool = True,
|
||||
elem_id: Optional[str] = None,
|
||||
elem_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -49,7 +49,7 @@ class Row(BlockContext):
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
visible: Optional[bool] = None,
|
||||
visible: bool | None = None,
|
||||
):
|
||||
return {
|
||||
"visible": visible,
|
||||
@ -59,8 +59,8 @@ class Row(BlockContext):
|
||||
def style(
|
||||
self,
|
||||
*,
|
||||
equal_height: Optional[bool] = None,
|
||||
mobile_collapse: Optional[bool] = None,
|
||||
equal_height: bool | None = None,
|
||||
mobile_collapse: bool | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -100,7 +100,7 @@ class Column(BlockContext):
|
||||
min_width: int = 320,
|
||||
variant: str = "default",
|
||||
visible: bool = True,
|
||||
elem_id: Optional[str] = None,
|
||||
elem_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -129,8 +129,8 @@ class Column(BlockContext):
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
variant: Optional[str] = None,
|
||||
visible: Optional[bool] = None,
|
||||
variant: str | None = None,
|
||||
visible: bool | None = None,
|
||||
):
|
||||
return {
|
||||
"variant": variant,
|
||||
@ -147,9 +147,9 @@ class Tabs(BlockContext):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
selected: Optional[int | str] = None,
|
||||
selected: int | str | None = None,
|
||||
visible: bool = True,
|
||||
elem_id: Optional[str] = None,
|
||||
elem_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -164,8 +164,9 @@ class Tabs(BlockContext):
|
||||
def get_config(self):
|
||||
return {"selected": self.selected, **super().get_config()}
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
selected: Optional[int | str] = None,
|
||||
selected: int | str | None = None,
|
||||
):
|
||||
return {
|
||||
"selected": selected,
|
||||
@ -183,13 +184,27 @@ class Tabs(BlockContext):
|
||||
self.set_event_trigger("change", fn, inputs, outputs)
|
||||
|
||||
|
||||
class TabItem(BlockContext):
|
||||
@document()
|
||||
class Tab(BlockContext):
|
||||
"""
|
||||
Tab (or its alias TabItem) is a layout element. Components defined within the Tab will be visible when this tab is selected tab.
|
||||
Example:
|
||||
with gradio.Blocks() as demo:
|
||||
with gradio.Tab("Lion"):
|
||||
gr.Image("lion.jpg")
|
||||
gr.Button("New Lion")
|
||||
with gradio.Tab("Tiger"):
|
||||
gr.Image("tiger.jpg")
|
||||
gr.Button("New Tiger")
|
||||
Guides: controlling_layout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str,
|
||||
*,
|
||||
id: Optional[int | str] = None,
|
||||
elem_id: Optional[str] = None,
|
||||
id: int | str | None = None,
|
||||
elem_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -222,26 +237,11 @@ class TabItem(BlockContext):
|
||||
def get_expected_parent(self) -> Type[Tabs]:
|
||||
return Tabs
|
||||
|
||||
|
||||
@document()
|
||||
class Tab(TabItem):
|
||||
"""
|
||||
Tab is a layout element. Components defined within the Tab will be visible when this tab is selected tab.
|
||||
Example:
|
||||
with gradio.Blocks() as demo:
|
||||
with gradio.Tab("Lion"):
|
||||
gr.Image("lion.jpg")
|
||||
gr.Button("New Lion")
|
||||
with gradio.Tab("Tiger"):
|
||||
gr.Image("tiger.jpg")
|
||||
gr.Button("New Tiger")
|
||||
Guides: controlling_layout
|
||||
"""
|
||||
|
||||
pass
|
||||
def get_block_name(self):
|
||||
return "tabitem"
|
||||
|
||||
|
||||
Tab = TabItem # noqa: F811
|
||||
TabItem = Tab
|
||||
|
||||
|
||||
class Group(BlockContext):
|
||||
@ -258,7 +258,7 @@ class Group(BlockContext):
|
||||
self,
|
||||
*,
|
||||
visible: bool = True,
|
||||
elem_id: Optional[str] = None,
|
||||
elem_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -273,7 +273,7 @@ class Group(BlockContext):
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
visible: Optional[bool] = None,
|
||||
visible: bool | None = None,
|
||||
):
|
||||
return {
|
||||
"visible": visible,
|
||||
@ -296,7 +296,7 @@ class Box(BlockContext):
|
||||
self,
|
||||
*,
|
||||
visible: bool = True,
|
||||
elem_id: Optional[str] = None,
|
||||
elem_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -311,7 +311,7 @@ class Box(BlockContext):
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
visible: Optional[bool] = None,
|
||||
visible: bool | None = None,
|
||||
):
|
||||
return {
|
||||
"visible": visible,
|
||||
@ -342,7 +342,7 @@ class Accordion(BlockContext):
|
||||
*,
|
||||
open: bool = True,
|
||||
visible: bool = True,
|
||||
elem_id: Optional[str] = None,
|
||||
elem_id: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -365,9 +365,9 @@ class Accordion(BlockContext):
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
open: Optional[bool] = None,
|
||||
label: Optional[str] = None,
|
||||
visible: Optional[bool] = None,
|
||||
open: bool | None = None,
|
||||
label: str | None = None,
|
||||
visible: bool | None = None,
|
||||
):
|
||||
return {
|
||||
"visible": visible,
|
||||
|
@ -1,3 +1,4 @@
|
||||
# type: ignore
|
||||
"""
|
||||
This module defines various classes that can serve as the `output` to an interface. Each class must inherit from
|
||||
`OutputComponent`, and each class must define a path to its template. All of the subclasses of `OutputComponent` are
|
||||
|
@ -12,6 +12,7 @@ import tempfile
|
||||
import urllib.request
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -304,8 +305,9 @@ def dict_or_str_to_json_file(jsn, dir=None):
|
||||
return file_obj
|
||||
|
||||
|
||||
def file_to_json(file_path: str) -> Dict:
|
||||
return json.load(open(file_path))
|
||||
def file_to_json(file_path: str | Path) -> Dict:
|
||||
with open(file_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
class TempFileManager:
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
@ -11,7 +10,7 @@ from gradio import processing_utils, utils
|
||||
class Serializable(ABC):
|
||||
@abstractmethod
|
||||
def serialize(
|
||||
self, x: Any, load_dir: str = "", encryption_key: bytes | None = None
|
||||
self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
|
||||
):
|
||||
"""
|
||||
Convert data from human-readable format to serialized format for a browser.
|
||||
@ -20,7 +19,10 @@ class Serializable(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(
|
||||
x: Any, save_dir: str | None = None, encryption_key: bytes | None = None
|
||||
self,
|
||||
x: Any,
|
||||
save_dir: str | Path | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
):
|
||||
"""
|
||||
Convert data from serialized format for a browser to human-readable format.
|
||||
@ -30,7 +32,7 @@ class Serializable(ABC):
|
||||
|
||||
class SimpleSerializable(Serializable):
|
||||
def serialize(
|
||||
self, x: Any, load_dir: str = "", encryption_key: bytes | None = None
|
||||
self, x: Any, load_dir: str | Path = "", encryption_key: bytes | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op.
|
||||
@ -42,7 +44,10 @@ class SimpleSerializable(Serializable):
|
||||
return x
|
||||
|
||||
def deserialize(
|
||||
self, x: Any, save_dir: str | None = None, encryption_key: bytes | None = None
|
||||
self,
|
||||
x: Any,
|
||||
save_dir: str | Path | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
):
|
||||
"""
|
||||
Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op.
|
||||
@ -56,8 +61,11 @@ class SimpleSerializable(Serializable):
|
||||
|
||||
class ImgSerializable(Serializable):
|
||||
def serialize(
|
||||
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
|
||||
) -> str:
|
||||
self,
|
||||
x: str | None,
|
||||
load_dir: str | Path = "",
|
||||
encryption_key: bytes | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Convert from human-friendly version of a file (string filepath) to a seralized
|
||||
representation (base64).
|
||||
@ -69,12 +77,15 @@ class ImgSerializable(Serializable):
|
||||
if x is None or x == "":
|
||||
return None
|
||||
return processing_utils.encode_url_or_file_to_base64(
|
||||
os.path.join(load_dir, x), encryption_key=encryption_key
|
||||
Path(load_dir) / x, encryption_key=encryption_key
|
||||
)
|
||||
|
||||
def deserialize(
|
||||
self, x: str, save_dir: str | None = None, encryption_key: bytes | None = None
|
||||
) -> str:
|
||||
self,
|
||||
x: str | None,
|
||||
save_dir: str | Path | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Convert from serialized representation of a file (base64) to a human-friendly
|
||||
version (string filepath). Optionally, save the file to the directory specified by save_dir
|
||||
@ -93,8 +104,11 @@ class ImgSerializable(Serializable):
|
||||
|
||||
class FileSerializable(Serializable):
|
||||
def serialize(
|
||||
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
|
||||
) -> Any:
|
||||
self,
|
||||
x: str | None,
|
||||
load_dir: str | Path = "",
|
||||
encryption_key: bytes | None = None,
|
||||
) -> Dict | None:
|
||||
"""
|
||||
Convert from human-friendly version of a file (string filepath) to a
|
||||
seralized representation (base64)
|
||||
@ -105,7 +119,7 @@ class FileSerializable(Serializable):
|
||||
"""
|
||||
if x is None or x == "":
|
||||
return None
|
||||
filename = os.path.join(load_dir, x)
|
||||
filename = Path(load_dir) / x
|
||||
return {
|
||||
"name": filename,
|
||||
"data": processing_utils.encode_url_or_file_to_base64(
|
||||
@ -120,7 +134,7 @@ class FileSerializable(Serializable):
|
||||
x: str | Dict | None,
|
||||
save_dir: Path | str | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
):
|
||||
) -> str | None:
|
||||
"""
|
||||
Convert from serialized representation of a file (base64) to a human-friendly
|
||||
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
|
||||
@ -158,7 +172,10 @@ class FileSerializable(Serializable):
|
||||
|
||||
class JSONSerializable(Serializable):
|
||||
def serialize(
|
||||
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
|
||||
self,
|
||||
x: str | None,
|
||||
load_dir: str | Path = "",
|
||||
encryption_key: bytes | None = None,
|
||||
) -> Dict | None:
|
||||
"""
|
||||
Convert from a a human-friendly version (string path to json file) to a
|
||||
@ -170,14 +187,14 @@ class JSONSerializable(Serializable):
|
||||
"""
|
||||
if x is None or x == "":
|
||||
return None
|
||||
return processing_utils.file_to_json(os.path.join(load_dir, x))
|
||||
return processing_utils.file_to_json(Path(load_dir) / x)
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
x: str | Dict,
|
||||
save_dir: str | None = None,
|
||||
save_dir: str | Path | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
) -> str:
|
||||
) -> str | None:
|
||||
"""
|
||||
Convert from serialized representation (json string) to a human-friendly
|
||||
version (string path to json file). Optionally, save the file to the directory specified by `save_dir`
|
||||
|
@ -634,7 +634,7 @@ class AsyncRequest:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_directory(path: Path):
|
||||
def set_directory(path: Path | str):
|
||||
"""Context manager that sets the working directory to the given path."""
|
||||
origin = Path().absolute()
|
||||
try:
|
||||
@ -673,9 +673,7 @@ def sanitize_value_for_csv(value: str | Number) -> str | Number:
|
||||
return value
|
||||
|
||||
|
||||
def sanitize_list_for_csv(
|
||||
values: List[str | Number] | List[List[str | Number]],
|
||||
) -> List[str | Number] | List[List[str | Number]]:
|
||||
def sanitize_list_for_csv(values: T) -> T:
|
||||
"""
|
||||
Sanitizes a list of values (or a list of list of values) that is being written to a
|
||||
CSV file to prevent CSV injection attacks.
|
||||
|
@ -6,4 +6,4 @@ pip_required
|
||||
pip install --upgrade pip
|
||||
pip install pyright
|
||||
cd gradio
|
||||
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py
|
||||
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py
|
||||
|
Loading…
x
Reference in New Issue
Block a user