More Ruff rules (#4038)

* Bump ruff to 0.0.264

* Enable Ruff Naming rules and fix most errors

* Move `clean_html` to utils (to fix an N lint error)

* Changelog

* Clean up possibly leaking file handles

* Enable and autofix Ruff SIM

* Fix remaining Ruff SIMs

* Enable and autofix Ruff UP issues

* Fix misordered import from #4048

* Fix bare except from #4048

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Aarni Koskela 2023-05-05 05:54:23 +03:00 committed by GitHub
parent 71f1e654ab
commit d1853625fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 694 additions and 710 deletions

View File

@ -120,6 +120,7 @@ No changes to highlight.
- CI: Simplified Python CI workflow by [@akx](https://github.com/akx) in [PR 3982](https://github.com/gradio-app/gradio/pull/3982)
- Upgrade pyright to 1.1.305 by [@akx](https://github.com/akx) in [PR 4042](https://github.com/gradio-app/gradio/pull/4042)
- More Ruff rules are enabled and lint errors fixed by [@akx](https://github.com/akx) in [PR 4038](https://github.com/gradio-app/gradio/pull/4038)
## Breaking Changes:

View File

@ -13,7 +13,7 @@ from concurrent.futures import Future, TimeoutError
from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable
import huggingface_hub
import requests
@ -132,7 +132,7 @@ class Client:
hf_token: str | None = None,
private: bool = True,
hardware: str | None = None,
secrets: Dict[str, str] | None = None,
secrets: dict[str, str] | None = None,
sleep_timeout: int = 5,
max_workers: int = 40,
verbose: bool = True,
@ -216,7 +216,7 @@ class Client:
current_info.hardware or huggingface_hub.SpaceHardware.CPU_BASIC
)
hardware = hardware or original_info.hardware
if not current_hardware == hardware:
if current_hardware != hardware:
huggingface_hub.request_space_hardware(space_id, hardware) # type: ignore
print(
f"-------\nNOTE: this Space uses upgraded hardware: {hardware}... see billing info at https://huggingface.co/settings/billing\n-------"
@ -262,7 +262,7 @@ class Client:
*args,
api_name: str | None = None,
fn_index: int | None = None,
result_callbacks: Callable | List[Callable] | None = None,
result_callbacks: Callable | list[Callable] | None = None,
) -> Job:
"""
Creates and returns a Job object which calls the Gradio API in a background thread. The job can be used to retrieve the status and result of the remote API call.
@ -323,7 +323,7 @@ class Client:
all_endpoints: bool | None = None,
print_info: bool = True,
return_format: Literal["dict", "str"] | None = None,
) -> Dict | str | None:
) -> dict | str | None:
"""
Prints the usage info for the API. If the Gradio app has multiple API endpoints, the usage info for each endpoint will be printed separately. If return_format="dict" the info is returned in dictionary format, as shown in the example below.
@ -449,7 +449,7 @@ class Client:
def _render_endpoints_info(
self,
name_or_index: str | int,
endpoints_info: Dict[str, List[Dict[str, str]]],
endpoints_info: dict[str, list[dict[str, str]]],
) -> str:
parameter_names = [p["label"] for p in endpoints_info["parameters"]]
parameter_names = [utils.sanitize_parameter_names(p) for p in parameter_names]
@ -542,7 +542,7 @@ class Client:
def _space_name_to_src(self, space) -> str | None:
return huggingface_hub.space_info(space, token=self.hf_token).host # type: ignore
def _get_config(self) -> Dict:
def _get_config(self) -> dict:
r = requests.get(
urllib.parse.urljoin(self.src, utils.CONFIG_URL), headers=self.headers
)
@ -568,7 +568,7 @@ class Client:
class Endpoint:
"""Helper class for storing all the information about a single API endpoint."""
def __init__(self, client: Client, fn_index: int, dependency: Dict):
def __init__(self, client: Client, fn_index: int, dependency: dict):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
@ -615,7 +615,7 @@ class Endpoint:
return _inner
def make_predict(self, helper: Communicator | None = None):
def _predict(*data) -> Tuple:
def _predict(*data) -> tuple:
data = json.dumps(
{
"data": data,
@ -669,8 +669,8 @@ class Endpoint:
return outputs
def _upload(
self, file_paths: List[str | List[str]]
) -> List[str | List[str]] | List[Dict[str, Any] | List[Dict[str, Any]]]:
self, file_paths: list[str | list[str]]
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
if not file_paths:
return []
# Put all the filepaths in one file
@ -683,7 +683,7 @@ class Endpoint:
if not isinstance(fs, list):
fs = [fs]
for f in fs:
files.append(("files", (Path(f).name, open(f, "rb"))))
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
indices.append(i)
r = requests.post(
self.client.upload_url, headers=self.client.headers, files=files
@ -718,8 +718,8 @@ class Endpoint:
def _add_uploaded_files_to_data(
self,
files: List[str | List[str]] | List[Dict[str, Any] | List[Dict[str, Any]]],
data: List[Any],
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
data: list[Any],
) -> None:
"""Helper function to modify the input data with the uploaded files."""
file_counter = 0
@ -728,7 +728,7 @@ class Endpoint:
data[i] = files[file_counter]
file_counter += 1
def serialize(self, *data) -> Tuple:
def serialize(self, *data) -> tuple:
data = list(data)
for i, input_component_type in enumerate(self.input_component_types):
if input_component_type == utils.STATE_COMPONENT:
@ -748,7 +748,7 @@ class Endpoint:
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
return o
def deserialize(self, *data) -> Tuple | Any:
def deserialize(self, *data) -> tuple | Any:
assert len(data) == len(
self.deserializers
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
@ -758,7 +758,7 @@ class Endpoint:
for s, d, oct in zip(
self.deserializers, data, self.output_component_types
)
if not oct == utils.STATE_COMPONENT
if oct != utils.STATE_COMPONENT
]
)
if (
@ -766,7 +766,7 @@ class Endpoint:
[
oct
for oct in self.output_component_types
if not oct == utils.STATE_COMPONENT
if oct != utils.STATE_COMPONENT
]
)
== 1
@ -776,7 +776,7 @@ class Endpoint:
output = outputs
return output
def _setup_serializers(self) -> Tuple[List[Serializable], List[Serializable]]:
def _setup_serializers(self) -> tuple[list[Serializable], list[Serializable]]:
inputs = self.dependency["inputs"]
serializers = []
@ -820,7 +820,7 @@ class Endpoint:
return serializers, deserializers
def _use_websocket(self, dependency: Dict) -> bool:
def _use_websocket(self, dependency: dict) -> bool:
queue_enabled = self.client.config.get("enable_queue", False)
queue_uses_websocket = version.parse(
self.client.config.get("version", "2.0")
@ -873,7 +873,7 @@ class Job(Future):
def __iter__(self) -> Job:
return self
def __next__(self) -> Tuple | Any:
def __next__(self) -> tuple | Any:
if not self.communicator:
raise StopIteration()
@ -925,7 +925,7 @@ class Job(Future):
else:
return super().result(timeout=timeout)
def outputs(self) -> List[Tuple | Any]:
def outputs(self) -> list[tuple | Any]:
"""
Returns a list containing the latest outputs from the Job.

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import inspect
from typing import Callable, Dict, List, Tuple
from typing import Callable
classes_to_document = {}
classes_inherit_documentation = {}
@ -61,7 +61,7 @@ def document(*fns, inherit=False):
return inner_doc
def document_fn(fn: Callable, cls) -> Tuple[str, List[Dict], Dict, str | None]:
def document_fn(fn: Callable, cls) -> tuple[str, list[dict], dict, str | None]:
"""
Generates documentation for any function.
Parameters:

View File

@ -4,21 +4,21 @@ import json
import os
import uuid
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any
from gradio_client import media_data, utils
from gradio_client.data_classes import FileData
class Serializable:
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
"""
The typing information for this component as a dictionary whose values are a list of 2 strings: [Python type, language-agnostic description].
Keys of the dictionary are: raw_input, raw_output, serialized_input, serialized_output
"""
raise NotImplementedError()
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
"""
The example inputs for this component as a dictionary whose values are example inputs compatible with this component.
Keys of the dictionary are: raw, serialized
@ -26,12 +26,12 @@ class Serializable:
raise NotImplementedError()
# For backwards compatibility
def input_api_info(self) -> Tuple[str, str]:
def input_api_info(self) -> tuple[str, str]:
api_info = self.api_info()
return (api_info["serialized_input"][0], api_info["serialized_input"][1])
# For backwards compatibility
def output_api_info(self) -> Tuple[str, str]:
def output_api_info(self) -> tuple[str, str]:
api_info = self.api_info()
return (api_info["serialized_output"][0], api_info["serialized_output"][1])
@ -57,7 +57,7 @@ class Serializable:
class SimpleSerializable(Serializable):
"""General class that does not perform any serialization or deserialization."""
def api_info(self) -> Dict[str, str | List[str]]:
def api_info(self) -> dict[str, str | list[str]]:
return {
"raw_input": ["Any", ""],
"raw_output": ["Any", ""],
@ -65,7 +65,7 @@ class SimpleSerializable(Serializable):
"serialized_output": ["Any", ""],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": None,
"serialized": None,
@ -75,7 +75,7 @@ class SimpleSerializable(Serializable):
class StringSerializable(Serializable):
"""Expects a string as input/output but performs no serialization."""
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": ["str", "string value"],
"raw_output": ["str", "string value"],
@ -83,7 +83,7 @@ class StringSerializable(Serializable):
"serialized_output": ["str", "string value"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": "Howdy!",
"serialized": "Howdy!",
@ -93,7 +93,7 @@ class StringSerializable(Serializable):
class ListStringSerializable(Serializable):
"""Expects a list of strings as input/output but performs no serialization."""
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": ["List[str]", "list of string values"],
"raw_output": ["List[str]", "list of string values"],
@ -101,7 +101,7 @@ class ListStringSerializable(Serializable):
"serialized_output": ["List[str]", "list of string values"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": ["Howdy!", "Merhaba"],
"serialized": ["Howdy!", "Merhaba"],
@ -111,7 +111,7 @@ class ListStringSerializable(Serializable):
class BooleanSerializable(Serializable):
"""Expects a boolean as input/output but performs no serialization."""
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": ["bool", "boolean value"],
"raw_output": ["bool", "boolean value"],
@ -119,7 +119,7 @@ class BooleanSerializable(Serializable):
"serialized_output": ["bool", "boolean value"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": True,
"serialized": True,
@ -129,7 +129,7 @@ class BooleanSerializable(Serializable):
class NumberSerializable(Serializable):
"""Expects a number (int/float) as input/output but performs no serialization."""
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": ["int | float", "numeric value"],
"raw_output": ["int | float", "numeric value"],
@ -137,7 +137,7 @@ class NumberSerializable(Serializable):
"serialized_output": ["int | float", "numeric value"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": 5,
"serialized": 5,
@ -147,7 +147,7 @@ class NumberSerializable(Serializable):
class ImgSerializable(Serializable):
"""Expects a base64 string as input/output which is serialized to a filepath."""
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": ["str", "base64 representation of image"],
"raw_output": ["str", "base64 representation of image"],
@ -155,7 +155,7 @@ class ImgSerializable(Serializable):
"serialized_output": ["str", "filepath or URL to image"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": media_data.BASE64_IMAGE,
"serialized": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
@ -204,7 +204,7 @@ class ImgSerializable(Serializable):
class FileSerializable(Serializable):
"""Expects a dict with base64 representation of object as input/output which is serialized to a filepath."""
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": [
"str | Dict",
@ -218,7 +218,7 @@ class FileSerializable(Serializable):
"serialized_output": ["str", "filepath or URL to file"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": {"is_file": False, "data": media_data.BASE64_FILE},
"serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf",
@ -280,9 +280,9 @@ class FileSerializable(Serializable):
def serialize(
self,
x: str | FileData | None | List[str | FileData | None],
x: str | FileData | None | list[str | FileData | None],
load_dir: str | Path = "",
) -> FileData | None | List[FileData | None]:
) -> FileData | None | list[FileData | None]:
"""
Convert from human-friendly version of a file (string filepath) to a
seralized representation (base64)
@ -299,11 +299,11 @@ class FileSerializable(Serializable):
def deserialize(
self,
x: str | FileData | None | List[str | FileData | None],
x: str | FileData | None | list[str | FileData | None],
save_dir: Path | str | None = None,
root_url: str | None = None,
hf_token: str | None = None,
) -> str | None | List[str | None]:
) -> str | None | list[str | None]:
"""
Convert from serialized representation of a file (base64) to a human-friendly
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
@ -331,7 +331,7 @@ class FileSerializable(Serializable):
class VideoSerializable(FileSerializable):
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": [
"str | Dict",
@ -345,7 +345,7 @@ class VideoSerializable(FileSerializable):
"serialized_output": ["str", "filepath or URL to file"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": {"is_file": False, "data": media_data.BASE64_VIDEO},
"serialized": "https://github.com/gradio-app/gradio/raw/main/test/test_files/video_sample.mp4",
@ -353,16 +353,16 @@ class VideoSerializable(FileSerializable):
def serialize(
self, x: str | None, load_dir: str | Path = ""
) -> Tuple[FileData | None, None]:
) -> tuple[FileData | None, None]:
return (super().serialize(x, load_dir), None) # type: ignore
def deserialize(
self,
x: Tuple[FileData | None, FileData | None] | None,
x: tuple[FileData | None, FileData | None] | None,
save_dir: Path | str | None = None,
root_url: str | None = None,
hf_token: str | None = None,
) -> str | Tuple[str | None, str | None] | None:
) -> str | tuple[str | None, str | None] | None:
"""
Convert from serialized representation of a file (base64) to a human-friendly
version (string filepath). Optionally, save the file to the directory specified by `save_dir`
@ -378,7 +378,7 @@ class VideoSerializable(FileSerializable):
class JSONSerializable(Serializable):
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": ["str | Dict | List", "JSON-serializable object or a string"],
"raw_output": ["Dict | List", "dictionary- or list-like object"],
@ -386,7 +386,7 @@ class JSONSerializable(Serializable):
"serialized_output": ["str", "filepath to JSON file"],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": {"a": 1, "b": 2},
"serialized": None,
@ -396,7 +396,7 @@ class JSONSerializable(Serializable):
self,
x: str | None,
load_dir: str | Path = "",
) -> Dict | List | None:
) -> dict | list | None:
"""
Convert from a a human-friendly version (string path to json file) to a
serialized representation (json string)
@ -410,7 +410,7 @@ class JSONSerializable(Serializable):
def deserialize(
self,
x: str | Dict | List,
x: str | dict | list,
save_dir: str | Path | None = None,
root_url: str | None = None,
hf_token: str | None = None,
@ -430,7 +430,7 @@ class JSONSerializable(Serializable):
class GallerySerializable(Serializable):
def api_info(self) -> Dict[str, List[str]]:
def api_info(self) -> dict[str, list[str]]:
return {
"raw_input": [
"List[List[str | None]]",
@ -450,7 +450,7 @@ class GallerySerializable(Serializable):
],
}
def example_inputs(self) -> Dict[str, Any]:
def example_inputs(self) -> dict[str, Any]:
return {
"raw": [media_data.BASE64_IMAGE] * 2,
"serialized": [
@ -461,7 +461,7 @@ class GallerySerializable(Serializable):
def serialize(
self, x: str | None, load_dir: str | Path = ""
) -> List[List[str | None]] | None:
) -> list[list[str | None]] | None:
if x is None or x == "":
return None
files = []
@ -475,7 +475,7 @@ class GallerySerializable(Serializable):
def deserialize(
self,
x: List[List[str | None]] | None,
x: list[list[str | None]] | None,
save_dir: str = "",
root_url: str | None = None,
hf_token: str | None = None,
@ -486,7 +486,7 @@ class GallerySerializable(Serializable):
gallery_path.mkdir(exist_ok=True, parents=True)
captions = {}
for img_data in x:
if isinstance(img_data, list) or isinstance(img_data, tuple):
if isinstance(img_data, (list, tuple)):
img_data, caption = img_data
else:
caption = None
@ -510,7 +510,7 @@ SERIALIZER_MAPPING["Serializable"] = SimpleSerializable
SERIALIZER_MAPPING["File"] = FileSerializable
SERIALIZER_MAPPING["UploadButton"] = FileSerializable
COMPONENT_MAPPING: Dict[str, type] = {
COMPONENT_MAPPING: dict[str, type] = {
"textbox": StringSerializable,
"number": NumberSerializable,
"slider": NumberSerializable,

View File

@ -14,7 +14,7 @@ from datetime import datetime
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import fsspec.asyn
import httpx
@ -84,7 +84,7 @@ class Status(Enum):
CANCELLED = "CANCELLED"
@staticmethod
def ordering(status: "Status") -> int:
def ordering(status: Status) -> int:
"""Order of messages. Helpful for testing."""
order = [
Status.STARTING,
@ -100,11 +100,11 @@ class Status(Enum):
]
return order.index(status)
def __lt__(self, other: "Status"):
def __lt__(self, other: Status):
return self.ordering(self) < self.ordering(other)
@staticmethod
def msg_to_status(msg: str) -> "Status":
def msg_to_status(msg: str) -> Status:
"""Map the raw message from the backend to the status code presented to users."""
return {
"send_hash": Status.JOINING_QUEUE,
@ -127,7 +127,7 @@ class ProgressUnit:
desc: Optional[str]
@classmethod
def from_ws_msg(cls, data: List[Dict]) -> List["ProgressUnit"]:
def from_ws_msg(cls, data: list[dict]) -> list[ProgressUnit]:
return [
cls(
index=d.get("index"),
@ -150,7 +150,7 @@ class StatusUpdate:
eta: float | None
success: bool | None
time: datetime | None
progress_data: List[ProgressUnit] | None
progress_data: list[ProgressUnit] | None
def create_initial_status_update():
@ -173,7 +173,7 @@ class JobStatus:
"""
latest_status: StatusUpdate = field(default_factory=create_initial_status_update)
outputs: List[Any] = field(default_factory=list)
outputs: list[Any] = field(default_factory=list)
@dataclass
@ -182,7 +182,7 @@ class Communicator:
lock: Lock
job: JobStatus
deserialize: Callable[..., Tuple]
deserialize: Callable[..., tuple]
reset_url: str
should_cancel: bool = False
@ -208,7 +208,7 @@ async def get_pred_from_ws(
data: str,
hash_data: str,
helper: Communicator | None = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
completed = False
resp = {}
while not completed:
@ -285,9 +285,10 @@ def download_tmp_copy_of_file(
suffix=suffix,
dir=dir,
)
with requests.get(url_path, headers=headers, stream=True) as r:
with open(file_obj.name, "wb") as f:
shutil.copyfileobj(r.raw, f)
with requests.get(url_path, headers=headers, stream=True) as r, open(
file_obj.name, "wb"
) as f:
shutil.copyfileobj(r.raw, f)
return file_obj
@ -360,7 +361,7 @@ def encode_url_or_file_to_base64(path: str | Path):
return encode_file_to_base64(path)
def decode_base64_to_binary(encoding: str) -> Tuple[bytes, str | None]:
def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]:
extension = get_extension(encoding)
data = encoding.rsplit(",", 1)[-1]
return base64.b64decode(data), extension
@ -421,7 +422,7 @@ def decode_base64_to_file(
return file_obj
def dict_or_str_to_json_file(jsn: str | Dict | List, dir: str | Path | None = None):
def dict_or_str_to_json_file(jsn: str | dict | list, dir: str | Path | None = None):
if dir is not None:
os.makedirs(dir, exist_ok=True)
@ -435,7 +436,7 @@ def dict_or_str_to_json_file(jsn: str | Dict | List, dir: str | Path | None = No
return file_obj
def file_to_json(file_path: str | Path) -> Dict | List:
def file_to_json(file_path: str | Path) -> dict | list:
with open(file_path) as f:
return json.load(f)

View File

@ -1,6 +1,6 @@
black==22.6.0
pytest-asyncio
pytest==7.1.2
ruff==0.0.263
ruff==0.0.264
pyright==1.1.305
gradio

View File

@ -48,8 +48,8 @@ class TestPredictionsFromSpaces:
@pytest.mark.flaky
def test_numerical_to_label_space(self):
client = Client("gradio-tests/titanic-survival")
output = client.predict("male", 77, 10, api_name="/predict")
assert json.load(open(output))["label"] == "Perishes"
with open(client.predict("male", 77, 10, api_name="/predict")) as f:
assert json.load(f)["label"] == "Perishes"
with pytest.raises(
ValueError,
match="This Gradio app might have multiple endpoints. Please specify an `api_name` or `fn_index`",
@ -258,7 +258,8 @@ class TestPredictionsFromSpaces:
f.write("Hello from private space!")
output = client.submit(1, "foo", f.name, api_name="/file_upload").result()
assert open(output).read() == "Hello from private space!"
with open(output) as f:
assert f.read() == "Hello from private space!"
upload.assert_called_once()
with patch.object(
@ -267,24 +268,28 @@ class TestPredictionsFromSpaces:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Hello from private space!")
output = client.submit(f.name, api_name="/upload_btn").result()
assert open(output).read() == "Hello from private space!"
with open(client.submit(f.name, api_name="/upload_btn").result()) as f:
assert f.read() == "Hello from private space!"
upload.assert_called_once()
with patch.object(
client.endpoints[2], "_upload", wraps=client.endpoints[0]._upload
) as upload:
# `delete=False` is required for Windows compat
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2:
f1.write("File1")
f2.write("File2")
output = client.submit(
3, [f1.name, f2.name], "hello", api_name="/upload_multiple"
r1, r2 = client.submit(
3,
[f1.name, f2.name],
"hello",
api_name="/upload_multiple",
).result()
assert open(output[0]).read() == "File1"
assert open(output[1]).read() == "File2"
with open(r1) as f:
assert f.read() == "File1"
with open(r2) as f:
assert f.read() == "File2"
upload.assert_called_once()
@pytest.mark.flaky
@ -588,10 +593,8 @@ class TestAPIInfo:
def test_serializable_in_mapping(self, calculator_demo):
with connect(calculator_demo) as client:
assert all(
[
isinstance(c, SimpleSerializable)
for c in client.endpoints[0].serializers
]
isinstance(c, SimpleSerializable)
for c in client.endpoints[0].serializers
)
@pytest.mark.flaky

View File

@ -34,8 +34,10 @@ def test_file_serializing():
assert serializing.serialize(output) == output
files = serializing.deserialize(output)
assert open(files[0]).read() == "Hello World!"
assert open(files[1]).read() == "Greetings!"
with open(files[0]) as f:
assert f.read() == "Hello World!"
with open(files[1]) as f:
assert f.read() == "Greetings!"
finally:
os.remove(f1.name)
os.remove(f2.name)

View File

@ -65,7 +65,7 @@ from gradio.flagging import (
SimpleCSVLogger,
)
from gradio.helpers import EventData, Progress, make_waveform, skip, update
from gradio.helpers import create_examples as Examples
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.interface import Interface, TabbedInterface, close_all
from gradio.ipython_ext import load_ipython_extension
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs

View File

@ -12,7 +12,7 @@ import warnings
import webbrowser
from abc import abstractmethod
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tuple, Type
from typing import TYPE_CHECKING, Any, Callable, Iterator
import anyio
import requests
@ -34,7 +34,7 @@ from gradio import (
)
from gradio.context import Context
from gradio.deprecation import check_deprecated_parameters
from gradio.exceptions import DuplicateBlockError, InvalidApiName
from gradio.exceptions import DuplicateBlockError, InvalidApiNameError
from gradio.helpers import EventData, create_tracker, skip, special_args
from gradio.themes import Default as DefaultTheme
from gradio.themes import ThemeClass as Theme
@ -56,7 +56,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.components import Component
BUILT_IN_THEMES: Dict[str, Theme] = {
BUILT_IN_THEMES: dict[str, Theme] = {
t.name: t
for t in [
themes.Base(),
@ -74,7 +74,7 @@ class Block:
*,
render: bool = True,
elem_id: str | None = None,
elem_classes: List[str] | str | None = None,
elem_classes: list[str] | str | None = None,
visible: bool = True,
root_url: str | None = None, # URL that is prepended to all file paths
_skip_init_processing: bool = False, # Used for loading from Spaces
@ -145,15 +145,15 @@ class Block:
else self.__class__.__name__.lower()
)
def get_expected_parent(self) -> Type[BlockContext] | None:
def get_expected_parent(self) -> type[BlockContext] | None:
return None
def set_event_trigger(
self,
event_name: str,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None,
outputs: Component | List[Component] | None,
inputs: Component | list[Component] | set[Component] | None,
outputs: Component | list[Component] | None,
preprocess: bool = True,
postprocess: bool = True,
scroll_to_output: bool = False,
@ -164,12 +164,12 @@ class Block:
queue: bool | None = None,
batch: bool = False,
max_batch_size: int = 4,
cancels: List[int] | None = None,
cancels: list[int] | None = None,
every: float | None = None,
collects_event_data: bool | None = None,
trigger_after: int | None = None,
trigger_only_on_success: bool = False,
) -> Tuple[Dict[str, Any], int]:
) -> tuple[dict[str, Any], int]:
"""
Adds an event to the component's dependencies.
Parameters:
@ -251,7 +251,7 @@ class Block:
api_name_ = utils.append_unique_suffix(
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
)
if not (api_name == api_name_):
if api_name != api_name_:
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
api_name = api_name_
@ -295,11 +295,11 @@ class Block:
@staticmethod
@abstractmethod
def update(**kwargs) -> Dict:
def update(**kwargs) -> dict:
return {}
@classmethod
def get_specific_update(cls, generic_update: Dict[str, Any]) -> Dict:
def get_specific_update(cls, generic_update: dict[str, Any]) -> dict:
generic_update = generic_update.copy()
del generic_update["__type__"]
specific_update = cls.update(**generic_update)
@ -318,7 +318,7 @@ class BlockContext(Block):
visible: If False, this will be hidden but included in the Blocks config file (its visibility can later be updated).
render: If False, this will not be included in the Blocks config file at all.
"""
self.children: List[Block] = []
self.children: list[Block] = []
Block.__init__(self, visible=visible, render=render, **kwargs)
def __enter__(self):
@ -368,8 +368,8 @@ class BlockFunction:
def __init__(
self,
fn: Callable | None,
inputs: List[Component],
outputs: List[Component],
inputs: list[Component],
outputs: list[Component],
preprocess: bool,
postprocess: bool,
inputs_as_dict: bool,
@ -399,13 +399,13 @@ class BlockFunction:
return str(self)
class class_or_instancemethod(classmethod):
class class_or_instancemethod(classmethod): # noqa: N801
def __get__(self, instance, type_):
descr_get = super().__get__ if instance is None else self.__func__.__get__
return descr_get(instance, type_)
def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool = True):
def postprocess_update_dict(block: Block, update_dict: dict, postprocess: bool = True):
"""
Converts a dictionary of updates into a format that can be sent to the frontend.
E.g. {"__type__": "generic_update", "value": "2", "interactive": False}
@ -433,15 +433,15 @@ def postprocess_update_dict(block: Block, update_dict: Dict, postprocess: bool =
def convert_component_dict_to_list(
outputs_ids: List[int], predictions: Dict
) -> List | Dict:
outputs_ids: list[int], predictions: dict
) -> list | dict:
"""
Converts a dictionary of component updates into a list of updates in the order of
the outputs_ids and including every output component. Leaves other types of dictionaries unchanged.
E.g. {"textbox": "hello", "number": {"__type__": "generic_update", "value": "2"}}
Into -> ["hello", {"__type__": "generic_update"}, {"__type__": "generic_update", "value": "2"}]
"""
keys_are_blocks = [isinstance(key, Block) for key in predictions.keys()]
keys_are_blocks = [isinstance(key, Block) for key in predictions]
if all(keys_are_blocks):
reordered_predictions = [skip() for _ in outputs_ids]
for component, value in predictions.items():
@ -459,7 +459,7 @@ def convert_component_dict_to_list(
return predictions
def get_api_info(config: Dict, serialize: bool = True):
def get_api_info(config: dict, serialize: bool = True):
"""
Gets the information needed to generate the API docs from a Blocks config.
Parameters:
@ -662,8 +662,8 @@ class Blocks(BlockContext):
if not self.analytics_enabled:
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
super().__init__(render=False, **kwargs)
self.blocks: Dict[int, Block] = {}
self.fns: List[BlockFunction] = []
self.blocks: dict[int, Block] = {}
self.fns: list[BlockFunction] = []
self.dependencies = []
self.mode = mode
@ -674,7 +674,7 @@ class Blocks(BlockContext):
self.height = None
self.api_open = True
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
self.is_space = os.getenv("SYSTEM") == "spaces"
self.favicon_path = None
self.auth = None
self.dev_mode = True
@ -713,7 +713,7 @@ class Blocks(BlockContext):
def from_config(
cls,
config: dict,
fns: List[Callable],
fns: list[Callable],
root_url: str | None = None,
) -> Blocks:
"""
@ -727,7 +727,7 @@ class Blocks(BlockContext):
config = copy.deepcopy(config)
components_config = config["components"]
theme = config.get("theme", "default")
original_mapping: Dict[int, Block] = {}
original_mapping: dict[int, Block] = {}
def get_block_instance(id: int) -> Block:
for block_config in components_config:
@ -862,7 +862,7 @@ class Blocks(BlockContext):
api_name,
[dep["api_name"] for dep in Context.root_block.dependencies],
)
if not (api_name == api_name_):
if api_name != api_name_:
warnings.warn(
f"api_name {api_name} already exists, using {api_name_}"
)
@ -937,7 +937,9 @@ class Blocks(BlockContext):
None,
)
if inferred_fn_index is None:
raise InvalidApiName(f"Cannot find a function with api_name {api_name}")
raise InvalidApiNameError(
f"Cannot find a function with api_name {api_name}"
)
fn_index = inferred_fn_index
if not (self.is_callable(fn_index)):
raise ValueError(
@ -970,9 +972,9 @@ class Blocks(BlockContext):
async def call_function(
self,
fn_index: int,
processed_input: List[Any],
processed_input: list[Any],
iterator: Iterator[Any] | None = None,
requests: routes.Request | List[routes.Request] | None = None,
requests: routes.Request | list[routes.Request] | None = None,
event_id: str | None = None,
event_data: EventData | None = None,
):
@ -993,10 +995,7 @@ class Blocks(BlockContext):
if block_fn.inputs_as_dict:
processed_input = [dict(zip(block_fn.inputs, processed_input))]
if isinstance(requests, list):
request = requests[0]
else:
request = requests
request = requests[0] if isinstance(requests, list) else requests
processed_input, progress_index, _ = special_args(
block_fn.fn, processed_input, request, event_data
)
@ -1054,7 +1053,7 @@ class Blocks(BlockContext):
"iterator": iterator,
}
def serialize_data(self, fn_index: int, inputs: List[Any]) -> List[Any]:
def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
dependency = self.dependencies[fn_index]
processed_input = []
@ -1068,7 +1067,7 @@ class Blocks(BlockContext):
return processed_input
def deserialize_data(self, fn_index: int, outputs: List[Any]) -> List[Any]:
def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
dependency = self.dependencies[fn_index]
predictions = []
@ -1084,7 +1083,7 @@ class Blocks(BlockContext):
return predictions
def validate_inputs(self, fn_index: int, inputs: List[Any]):
def validate_inputs(self, fn_index: int, inputs: list[Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
@ -1106,10 +1105,7 @@ class Blocks(BlockContext):
block = self.blocks[input_id]
wanted_args.append(str(block))
for inp in inputs:
if isinstance(inp, str):
v = f'"{inp}"'
else:
v = str(inp)
v = f'"{inp}"' if isinstance(inp, str) else str(inp)
received_args.append(v)
wanted = ", ".join(wanted_args)
@ -1125,7 +1121,7 @@ Received inputs:
[{received}]"""
)
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
def preprocess_data(self, fn_index: int, inputs: list[Any], state: dict[int, Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
@ -1146,7 +1142,7 @@ Received inputs:
processed_input = inputs
return processed_input
def validate_outputs(self, fn_index: int, predictions: Any | List[Any]):
def validate_outputs(self, fn_index: int, predictions: Any | list[Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
@ -1168,10 +1164,7 @@ Received inputs:
block = self.blocks[output_id]
wanted_args.append(str(block))
for pred in predictions:
if isinstance(pred, str):
v = f'"{pred}"'
else:
v = str(pred)
v = f'"{pred}"' if isinstance(pred, str) else str(pred)
received_args.append(v)
wanted = ", ".join(wanted_args)
@ -1186,7 +1179,7 @@ Received outputs:
)
def postprocess_data(
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
self, fn_index: int, predictions: list | dict, state: dict[int, Any]
):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
@ -1241,13 +1234,13 @@ Received outputs:
async def process_api(
self,
fn_index: int,
inputs: List[Any],
state: Dict[int, Any],
request: routes.Request | List[routes.Request] | None = None,
iterators: Dict[int, Any] | None = None,
inputs: list[Any],
state: dict[int, Any],
request: routes.Request | list[routes.Request] | None = None,
iterators: dict[int, Any] | None = None,
event_id: str | None = None,
event_data: EventData | None = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Processes API calls from the frontend. First preprocesses the data,
then runs the relevant function, then postprocesses the output.
@ -1342,15 +1335,15 @@ Received outputs:
"theme": self.theme.name,
}
def getLayout(block):
def get_layout(block):
if not isinstance(block, BlockContext):
return {"id": block._id}
children_layout = []
for child in block.children:
children_layout.append(getLayout(child))
children_layout.append(get_layout(child))
return {"id": block._id, "children": children_layout}
config["layout"] = getLayout(self)
config["layout"] = get_layout(self)
for _id, block in self.blocks.items():
props = block.get_config() if hasattr(block, "get_config") else {}
@ -1393,10 +1386,10 @@ Received outputs:
@class_or_instancemethod
def load(
self_or_cls,
self_or_cls, # noqa: N805
fn: Callable | None = None,
inputs: List[Component] | None = None,
outputs: List[Component] | None = None,
inputs: list[Component] | None = None,
outputs: list[Component] | None = None,
api_name: str | None = None,
scroll_to_output: bool = False,
show_progress: bool = True,
@ -1413,7 +1406,7 @@ Received outputs:
api_key: str | None = None,
alias: str | None = None,
**kwargs,
) -> Blocks | Dict[str, Any] | None:
) -> Blocks | dict[str, Any] | None:
"""
For reverse compatibility reasons, this is both a class method and an instance
method, the two of which, confusingly, do two completely different things.
@ -1571,7 +1564,7 @@ Received outputs:
debug: bool = False,
enable_queue: bool | None = None,
max_threads: int = 40,
auth: Callable | Tuple[str, str] | List[Tuple[str, str]] | None = None,
auth: Callable | tuple[str, str] | list[tuple[str, str]] | None = None,
auth_message: str | None = None,
prevent_thread_lock: bool = False,
show_error: bool = False,
@ -1588,11 +1581,11 @@ Received outputs:
ssl_verify: bool = True,
quiet: bool = False,
show_api: bool = True,
file_directories: List[str] | None = None,
allowed_paths: List[str] | None = None,
blocked_paths: List[str] | None = None,
file_directories: list[str] | None = None,
allowed_paths: list[str] | None = None,
blocked_paths: list[str] | None = None,
_frontend: bool = True,
) -> Tuple[FastAPI, str, str]:
) -> tuple[FastAPI, str, str]:
"""
Launches a simple web server that serves the demo. Can also be used to create a
public link used by anyone to access the demo from their browser by setting share=True.

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,7 @@ of the on-page-load event, which is defined in gr.Blocks().load()."""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple
from typing import TYPE_CHECKING, Any, Callable
from gradio_client.documentation import document, set_documentation_group
@ -19,7 +19,7 @@ set_documentation_group("events")
def set_cancel_events(
block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
block: Block, event_name: str, cancels: None | dict[str, Any] | list[dict[str, Any]]
):
if cancels:
if not isinstance(cancels, list):
@ -91,8 +91,8 @@ class EventListenerMethod:
def __call__(
self,
fn: Callable | None,
inputs: Component | List[Component] | Set[Component] | None = None,
outputs: Component | List[Component] | None = None,
inputs: Component | list[Component] | set[Component] | None = None,
outputs: Component | list[Component] | None = None,
api_name: str | None = None,
status_tracker: StatusTracker | None = None,
scroll_to_output: bool = False,
@ -102,7 +102,7 @@ class EventListenerMethod:
max_batch_size: int = 4,
preprocess: bool = True,
postprocess: bool = True,
cancels: Dict[str, Any] | List[Dict[str, Any]] | None = None,
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
every: float | None = None,
_js: str | None = None,
) -> Dependency:
@ -290,7 +290,7 @@ class Selectable(EventListener):
class SelectData(EventData):
def __init__(self, target: Block | None, data: Any):
super().__init__(target, data)
self.index: int | Tuple[int, int] = data["index"]
self.index: int | tuple[int, int] = data["index"]
"""
The index of the selected item. Is a tuple if the component is two dimensional or selection is a range.
"""

View File

@ -15,10 +15,13 @@ class TooManyRequestsError(Exception):
pass
class InvalidApiName(ValueError):
class InvalidApiNameError(ValueError):
pass
InvalidApiName = InvalidApiNameError # backwards compatibility
@document()
class Error(Exception):
"""

View File

@ -6,7 +6,7 @@ from __future__ import annotations
import json
import re
import warnings
from typing import TYPE_CHECKING, Callable, Dict
from typing import TYPE_CHECKING, Callable
import requests
from gradio_client import Client
@ -87,7 +87,7 @@ def load_blocks_from_repo(
src = tokens[0]
name = "/".join(tokens[1:])
factory_methods: Dict[str, Callable] = {
factory_methods: dict[str, Callable] = {
# for each repo type, we have a method that returns the Interface given the model name & optionally an api_key
"huggingface": from_model,
"models": from_model,
@ -393,7 +393,7 @@ def from_model(model_name: str, api_key: str | None, alias: str | None, **kwargs
data.update({"options": {"wait_for_model": True}})
data = json.dumps(data)
response = requests.request("POST", api_url, headers=headers, data=data)
if not (response.status_code == 200):
if response.status_code != 200:
errors_json = response.json()
errors, warns = "", ""
if errors_json.get("error"):
@ -494,7 +494,7 @@ def from_spaces_blocks(space: str, api_key: str | None) -> Blocks:
def from_spaces_interface(
model_name: str,
config: Dict,
config: dict,
alias: str | None,
api_key: str | None,
iframe_url: str,

View File

@ -10,7 +10,7 @@ import warnings
from abc import ABC, abstractmethod
from distutils.version import StrictVersion
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any
import filelock
import huggingface_hub
@ -33,7 +33,7 @@ class FlaggingCallback(ABC):
"""
@abstractmethod
def setup(self, components: List[IOComponent], flagging_dir: str):
def setup(self, components: list[IOComponent], flagging_dir: str):
"""
This method should be overridden and ensure that everything is set up correctly for flag().
This method gets called once at the beginning of the Interface.launch() method.
@ -46,7 +46,7 @@ class FlaggingCallback(ABC):
@abstractmethod
def flag(
self,
flag_data: List[Any],
flag_data: list[Any],
flag_option: str = "",
username: str | None = None,
) -> int:
@ -81,14 +81,14 @@ class SimpleCSVLogger(FlaggingCallback):
def __init__(self):
pass
def setup(self, components: List[IOComponent], flagging_dir: str | Path):
def setup(self, components: list[IOComponent], flagging_dir: str | Path):
self.components = components
self.flagging_dir = flagging_dir
os.makedirs(flagging_dir, exist_ok=True)
def flag(
self,
flag_data: List[Any],
flag_data: list[Any],
flag_option: str = "",
username: str | None = None,
) -> int:
@ -112,7 +112,7 @@ class SimpleCSVLogger(FlaggingCallback):
writer = csv.writer(csvfile)
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(log_filepath, "r") as csvfile:
with open(log_filepath) as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@ -136,7 +136,7 @@ class CSVLogger(FlaggingCallback):
def setup(
self,
components: List[IOComponent],
components: list[IOComponent],
flagging_dir: str | Path,
):
self.components = components
@ -145,7 +145,7 @@ class CSVLogger(FlaggingCallback):
def flag(
self,
flag_data: List[Any],
flag_data: list[Any],
flag_option: str = "",
username: str | None = None,
) -> int:
@ -186,7 +186,7 @@ class CSVLogger(FlaggingCallback):
writer.writerow(utils.sanitize_list_for_csv(headers))
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(log_filepath, "r", encoding="utf-8") as csvfile:
with open(log_filepath, encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@ -235,7 +235,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
self.info_filename = info_filename
self.separate_dirs = separate_dirs
def setup(self, components: List[IOComponent], flagging_dir: str):
def setup(self, components: list[IOComponent], flagging_dir: str):
"""
Params:
flagging_dir (str): local directory where the dataset is cloned,
@ -286,7 +286,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
except huggingface_hub.utils.EntryNotFoundError:
pass
def flag(self, flag_data: List[Any], flag_option: str = "") -> int:
def flag(self, flag_data: list[Any], flag_option: str = "") -> int:
if self.separate_dirs:
# JSONL files to support dataset preview on the Hub
unique_id = str(uuid.uuid4())
@ -312,7 +312,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
data_file: Path,
components_dir: Path,
path_in_repo: str | None,
flag_data: List[Any],
flag_data: list[Any],
flag_option: str = "",
) -> int:
# Deserialize components (write images/audio to files)
@ -368,7 +368,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
return sample_nb
@staticmethod
def _save_as_csv(data_file: Path, headers: List[str], row: List[Any]) -> int:
def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
"""Save data as CSV and return the sample name (row number)."""
is_new = not data_file.exists()
@ -386,7 +386,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
return sum(1 for _ in csv.reader(csvfile)) - 1
@staticmethod
def _save_as_jsonl(data_file: Path, headers: List[str], row: List[Any]) -> str:
def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
"""Save data as JSONL and return the sample name (uuid)."""
Path.mkdir(data_file.parent, parents=True, exist_ok=True)
with open(data_file, "w") as f:
@ -394,15 +394,15 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
return data_file.parent.name
def _deserialize_components(
self, data_dir: Path, flag_data: List[Any], flag_option: str = ""
) -> Tuple[Dict[Any, Any], List[Any]]:
self, data_dir: Path, flag_data: list[Any], flag_option: str = ""
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
# NOTE: not at root level to avoid circular imports
FILE_PREVIEW_TYPES = {gr.Audio: "Audio", gr.Image: "Image"}
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
# Generate the row corresponding to the flagged sample
features = {}
@ -418,8 +418,8 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
row.append(Path(deserialized).name)
# If component is eligible for a preview, add the URL of the file
if isinstance(component, tuple(FILE_PREVIEW_TYPES)): # type: ignore
for _component, _type in FILE_PREVIEW_TYPES.items():
if isinstance(component, tuple(file_preview_types)): # type: ignore
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
features[label + " file"] = {"_type": _type}
break

View File

@ -12,7 +12,7 @@ import tempfile
import threading
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Iterable
import matplotlib.pyplot as plt
import numpy as np
@ -36,9 +36,9 @@ set_documentation_group("helpers")
def create_examples(
examples: List[Any] | List[List[Any]] | str,
inputs: IOComponent | List[IOComponent],
outputs: IOComponent | List[IOComponent] | None = None,
examples: list[Any] | list[list[Any]] | str,
inputs: IOComponent | list[IOComponent],
outputs: IOComponent | list[IOComponent] | None = None,
fn: Callable | None = None,
cache_examples: bool = False,
examples_per_page: int = 10,
@ -85,9 +85,9 @@ class Examples:
def __init__(
self,
examples: List[Any] | List[List[Any]] | str,
inputs: IOComponent | List[IOComponent],
outputs: IOComponent | List[IOComponent] | None = None,
examples: list[Any] | list[list[Any]] | str,
inputs: IOComponent | list[IOComponent],
outputs: IOComponent | list[IOComponent] | None = None,
fn: Callable | None = None,
cache_examples: bool = False,
examples_per_page: int = 10,
@ -323,7 +323,7 @@ class Examples:
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)
async def load_from_cache(self, example_id: int) -> List[Any]:
async def load_from_cache(self, example_id: int) -> list[Any]:
"""Loads a particular cached example for the interface.
Parameters:
example_id: The id of the example to process (zero-indexed).
@ -396,7 +396,7 @@ class Progress(Iterable):
self.track_tqdm = track_tqdm
self._callback = _callback
self._event_id = _event_id
self.iterables: List[TrackedIterable] = []
self.iterables: list[TrackedIterable] = []
def __len__(self):
return self.iterables[-1].length
@ -431,7 +431,7 @@ class Progress(Iterable):
def __call__(
self,
progress: float | Tuple[int, int | None] | None,
progress: float | tuple[int, int | None] | None,
desc: str | None = None,
total: int | None = None,
unit: str = "steps",
@ -541,7 +541,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm):
if self._progress is not None:
self._progress.event_id = event_id
self._progress.tqdm(iterable, desc, _tqdm=self)
kwargs["file"] = open(os.devnull, "w")
kwargs["file"] = open(os.devnull, "w") # noqa: SIM115
self.__init__orig__(iterable, desc, *args, **kwargs)
def iter_tqdm(self):
@ -595,7 +595,7 @@ def create_tracker(root_blocks, event_id, fn, track_tqdm):
def special_args(
fn: Callable,
inputs: List[Any] | None = None,
inputs: list[Any] | None = None,
request: routes.Request | None = None,
event_data: EventData | None = None,
):
@ -632,9 +632,10 @@ def special_args(
event_data_index = i
if inputs is not None and event_data is not None:
inputs.insert(i, param.annotation(event_data.target, event_data._data))
elif param.default is not param.empty:
if inputs is not None and len(inputs) <= i:
inputs.insert(i, param.default)
elif (
param.default is not param.empty and inputs is not None and len(inputs) <= i
):
inputs.insert(i, param.default)
if inputs is not None:
while len(inputs) < len(positional_args):
i = len(inputs)
@ -696,12 +697,12 @@ def skip() -> dict:
@document()
def make_waveform(
audio: str | Tuple[int, np.ndarray],
audio: str | tuple[int, np.ndarray],
*,
bg_color: str = "#f3f4f6",
bg_image: str | None = None,
fg_alpha: float = 0.75,
bars_color: str | Tuple[str, str] = ("#fbbf24", "#ea580c"),
bars_color: str | tuple[str, str] = ("#fbbf24", "#ea580c"),
bar_count: int = 50,
bar_width: float = 0.6,
):
@ -728,13 +729,13 @@ def make_waveform(
duration = round(len(audio[1]) / audio[0], 4)
# Helper methods to create waveform
def hex_to_RGB(hex_str):
def hex_to_rgb(hex_str):
return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
def get_color_gradient(c1, c2, n):
assert n > 1
c1_rgb = np.array(hex_to_RGB(c1)) / 255
c2_rgb = np.array(hex_to_RGB(c2)) / 255
c1_rgb = np.array(hex_to_rgb(c1)) / 255
c2_rgb = np.array(hex_to_rgb(c2)) / 255
mix_pcts = [x / (n - 1) for x in range(n)]
rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
return [
@ -770,7 +771,7 @@ def make_waveform(
plt.axis("off")
plt.margins(x=0)
tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"}
savefig_kwargs: dict[str, Any] = {"bbox_inches": "tight"}
if bg_image is not None:
savefig_kwargs["transparent"] = True
else:

View File

@ -8,7 +8,7 @@ automatically added to a registry, which allows them to be easily referenced in
from __future__ import annotations
import warnings
from typing import Any, List, Optional, Tuple
from typing import Any, Optional
from gradio import components
@ -132,8 +132,8 @@ class CheckboxGroup(components.CheckboxGroup):
def __init__(
self,
choices: List[str],
default: List[str] | None = None,
choices: list[str],
default: list[str] | None = None,
type: str = "value",
label: Optional[str] = None,
optional: bool = False,
@ -168,7 +168,7 @@ class Radio(components.Radio):
def __init__(
self,
choices: List[str],
choices: list[str],
type: str = "value",
default: Optional[str] = None,
label: Optional[str] = None,
@ -202,7 +202,7 @@ class Dropdown(components.Dropdown):
def __init__(
self,
choices: List[str],
choices: list[str],
type: str = "value",
default: Optional[str] = None,
label: Optional[str] = None,
@ -236,7 +236,7 @@ class Image(components.Image):
def __init__(
self,
shape: Tuple[int, int] = None,
shape: tuple[int, int] = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
@ -366,12 +366,12 @@ class Dataframe(components.Dataframe):
def __init__(
self,
headers: Optional[List[str]] = None,
headers: Optional[list[str]] = None,
row_count: int = 3,
col_count: Optional[int] = 3,
datatype: str | List[str] = "str",
col_width: int | List[int] = None,
default: Optional[List[List[Any]]] = None,
datatype: str | list[str] = "str",
col_width: int | list[int] = None,
default: Optional[list[list[Any]]] = None,
type: str = "pandas",
label: Optional[str] = None,
optional: bool = False,
@ -413,7 +413,7 @@ class Timeseries(components.Timeseries):
def __init__(
self,
x: Optional[str] = None,
y: str | List[str] = None,
y: str | list[str] = None,
label: Optional[str] = None,
optional: bool = False,
):

View File

@ -8,10 +8,9 @@ from __future__ import annotations
import inspect
import json
import os
import re
import warnings
import weakref
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from typing import TYPE_CHECKING, Any, Callable
from gradio_client.documentation import document, set_documentation_group
@ -63,7 +62,7 @@ class Interface(Blocks):
instances: weakref.WeakSet = weakref.WeakSet()
@classmethod
def get_instances(cls) -> List[Interface]:
def get_instances(cls) -> list[Interface]:
"""
:return: list of all current instances.
"""
@ -119,9 +118,9 @@ class Interface(Blocks):
def __init__(
self,
fn: Callable,
inputs: str | IOComponent | List[str | IOComponent] | None,
outputs: str | IOComponent | List[str | IOComponent] | None,
examples: List[Any] | List[List[Any]] | str | None = None,
inputs: str | IOComponent | list[str | IOComponent] | None,
outputs: str | IOComponent | list[str | IOComponent] | None,
examples: list[Any] | list[list[Any]] | str | None = None,
cache_examples: bool | None = None,
examples_per_page: int = 10,
live: bool = False,
@ -134,7 +133,7 @@ class Interface(Blocks):
theme: Theme | str | None = None,
css: str | None = None,
allow_flagging: str | None = None,
flagging_options: List[str] | List[Tuple[str, str]] | None = None,
flagging_options: list[str] | list[tuple[str, str]] | None = None,
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback = CSVLogger(),
analytics_enabled: bool | None = None,
@ -287,17 +286,11 @@ class Interface(Blocks):
self.live = live
self.title = title
CLEANER = re.compile("<.*?>")
def clean_html(raw_html):
cleantext = re.sub(CLEANER, "", raw_html)
return cleantext
md = utils.get_markdown_parser()
simple_description = None
simple_description: str | None = None
if description is not None:
description = md.render(description)
simple_description = clean_html(description)
simple_description = utils.remove_html_tags(description)
self.simple_description = simple_description
self.description = description
if article is not None:
@ -466,19 +459,19 @@ class Interface(Blocks):
if self.description:
Markdown(self.description)
def render_flag_btns(self) -> List[Button]:
def render_flag_btns(self) -> list[Button]:
return [Button(label) for label, _ in self.flagging_options]
def render_input_column(
self,
) -> Tuple[
) -> tuple[
Button | None,
Button | None,
Button | None,
List[Button] | None,
list[Button] | None,
Column,
Column | None,
List[Interpretation] | None,
list[Interpretation] | None,
]:
submit_btn, clear_btn, stop_btn, flag_btns = None, None, None, None
interpret_component_column, interpretation_set = None, None
@ -531,7 +524,7 @@ class Interface(Blocks):
def render_output_column(
self,
submit_btn_in: Button | None,
) -> Tuple[Button | None, Button | None, Button | None, List | None, Button | None]:
) -> tuple[Button | None, Button | None, Button | None, list | None, Button | None]:
submit_btn = submit_btn_in
interpretation_btn, clear_btn, flag_btns, stop_btn = None, None, None, None
@ -699,7 +692,7 @@ class Interface(Blocks):
def attach_interpretation_events(
self,
interpretation_btn: Button | None,
interpretation_set: List[Interpretation] | None,
interpretation_set: list[Interpretation] | None,
input_component_column: Column | None,
interpret_component_column: Column | None,
):
@ -711,53 +704,58 @@ class Interface(Blocks):
preprocess=False,
)
def attach_flagging_events(self, flag_btns: List[Button] | None, clear_btn: Button):
if flag_btns:
if self.interface_type in [
def attach_flagging_events(self, flag_btns: list[Button] | None, clear_btn: Button):
if not (
flag_btns
and self.interface_type
in (
InterfaceTypes.STANDARD,
InterfaceTypes.OUTPUT_ONLY,
InterfaceTypes.UNIFIED,
]:
if self.allow_flagging == "auto":
flag_method = FlagMethod(
self.flagging_callback, "", "", visual_feedback=False
)
flag_btns[0].click( # flag_btns[0] is just the "Submit" button
flag_method,
inputs=self.input_components,
outputs=None,
preprocess=False,
queue=False,
)
return
)
):
return
if self.interface_type == InterfaceTypes.UNIFIED:
flag_components = self.input_components
else:
flag_components = self.input_components + self.output_components
if self.allow_flagging == "auto":
flag_method = FlagMethod(
self.flagging_callback, "", "", visual_feedback=False
)
flag_btns[0].click( # flag_btns[0] is just the "Submit" button
flag_method,
inputs=self.input_components,
outputs=None,
preprocess=False,
queue=False,
)
return
for flag_btn, (label, value) in zip(flag_btns, self.flagging_options):
assert isinstance(value, str)
flag_method = FlagMethod(self.flagging_callback, label, value)
flag_btn.click(
lambda: Button.update(value="Saving...", interactive=False),
None,
flag_btn,
queue=False,
)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=flag_btn,
preprocess=False,
queue=False,
)
clear_btn.click(
flag_method.reset,
None,
flag_btn,
queue=False,
)
if self.interface_type == InterfaceTypes.UNIFIED:
flag_components = self.input_components
else:
flag_components = self.input_components + self.output_components
for flag_btn, (label, value) in zip(flag_btns, self.flagging_options):
assert isinstance(value, str)
flag_method = FlagMethod(self.flagging_callback, label, value)
flag_btn.click(
lambda: Button.update(value="Saving...", interactive=False),
None,
flag_btn,
queue=False,
)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=flag_btn,
preprocess=False,
queue=False,
)
clear_btn.click(
flag_method.reset,
None,
flag_btn,
queue=False,
)
def render_examples(self):
if self.examples:
@ -798,7 +796,7 @@ class Interface(Blocks):
Column.update(visible=True),
]
async def interpret(self, raw_input: List[Any]) -> List[Any]:
async def interpret(self, raw_input: list[Any]) -> list[Any]:
return [
{"original": raw_value, "interpretation": interpretation}
for interpretation, raw_value in zip(
@ -823,8 +821,8 @@ class TabbedInterface(Blocks):
def __init__(
self,
interface_list: List[Interface],
tab_names: List[str] | None = None,
interface_list: list[Interface],
tab_names: list[str] | None = None,
title: str | None = None,
theme: Theme | None = None,
analytics_enabled: bool | None = None,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import copy
import math
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any
import numpy as np
from gradio_client import utils as client_utils
@ -28,8 +28,8 @@ class Interpretable(ABC): # noqa: B024
pass
def get_interpretation_scores(
self, x: Any, neighbors: List[Any] | None, scores: List[float], **kwargs
) -> List:
self, x: Any, neighbors: list[Any] | None, scores: list[float], **kwargs
) -> list:
"""
Arrange the output values from the neighbors into interpretation scores for the interface to render.
Parameters:
@ -44,7 +44,7 @@ class Interpretable(ABC): # noqa: B024
class TokenInterpretable(Interpretable, ABC):
@abstractmethod
def tokenize(self, x: Any) -> Tuple[List, List, None]:
def tokenize(self, x: Any) -> tuple[list, list, None]:
"""
Interprets an input data point x by splitting it into a list of tokens (e.g
a string into words or an image into super-pixels).
@ -52,13 +52,13 @@ class TokenInterpretable(Interpretable, ABC):
return [], [], None
@abstractmethod
def get_masked_inputs(self, tokens: List, binary_mask_matrix: List[List]) -> List:
def get_masked_inputs(self, tokens: list, binary_mask_matrix: list[list]) -> list:
return []
class NeighborInterpretable(Interpretable, ABC):
@abstractmethod
def get_interpretation_neighbors(self, x: Any) -> Tuple[List, Dict]:
def get_interpretation_neighbors(self, x: Any) -> tuple[list, dict]:
"""
Generates values similar to input to be used to interpret the significance of the input in the final output.
Parameters:
@ -70,7 +70,7 @@ class NeighborInterpretable(Interpretable, ABC):
return [], {}
async def run_interpret(interface: Interface, raw_input: List):
async def run_interpret(interface: Interface, raw_input: list):
"""
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
interpretation for a certain set of UI component types, as well as the custom interpretation case.
@ -265,12 +265,12 @@ def diff(original: Any, perturbed: Any) -> int | float:
try: # try computing numerical difference
score = float(original) - float(perturbed)
except ValueError: # otherwise, look at strict difference in label
score = int(not (original == perturbed))
score = int(original != perturbed)
return score
def quantify_difference_in_label(
interface: Interface, original_output: List, perturbed_output: List
interface: Interface, original_output: list, perturbed_output: list
) -> int | float:
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
@ -300,7 +300,7 @@ def quantify_difference_in_label(
def get_regression_or_classification_value(
interface: Interface, original_output: List, perturbed_output: List
interface: Interface, original_output: list, perturbed_output: list
) -> int | float:
"""Used to combine regression/classification for Shap interpretation method."""
output_component = interface.output_components[0]

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import warnings
from typing import Type
from gradio_client.documentation import document, set_documentation_group
@ -216,7 +215,7 @@ class Tab(BlockContext, Selectable):
**super(BlockContext, self).get_config(),
}
def get_expected_parent(self) -> Type[Tabs]:
def get_expected_parent(self) -> type[Tabs]:
return Tabs
def get_block_name(self):

View File

@ -9,7 +9,7 @@ import socket
import threading
import time
import warnings
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import requests
import uvicorn
@ -89,7 +89,7 @@ def start_server(
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
) -> Tuple[str, int, str, App, Server]:
) -> tuple[str, int, str, App, Server]:
"""Launches a local server running the provided Interface
Parameters:
blocks: The Blocks object to run on the server

View File

@ -8,7 +8,7 @@ automatically added to a registry, which allows them to be easily referenced in
from __future__ import annotations
import warnings
from typing import Dict, List, Optional
from typing import Optional
from gradio import components
@ -109,7 +109,7 @@ class Dataframe(components.Dataframe):
def __init__(
self,
headers: Optional[List[str]] = None,
headers: Optional[list[str]] = None,
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
overflow_row_behaviour: str = "paginate",
@ -145,7 +145,7 @@ class Timeseries(components.Timeseries):
"""
def __init__(
self, x: str = None, y: str | List[str] = None, label: Optional[str] = None
self, x: str = None, y: str | list[str] = None, label: Optional[str] = None
):
"""
Parameters:
@ -227,7 +227,7 @@ class HighlightedText(components.HighlightedText):
def __init__(
self,
color_map: Dict[str, str] = None,
color_map: dict[str, str] = None,
label: Optional[str] = None,
show_legend: bool = False,
):
@ -281,7 +281,7 @@ class Carousel(components.Carousel):
def __init__(
self,
components: components.Component | List[components.Component],
components: components.Component | list[components.Component],
label: Optional[str] = None,
):
"""

View File

@ -3,7 +3,7 @@ please use the `gr.Interface.from_pipeline()` function."""
from __future__ import annotations
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from gradio import components
@ -11,7 +11,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from transformers import pipelines
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> Dict:
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface

View File

@ -8,7 +8,6 @@ import tempfile
import warnings
from io import BytesIO
from pathlib import Path
from typing import Dict
import numpy as np
from ffmpy import FFmpeg, FFprobe, FFRuntimeError
@ -25,7 +24,7 @@ with warnings.catch_warnings():
#########################
def to_binary(x: str | Dict) -> bytes:
def to_binary(x: str | dict) -> bytes:
"""Converts a base64 string or dictionary to a binary string that can be sent in a POST."""
if isinstance(x, dict):
if x.get("data"):
@ -362,10 +361,7 @@ def _convert(image, dtype, force_copy=False, uniform=False):
image = np.asarray(image)
dtypeobj_in = image.dtype
if dtype is np.floating:
dtypeobj_out = np.dtype("float64")
else:
dtypeobj_out = np.dtype(dtype)
dtypeobj_out = np.dtype("float64") if dtype is np.floating else np.dtype(dtype)
dtype_in = dtypeobj_in.type
dtype_out = dtypeobj_out.type
kind_in = dtypeobj_in.kind

View File

@ -6,7 +6,7 @@ import sys
import time
from asyncio import TimeoutError as AsyncTimeOutError
from collections import deque
from typing import Any, Deque, Dict, List, Tuple
from typing import Any, Deque
import fastapi
import httpx
@ -44,14 +44,14 @@ class Queue:
concurrency_count: int,
update_intervals: float,
max_size: int | None,
blocks_dependencies: List,
blocks_dependencies: list,
):
self.event_queue: Deque[Event] = deque()
self.events_pending_reconnection = []
self.stopped = False
self.max_thread_count = concurrency_count
self.update_intervals = update_intervals
self.active_jobs: List[None | List[Event]] = [None] * concurrency_count
self.active_jobs: list[None | list[Event]] = [None] * concurrency_count
self.delete_lock = asyncio.Lock()
self.server_path = None
self.duration_history_total = 0
@ -96,7 +96,7 @@ class Queue:
count += 1
return count
def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]:
def get_events_in_batch(self) -> tuple[list[Event] | None, bool]:
if not (self.event_queue):
return None, False
@ -158,7 +158,7 @@ class Queue:
def set_progress(
self,
event_id: str,
iterables: List[TrackedIterable] | None,
iterables: list[TrackedIterable] | None,
):
if iterables is None:
return
@ -167,7 +167,7 @@ class Queue:
continue
for evt in job:
if evt._id == event_id:
progress_data: List[ProgressUnit] = []
progress_data: list[ProgressUnit] = []
for iterable in iterables:
progress_unit = ProgressUnit(
index=iterable.index,
@ -303,7 +303,7 @@ class Queue:
queue_eta=self.queue_duration,
)
def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
def get_request_params(self, websocket: fastapi.WebSocket) -> dict[str, Any]:
return {
"url": str(websocket.url),
"headers": dict(websocket.headers),
@ -312,7 +312,7 @@ class Queue:
"client": {"host": websocket.client.host, "port": websocket.client.port}, # type: ignore
}
async def call_prediction(self, events: List[Event], batch: bool):
async def call_prediction(self, events: list[Event], batch: bool):
data = events[0].data
assert data is not None, "No event data"
token = events[0].token
@ -340,8 +340,8 @@ class Queue:
)
return response
async def process_events(self, events: List[Event], batch: bool) -> None:
awake_events: List[Event] = []
async def process_events(self, events: list[Event], batch: bool) -> None:
awake_events: list[Event] = []
try:
for event in events:
client_awake = await self.gather_event_data(event)
@ -438,7 +438,7 @@ class Queue:
# to start "from scratch"
await self.reset_iterators(event.session_hash, event.fn_index)
async def send_message(self, event, data: Dict, timeout: float | int = 1) -> bool:
async def send_message(self, event, data: dict, timeout: float | int = 1) -> bool:
try:
await asyncio.wait_for(
event.websocket.send_json(data=data), timeout=timeout
@ -448,7 +448,7 @@ class Queue:
await self.clean_event(event)
return False
async def get_message(self, event, timeout=5) -> Tuple[PredictBody | None, bool]:
async def get_message(self, event, timeout=5) -> tuple[PredictBody | None, bool]:
try:
data = await asyncio.wait_for(
event.websocket.receive_json(), timeout=timeout

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import os
import re
import stat
from typing import Dict, NamedTuple
from typing import NamedTuple
from urllib.parse import quote
import aiofiles
@ -36,7 +36,7 @@ class OpenRange(NamedTuple):
def clamp(self, start: int, end: int) -> ClosedRange:
begin = max(self.start, start)
end = min((x for x in (self.end, end) if x))
end = min(x for x in (self.end, end) if x)
begin = min(begin, end)
end = max(begin, end)
@ -51,7 +51,7 @@ class RangedFileResponse(Response):
self,
path: str | os.PathLike,
range: OpenRange,
headers: Dict[str, str] | None = None,
headers: dict[str, str] | None = None,
media_type: str | None = None,
filename: str | None = None,
stat_result: os.stat_result | None = None,

View File

@ -18,10 +18,7 @@ def run_in_reload_mode():
args = sys.argv[1:]
if len(args) == 0:
raise ValueError("No file specified.")
if len(args) == 1:
demo_name = "demo"
else:
demo_name = args[1]
demo_name = "demo" if len(args) == 1 else args[1]
original_path = args[0]
abs_original_path = utils.abspath(original_path)

View File

@ -309,10 +309,8 @@ class App(FastAPI):
)
abs_path = utils.abspath(path_or_url)
in_blocklist = any(
(
utils.is_in_or_equal(abs_path, blocked_path)
for blocked_path in blocks.blocked_paths
)
utils.is_in_or_equal(abs_path, blocked_path)
for blocked_path in blocks.blocked_paths
)
if in_blocklist:
raise HTTPException(403, f"File not allowed: {path_or_url}.")
@ -320,10 +318,8 @@ class App(FastAPI):
in_app_dir = utils.abspath(app.cwd) in abs_path.parents
created_by_app = str(abs_path) in set().union(*blocks.temp_file_sets)
in_file_dir = any(
(
utils.is_in_or_equal(abs_path, allowed_path)
for allowed_path in blocks.allowed_paths
)
utils.is_in_or_equal(abs_path, allowed_path)
for allowed_path in blocks.allowed_paths
)
was_uploaded = utils.abspath(app.uploaded_file_dir) in abs_path.parents
@ -464,14 +460,15 @@ class App(FastAPI):
)
else:
fn_index_inferred = body.fn_index
if not app.get_blocks().api_open and app.get_blocks().queue_enabled_for_fn(
fn_index_inferred
if (
not app.get_blocks().api_open
and app.get_blocks().queue_enabled_for_fn(fn_index_inferred)
and f"Bearer {app.queue_token}" != request.headers.get("Authorization")
):
if f"Bearer {app.queue_token}" != request.headers.get("Authorization"):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authorized to skip the queue",
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authorized to skip the queue",
)
# If this fn_index cancels jobs, then the only input we need is the
# current session hash

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import typing
from typing import Any, Callable, Tuple
from typing import Any, Callable
import numpy as np
from PIL.Image import Image
@ -55,7 +54,7 @@ class Webcam(components.Image):
self,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] | None = None,
shape: tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "webcam",
@ -102,7 +101,7 @@ class Sketchpad(components.Image):
self,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] = (28, 28),
shape: tuple[int, int] = (28, 28),
image_mode: str = "L",
invert_colors: bool = True,
source: str = "canvas",
@ -149,7 +148,7 @@ class Paint(components.Image):
self,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] | None = None,
shape: tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "canvas",
@ -196,7 +195,7 @@ class ImageMask(components.Image):
self,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] | None = None,
shape: tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
@ -243,7 +242,7 @@ class ImagePaint(components.Image):
self,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] | None = None,
shape: tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
@ -290,7 +289,7 @@ class Pil(components.Image):
self,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] | None = None,
shape: tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
@ -372,7 +371,7 @@ class Microphone(components.Audio):
def __init__(
self,
value: str | Tuple[int, np.ndarray] | Callable | None = None,
value: str | tuple[int, np.ndarray] | Callable | None = None,
*,
source: str = "microphone",
type: str = "numpy",
@ -407,7 +406,7 @@ class Files(components.File):
def __init__(
self,
value: str | typing.List[str] | Callable | None = None,
value: str | list[str] | Callable | None = None,
*,
file_count: str = "multiple",
type: str = "file",
@ -440,12 +439,12 @@ class Numpy(components.Dataframe):
def __init__(
self,
value: typing.List[typing.List[Any]] | Callable | None = None,
value: list[list[Any]] | Callable | None = None,
*,
headers: typing.List[str] | None = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: int | Tuple[int, str] | None = None,
datatype: str | typing.List[str] = "str",
headers: list[str] | None = None,
row_count: int | tuple[int, str] = (1, "dynamic"),
col_count: int | tuple[int, str] | None = None,
datatype: str | list[str] = "str",
type: str = "numpy",
max_rows: int | None = 20,
max_cols: int | None = None,
@ -487,12 +486,12 @@ class Matrix(components.Dataframe):
def __init__(
self,
value: typing.List[typing.List[Any]] | Callable | None = None,
value: list[list[Any]] | Callable | None = None,
*,
headers: typing.List[str] | None = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: int | Tuple[int, str] | None = None,
datatype: str | typing.List[str] = "str",
headers: list[str] | None = None,
row_count: int | tuple[int, str] = (1, "dynamic"),
col_count: int | tuple[int, str] | None = None,
datatype: str | list[str] = "str",
type: str = "array",
max_rows: int | None = 20,
max_cols: int | None = None,
@ -534,12 +533,12 @@ class List(components.Dataframe):
def __init__(
self,
value: typing.List[typing.List[Any]] | Callable | None = None,
value: list[list[Any]] | Callable | None = None,
*,
headers: typing.List[str] | None = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: int | Tuple[int, str] = 1,
datatype: str | typing.List[str] = "str",
headers: list[str] | None = None,
row_count: int | tuple[int, str] = (1, "dynamic"),
col_count: int | tuple[int, str] = 1,
datatype: str | list[str] = "str",
type: str = "array",
max_rows: int | None = 20,
max_cols: int | None = None,

View File

@ -5,7 +5,7 @@ import re
import tempfile
import textwrap
from pathlib import Path
from typing import Dict, Iterable
from typing import Iterable
import huggingface_hub
import requests
@ -108,17 +108,17 @@ class ThemeClass:
return schema
@classmethod
def load(cls, path: str) -> "ThemeClass":
def load(cls, path: str) -> ThemeClass:
"""Load a theme from a json file.
Parameters:
path: The filepath to read.
"""
theme = json.load(open(path), object_hook=fonts.as_font)
return cls.from_dict(theme)
with open(path) as fp:
return cls.from_dict(json.load(fp, object_hook=fonts.as_font))
@classmethod
def from_dict(cls, theme: Dict[str, Dict[str, str]]) -> "ThemeClass":
def from_dict(cls, theme: dict[str, dict[str, str]]) -> ThemeClass:
"""Create a theme instance from a dictionary representation.
Parameters:
@ -142,8 +142,7 @@ class ThemeClass:
Parameters:
filename: The path to write the theme too
"""
as_dict = self.to_dict()
json.dump(as_dict, open(Path(filename), "w"), cls=fonts.FontEncoder)
Path(filename).write_text(json.dumps(self.to_dict(), cls=fonts.FontEncoder))
@classmethod
def from_hub(cls, repo_name: str, hf_token: str | None = None):
@ -248,10 +247,7 @@ class ThemeClass:
# If no version, set the version to next patch release
if not version:
if space_exists:
version = self._get_next_version(space_info)
else:
version = "0.0.1"
version = self._get_next_version(space_info) if space_exists else "0.0.1"
else:
_ = semver.Version(version)
@ -279,7 +275,7 @@ class ThemeClass:
)
readme_file.write(textwrap.dedent(readme_content))
with tempfile.NamedTemporaryFile(mode="w", delete=False) as app_file:
contents = open(str(Path(__file__).parent / "app.py")).read()
contents = (Path(__file__).parent / "app.py").read_text()
contents = re.sub(
r"theme=gr.themes.Default\(\)",
f"theme='{space_id}'",

View File

@ -76,7 +76,11 @@ css = """
}
"""
with gr.Blocks(theme=gr.themes.Base(), css=css, title="Gradio Theme Builder") as demo:
with gr.Blocks( # noqa: SIM117
theme=gr.themes.Base(),
css=css,
title="Gradio Theme Builder",
) as demo:
with gr.Row():
with gr.Column(scale=1, elem_id="controls", min_width=400):
with gr.Row():

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List
import huggingface_hub
import semantic_version
@ -17,7 +16,7 @@ class ThemeAsset:
self.version = semver.Version(self.filename.split("@")[1].replace(".json", ""))
def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> List[ThemeAsset]:
def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> list[ThemeAsset]:
if "gradio-theme" not in getattr(space_info, "tags", []):
raise ValueError(f"{space_info.id} is not a valid gradio-theme space!")
@ -29,7 +28,7 @@ def get_theme_assets(space_info: huggingface_hub.hf_api.SpaceInfo) -> List[Theme
def get_matching_version(
assets: List[ThemeAsset], expression: str | None
assets: list[ThemeAsset], expression: str | None
) -> ThemeAsset | None:
expression = expression or "*"

View File

@ -27,11 +27,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Tuple,
Type,
TypeVar,
Union,
)
@ -104,10 +100,10 @@ def get_local_ip_address() -> str:
return ip_address
def initiated_analytics(data: Dict[str, Any]) -> None:
def initiated_analytics(data: dict[str, Any]) -> None:
data.update({"ip_address": get_local_ip_address()})
def initiated_analytics_thread(data: Dict[str, Any]) -> None:
def initiated_analytics_thread(data: dict[str, Any]) -> None:
try:
requests.post(
f"{analytics_url}gradio-initiated-analytics/", data=data, timeout=5
@ -118,10 +114,10 @@ def initiated_analytics(data: Dict[str, Any]) -> None:
threading.Thread(target=initiated_analytics_thread, args=(data,)).start()
def launch_analytics(data: Dict[str, Any]) -> None:
def launch_analytics(data: dict[str, Any]) -> None:
data.update({"ip_address": get_local_ip_address()})
def launch_analytics_thread(data: Dict[str, Any]) -> None:
def launch_analytics_thread(data: dict[str, Any]) -> None:
try:
requests.post(
f"{analytics_url}gradio-launched-analytics/", data=data, timeout=5
@ -132,7 +128,7 @@ def launch_analytics(data: Dict[str, Any]) -> None:
threading.Thread(target=launch_analytics_thread, args=(data,)).start()
def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None:
def launched_telemetry(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
blocks_telemetry, inputs_telemetry, outputs_telemetry, targets_telemetry = (
[],
[],
@ -180,7 +176,7 @@ def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None:
data.update(additional_data)
data.update({"ip_address": get_local_ip_address()})
def launched_telemtry_thread(data: Dict[str, Any]) -> None:
def launched_telemtry_thread(data: dict[str, Any]) -> None:
try:
requests.post(
f"{analytics_url}gradio-launched-telemetry/", data=data, timeout=5
@ -191,10 +187,10 @@ def launched_telemetry(blocks: gradio.Blocks, data: Dict[str, Any]) -> None:
threading.Thread(target=launched_telemtry_thread, args=(data,)).start()
def integration_analytics(data: Dict[str, Any]) -> None:
def integration_analytics(data: dict[str, Any]) -> None:
data.update({"ip_address": get_local_ip_address()})
def integration_analytics_thread(data: Dict[str, Any]) -> None:
def integration_analytics_thread(data: dict[str, Any]) -> None:
try:
requests.post(
f"{analytics_url}gradio-integration-analytics/", data=data, timeout=5
@ -213,7 +209,7 @@ def error_analytics(message: str) -> None:
"""
data = {"ip_address": get_local_ip_address(), "error": message}
def error_analytics_thread(data: Dict[str, Any]) -> None:
def error_analytics_thread(data: dict[str, Any]) -> None:
try:
requests.post(
f"{analytics_url}gradio-error-analytics/", data=data, timeout=5
@ -320,7 +316,7 @@ def launch_counter() -> None:
pass
def get_default_args(func: Callable) -> List[Any]:
def get_default_args(func: Callable) -> list[Any]:
signature = inspect.signature(func)
return [
v.default if v.default is not inspect.Parameter.empty else None
@ -329,7 +325,7 @@ def get_default_args(func: Callable) -> List[Any]:
def assert_configs_are_equivalent_besides_ids(
config1: Dict, config2: Dict, root_keys: Tuple = ("mode",)
config1: dict, config2: dict, root_keys: tuple = ("mode",)
):
"""Allows you to test if two different Blocks configs produce the same demo.
@ -382,7 +378,7 @@ def assert_configs_are_equivalent_besides_ids(
return True
def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]):
def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]):
if len(ner_groups) == 0:
return [(input_string, None)]
@ -400,7 +396,7 @@ def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]):
return output
def delete_none(_dict: Dict, skip_value: bool = False) -> Dict:
def delete_none(_dict: dict, skip_value: bool = False) -> dict:
"""
Delete keys whose values are None from a dictionary
"""
@ -412,14 +408,14 @@ def delete_none(_dict: Dict, skip_value: bool = False) -> Dict:
return _dict
def resolve_singleton(_list: List[Any] | Any) -> Any:
def resolve_singleton(_list: list[Any] | Any) -> Any:
if len(_list) == 1:
return _list[0]
else:
return _list
def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockContext]:
def component_or_layout_class(cls_name: str) -> type[Component] | type[BlockContext]:
"""
Returns the component, template, or layout class with the given class name, or
raises a ValueError if not found.
@ -536,9 +532,9 @@ class AsyncRequest:
method: Method,
url: str,
*,
validation_model: Type[BaseModel] | None = None,
validation_model: type[BaseModel] | None = None,
validation_function: Union[Callable, None] = None,
exception_type: Type[Exception] = Exception,
exception_type: type[Exception] = Exception,
raise_for_status: bool = False,
client: httpx.AsyncClient | None = None,
**kwargs,
@ -565,7 +561,7 @@ class AsyncRequest:
self._request = self._create_request(method, url, **kwargs)
self.client_ = client or self.client
def __await__(self) -> Generator[None, Any, "AsyncRequest"]:
def __await__(self) -> Generator[None, Any, AsyncRequest]:
"""
Wrap Request's __await__ magic function to create request calls which are executed in one line.
"""
@ -740,7 +736,7 @@ def sanitize_value_for_csv(value: str | Number) -> str | Number:
return value
def sanitize_list_for_csv(values: List[Any]) -> List[Any]:
def sanitize_list_for_csv(values: list[Any]) -> list[Any]:
"""
Sanitizes a list of values (or a list of list of values) that is being written to a
CSV file to prevent CSV injection attacks.
@ -756,7 +752,7 @@ def sanitize_list_for_csv(values: List[Any]) -> List[Any]:
return sanitized_values
def append_unique_suffix(name: str, list_of_names: List[str]):
def append_unique_suffix(name: str, list_of_names: list[str]):
"""Appends a numerical suffix to `name` so that it does not appear in `list_of_names`."""
set_of_names: set[str] = set(list_of_names) # for O(1) lookup
if name not in set_of_names:
@ -815,8 +811,8 @@ def set_task_name(task, session_hash: str, fn_index: int, batch: bool):
def get_cancel_function(
dependencies: List[Dict[str, Any]]
) -> Tuple[Callable, List[int]]:
dependencies: list[dict[str, Any]]
) -> tuple[Callable, list[int]]:
fn_to_comp = {}
for dep in dependencies:
if Context.root_block:
@ -858,7 +854,7 @@ def is_special_typed_parameter(name, parameter_types):
return is_request or is_event_data
def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool):
def check_function_inputs_match(fn: Callable, inputs: list, inputs_as_dict: bool):
"""
Checks if the input component set matches the function
Returns: None if valid, a string error message if mismatch
@ -878,9 +874,8 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
max_args += 1
elif param.kind == param.VAR_POSITIONAL:
max_args = infinity
elif param.kind == param.KEYWORD_ONLY:
if not has_default:
return f"Keyword-only args must have default values for function {fn}"
elif param.kind == param.KEYWORD_ONLY and not has_default:
return f"Keyword-only args must have default values for function {fn}"
arg_count = 1 if inputs_as_dict else len(inputs)
if min_args == max_args and max_args != arg_count:
warnings.warn(
@ -918,15 +913,15 @@ def tex2svg(formula, *args):
with MatplotlibBackendMananger():
import matplotlib.pyplot as plt
FONTSIZE = 20
DPI = 300
fontsize = 20
dpi = 300
plt.rc("mathtext", fontset="cm")
fig = plt.figure(figsize=(0.01, 0.01))
fig.text(0, 0, rf"${formula}$", fontsize=FONTSIZE)
fig.text(0, 0, rf"${formula}$", fontsize=fontsize)
output = BytesIO()
fig.savefig(
output,
dpi=DPI,
dpi=dpi,
transparent=True,
format="svg",
bbox_inches="tight",
@ -942,7 +937,7 @@ def tex2svg(formula, *args):
height_match = re.search(r'height="([\d.]+)pt"', svg_code)
if height_match:
height = float(height_match.group(1))
new_height = height / FONTSIZE # conversion from pt to em
new_height = height / fontsize # conversion from pt to em
svg_code = re.sub(
r'height="[\d.]+pt"', f'height="{new_height}em"', svg_code
)
@ -1016,7 +1011,7 @@ def get_serializer_name(block: Block) -> str | None:
def highlight_code(code, name, attrs):
try:
lexer = get_lexer_by_name(name)
except:
except Exception:
lexer = get_lexer_by_name("text")
formatter = HtmlFormatter()
@ -1047,3 +1042,10 @@ def get_markdown_parser() -> MarkdownIt:
md.add_render_rule("link_open", render_blank_link)
return md
HTML_TAG_RE = re.compile("<.*?>")
def remove_html_tags(raw_html: str | None) -> str:
return re.sub(HTML_TAG_RE, "", raw_html or "")

View File

@ -67,10 +67,9 @@ extend-select = [
"B",
"C",
"I",
# Formatting-related UP rules
"UP030",
"UP031",
"UP032",
"N",
"SIM",
"UP",
]
ignore = [
"C901", # function is too complex (TODO: un-ignore this)
@ -79,6 +78,9 @@ ignore = [
"B017", # pytest.raises considered evil
"B028", # explicit stacklevel for warnings
"E501", # from scripts/lint_backend.sh
"SIM105", # contextlib.suppress (has a performance cost)
"SIM117", # multiple nested with blocks (doesn't look good with gr.Row etc)
"UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them)
]
[tool.ruff.per-file-ignores]
@ -91,3 +93,6 @@ ignore = [
"gradio/__init__.py" = [
"F401", # "Imported but unused" (TODO: it would be better to be explicit and use __all__)
]
"gradio/routes.py" = [
"UP006", # Pydantic on Python 3.7 requires old-style type annotations (TODO: drop when Python 3.7 is dropped)
]

View File

@ -197,7 +197,7 @@ respx==0.19.2
# via -r requirements.in
rfc3986[idna2008]==1.5.0
# via httpx
ruff==0.0.263
ruff==0.0.264
# via -r requirements.in
s3transfer==0.6.0
# via boto3

View File

@ -185,7 +185,7 @@ requests==2.28.1
# transformers
respx==0.19.2
# via -r requirements.in
ruff==0.0.263
ruff==0.0.264
# via -r requirements.in
rfc3986[idna2008]==1.5.0
# via httpx

View File

@ -279,8 +279,7 @@ class TestBlocksMethods:
return 42
def generator_function():
for index in range(10):
yield index
yield from range(10)
with gr.Blocks() as demo:
@ -670,8 +669,7 @@ class TestCallFunction:
@pytest.mark.asyncio
async def test_call_generator(self):
def generator(x):
for i in range(x):
yield i
yield from range(x)
with gr.Blocks() as demo:
inp = gr.Number()
@ -1368,11 +1366,10 @@ class TestProgressBar:
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if msg["msg"] == "progress":
if msg[
"progress_data"
]: # Ignore empty lists which sometimes appear on Windows
progress_updates.append(msg["progress_data"])
if (
msg["msg"] == "progress" and msg["progress_data"]
): # Ignore empty lists which sometimes appear on Windows
progress_updates.append(msg["progress_data"])
if msg["msg"] == "process_completed":
completed = True
break

View File

@ -1489,7 +1489,8 @@ class TestLabel:
label_output = gr.Label()
label = label_output.postprocess(y)
assert label == {"label": "happy"}
assert json.load(open(label_output.deserialize(label))) == label
with open(label_output.deserialize(label)) as f:
assert json.load(f) == label
y = {3: 0.7, 1: 0.2, 0: 0.1}
label = label_output.postprocess(y)

View File

@ -11,7 +11,7 @@ from gradio_client import media_data
import gradio as gr
from gradio.context import Context
from gradio.exceptions import InvalidApiName
from gradio.exceptions import InvalidApiNameError
from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples
"""
@ -190,16 +190,16 @@ class TestLoadInterface:
def test_sentiment_model(self):
io = gr.load("models/distilbert-base-uncased-finetuned-sst-2-english")
try:
output = io("I am happy, I love you")
assert json.load(open(output))["label"] == "POSITIVE"
with open(io("I am happy, I love you")) as f:
assert json.load(f)["label"] == "POSITIVE"
except TooManyRequestsError:
pass
def test_image_classification_model(self):
io = gr.Blocks.load(name="models/google/vit-base-patch16-224")
try:
output = io("gradio/test_data/lion.jpg")
assert json.load(open(output))["label"] == "lion"
with open(io("gradio/test_data/lion.jpg")) as f:
assert json.load(f)["label"] == "lion"
except TooManyRequestsError:
pass
@ -214,9 +214,9 @@ class TestLoadInterface:
def test_numerical_to_label_space(self):
io = gr.load("spaces/abidlabs/titanic-survival")
try:
output = io("male", 77, 10)
assert json.load(open(output))["label"] == "Perishes"
assert io.theme.name == "soft"
with open(io("male", 77, 10)) as f:
assert json.load(f)["label"] == "Perishes"
except TooManyRequestsError:
pass
@ -472,7 +472,7 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme):
def test_raise_value_error_when_api_name_invalid():
with pytest.raises(InvalidApiName):
with pytest.raises(InvalidApiNameError):
demo = gr.Blocks.load(name="spaces/gradio/hello_world")
demo("freddy", api_name="route does not exist")

View File

@ -26,7 +26,7 @@ class TestDefault:
"interpretation"
]
assert interpretation[0][1] > 0 # Checks to see if the first word has >0 score.
assert 0 == interpretation[-1][1] # Checks to see if the last word has 0 score.
assert interpretation[-1][1] == 0 # Checks to see if the last word has 0 score.
class TestShapley:
@ -92,9 +92,9 @@ class TestHelperMethods:
def test_quantify_difference_with_label(self):
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"])
assert -7 == diff
assert diff == -7
diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"])
assert -100 == diff
assert diff == -100
def test_quantify_difference_with_confidences(self):
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
@ -104,11 +104,11 @@ class TestHelperMethods:
diff = gradio.interpretation.quantify_difference_in_label(
iface, [output_1], [output_2]
)
assert 0.3 == pytest.approx(diff)
assert pytest.approx(diff) == 0.3
diff = gradio.interpretation.quantify_difference_in_label(
iface, [output_1], [output_3]
)
assert 0.8 == pytest.approx(diff)
assert pytest.approx(diff) == 0.8
def test_get_regression_value(self):
iface = Interface(lambda text: text, ["textbox"], ["label"])
@ -118,19 +118,19 @@ class TestHelperMethods:
diff = gradio.interpretation.get_regression_or_classification_value(
iface, [output_1], [output_2]
)
assert 0 == diff
assert diff == 0
diff = gradio.interpretation.get_regression_or_classification_value(
iface, [output_1], [output_3]
)
assert 0.1 == pytest.approx(diff)
assert pytest.approx(diff) == 0.1
def test_get_classification_value(self):
iface = Interface(lambda text: text, ["textbox"], ["label"])
diff = gradio.interpretation.get_regression_or_classification_value(
iface, ["cat"], ["test"]
)
assert 1 == diff
assert diff == 1
diff = gradio.interpretation.get_regression_or_classification_value(
iface, ["test"], ["test"]
)
assert 0 == diff
assert diff == 0

View File

@ -28,8 +28,8 @@ class TestSeries:
io2 = gr.load("spaces/abidlabs/image-classifier")
series = mix.Series(io1, io2)
try:
output = series("gradio/test_data/lion.jpg")
assert json.load(open(output))["label"] == "lion"
with open(series("gradio/test_data/lion.jpg")) as f:
assert json.load(f)["label"] == "lion"
except TooManyRequestsError:
pass

View File

@ -187,15 +187,16 @@ class TestVideoProcessing:
def test_video_has_playable_codecs_catches_exceptions(
self, exception_to_raise, test_file_dir
):
with patch("ffmpy.FFprobe.run", side_effect=exception_to_raise):
with tempfile.NamedTemporaryFile(
suffix="out.avi", delete=False
) as tmp_not_playable_vid:
shutil.copy(
str(test_file_dir / "bad_video_sample.mp4"),
tmp_not_playable_vid.name,
)
assert processing_utils.video_is_playable(tmp_not_playable_vid.name)
with patch(
"ffmpy.FFprobe.run", side_effect=exception_to_raise
), tempfile.NamedTemporaryFile(
suffix="out.avi", delete=False
) as tmp_not_playable_vid:
shutil.copy(
str(test_file_dir / "bad_video_sample.mp4"),
tmp_not_playable_vid.name,
)
assert processing_utils.video_is_playable(tmp_not_playable_vid.name)
def test_convert_video_to_playable_mp4(self, test_file_dir):
with tempfile.NamedTemporaryFile(

View File

@ -56,9 +56,8 @@ class TestRoutes:
assert response.status_code == 200
def test_upload_path(self, test_client):
response = test_client.post(
"/upload", files={"files": open("test/test_files/alphabet.txt", "r")}
)
with open("test/test_files/alphabet.txt") as f:
response = test_client.post("/upload", files={"files": f})
assert response.status_code == 200
file = response.json()[0]
assert "alphabet" in file
@ -72,9 +71,8 @@ class TestRoutes:
app, _, _ = io.launch(prevent_thread_lock=True)
test_client = TestClient(app)
try:
response = test_client.post(
"/upload", files={"files": open("test/test_files/alphabet.txt", "r")}
)
with open("test/test_files/alphabet.txt") as f:
response = test_client.post("/upload", files={"files": f})
assert response.status_code == 200
file = response.json()[0]
assert "alphabet" in file
@ -399,8 +397,7 @@ class TestRoutes:
class TestGeneratorRoutes:
def test_generator(self):
def generator(string):
for char in string:
yield char
yield from string
io = Interface(generator, "text", "text")
app, _, _ = io.queue().launch(prevent_thread_lock=True)

View File

@ -350,7 +350,7 @@ class TestRequest:
name: str
job: str
id: str
createdAt: str
createdAt: str # noqa: N815
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
@ -434,7 +434,7 @@ async def test_validate_with_model(respx_mock):
name: str
job: str
id: str
createdAt: str
createdAt: str # noqa: N815
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
@ -469,7 +469,7 @@ async def test_validate_and_fail_with_model(respx_mock):
@mock.patch("gradio.utils.AsyncRequest._validate_response_data")
@pytest.mark.asyncio
async def test_exception_type(validate_response_data, respx_mock):
class ResponseValidationException(Exception):
class ResponseValidationError(Exception):
message = "Response object is not valid."
validate_response_data.side_effect = Exception()
@ -479,9 +479,9 @@ async def test_exception_type(validate_response_data, respx_mock):
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.GET,
url=MOCK_REQUEST_URL,
exception_type=ResponseValidationException,
exception_type=ResponseValidationError,
)
assert isinstance(client_response.exception, ResponseValidationException)
assert isinstance(client_response.exception, ResponseValidationError)
@pytest.mark.asyncio
@ -511,9 +511,8 @@ async def test_validate_with_function(respx_mock):
@pytest.mark.asyncio
async def test_validate_and_fail_with_function(respx_mock):
def has_name(response):
if response["name"] is not None:
if response["name"] == "Alex":
return response
if response["name"] is not None and response["name"] == "Alex":
return response
raise Exception
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response({"name": "morpheus"}))