Updating typing in utils.py and templates.py (#2904)

* typing

* typing

* typing

* updated CHANGELOG.md

* typing

* typing

* typing

* typing

* Update max_args count to be -1 if positional param type

max_args would be set to the value of "infinity" when the param was of VAR_POSITIONAL type. That makes max_args hold both int and string values.
Updated the max_args to be set to -1 value.

* Updated to use same response and not previous response for validation

* Keep self._response uninitialized

* Updated to use empty BaseModel when not validation_model not present

* Reference `BlockContext` directly

* initializing end as 0 in format_net_list

* Updating docs for error_analytics

* Update file version check in utils

* Update gradio.strings import

* typing

* typing

* typing

* typing

* typing

* typing

* typing

* typing

* typing

* typing

* Add utils.py and templates.py to type check script

* fixed changelog

* Update formatting of utils.py

* rerun ci

* flagging fix

* fix typing

* formatting

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Pradyumna Rahul 2023-01-03 22:50:02 +05:30 committed by GitHub
parent d77b0702d1
commit 6a6e9175e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 130 additions and 115 deletions

View File

@ -43,12 +43,12 @@ No changes to highlight.
* The `default_enabled` parameter of the `Blocks.queue` method has no effect by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2876](https://github.com/gradio-app/gradio/pull/2876)
* Added typing to several Python files in codebase by [@abidlabs](https://github.com/abidlabs) in [PR 2887](https://github.com/gradio-app/gradio/pull/2887)
* Excluding untracked files from demo notebook check action by [@aliabd](https://github.com/aliabd) in [PR 2897](https://github.com/gradio-app/gradio/pull/2897)
* Updated typing by [@1nF0rmed](https://github.com/1nF0rmed) in [PR 2904](https://github.com/gradio-app/gradio/pull/2904)
## Contributors Shoutout:
* @JaySmithWpg for making their first contribution to gradio!
* @MohamedAliRashad for making their first contribution to gradio!
# Version 3.15.0
## New Features:

View File

@ -273,7 +273,7 @@ class CSVLogger(FlaggingCallback):
with open(
log_filepath, "w", newline="", encoding="utf-8"
) as csvfile: # newline parameter needed for Windows
csvfile.write(utils.sanitize_list_for_csv(file_content))
csvfile.write(file_content)
with open(log_filepath, "r", encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count

View File

@ -1,10 +1,10 @@
from __future__ import annotations
import typing
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Tuple
import numpy as np
import PIL
from PIL.Image import Image
from gradio import components
@ -18,16 +18,16 @@ class TextArea(components.Textbox):
def __init__(
self,
value: Optional[str | Callable] = "",
value: str | Callable | None = "",
*,
lines: int = 7,
max_lines: int = 20,
placeholder: Optional[str] = None,
label: Optional[str] = None,
placeholder: str | None = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
super().__init__(
@ -53,20 +53,20 @@ class Webcam(components.Image):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] = None,
shape: Tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "webcam",
tool: str = None,
tool: str | None = None,
type: str = "numpy",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = True,
interactive: bool | None = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
elem_id: str | None = None,
mirror_webcam: bool = True,
**kwargs,
):
@ -98,20 +98,20 @@ class Sketchpad(components.Image):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] = (28, 28),
image_mode: str = "L",
invert_colors: bool = True,
source: str = "canvas",
tool: str = None,
tool: str | None = None,
type: str = "numpy",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = True,
interactive: bool | None = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
elem_id: str | None = None,
mirror_webcam: bool = True,
**kwargs,
):
@ -143,20 +143,20 @@ class Paint(components.Image):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] = None,
shape: Tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "canvas",
tool: str = "color-sketch",
type: str = "numpy",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = True,
interactive: bool | None = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
elem_id: str | None = None,
mirror_webcam: bool = True,
**kwargs,
):
@ -188,20 +188,20 @@ class ImageMask(components.Image):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] = None,
shape: Tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = "sketch",
type: str = "numpy",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = True,
interactive: bool | None = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
elem_id: str | None = None,
mirror_webcam: bool = True,
**kwargs,
):
@ -233,20 +233,20 @@ class ImagePaint(components.Image):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] = None,
shape: Tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = "color-sketch",
type: str = "numpy",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = True,
interactive: bool | None = True,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
elem_id: str | None = None,
mirror_webcam: bool = True,
**kwargs,
):
@ -278,20 +278,20 @@ class Pil(components.Image):
def __init__(
self,
value: Optional[str | PIL.Image | np.narray] = None,
value: str | Image | np.ndarray | None = None,
*,
shape: Tuple[int, int] = None,
shape: Tuple[int, int] | None = None,
image_mode: str = "RGB",
invert_colors: bool = False,
source: str = "upload",
tool: str = None,
tool: str | None = None,
type: str = "pil",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
elem_id: str | None = None,
mirror_webcam: bool = True,
**kwargs,
):
@ -323,17 +323,17 @@ class PlayableVideo(components.Video):
def __init__(
self,
value: Optional[str | Callable] = None,
value: str | Callable | None = None,
*,
format: Optional[str] = "mp4",
format: str | None = "mp4",
source: str = "upload",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
mirror_webcam: bool = True,
include_audio: Optional[bool] = None,
include_audio: bool | None = None,
**kwargs,
):
super().__init__(
@ -360,16 +360,16 @@ class Microphone(components.Audio):
def __init__(
self,
value: Optional[str | Tuple[int, np.array] | Callable] = None,
value: str | Tuple[int, np.ndarray] | Callable | None = None,
*,
source: str = "microphone",
type: str = "numpy",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
streaming: bool = False,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
super().__init__(
@ -395,15 +395,15 @@ class Files(components.File):
def __init__(
self,
value: Optional[str | typing.List[str] | Callable] = None,
value: str | typing.List[str] | Callable | None = None,
*,
file_count: str = "multiple",
type: str = "file",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
**kwargs,
):
super().__init__(
@ -428,21 +428,21 @@ class Numpy(components.Dataframe):
def __init__(
self,
value: Optional[typing.List[typing.List[Any]] | Callable] = None,
value: typing.List[typing.List[Any]] | Callable | None = None,
*,
headers: Optional[typing.List[str]] = None,
headers: typing.List[str] | None = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: Optional[int | Tuple[int, str]] = None,
col_count: int | Tuple[int, str] | None = None,
datatype: str | typing.List[str] = "str",
type: str = "numpy",
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
max_rows: int | None = 20,
max_cols: int | None = None,
overflow_row_behaviour: str = "paginate",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
wrap: bool = False,
**kwargs,
):
@ -475,21 +475,21 @@ class Matrix(components.Dataframe):
def __init__(
self,
value: Optional[typing.List[typing.List[Any]] | Callable] = None,
value: typing.List[typing.List[Any]] | Callable | None = None,
*,
headers: Optional[typing.List[str]] = None,
headers: typing.List[str] | None = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: Optional[int | Tuple[int, str]] = None,
col_count: int | Tuple[int, str] | None = None,
datatype: str | typing.List[str] = "str",
type: str = "array",
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
max_rows: int | None = 20,
max_cols: int | None = None,
overflow_row_behaviour: str = "paginate",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
wrap: bool = False,
**kwargs,
):
@ -522,21 +522,21 @@ class List(components.Dataframe):
def __init__(
self,
value: Optional[typing.List[typing.List[Any]] | Callable] = None,
value: typing.List[typing.List[Any]] | Callable | None = None,
*,
headers: Optional[typing.List[str]] = None,
headers: typing.List[str] | None = None,
row_count: int | Tuple[int, str] = (1, "dynamic"),
col_count: Optional[int | Tuple[int, str]] = 1,
col_count: int | Tuple[int, str] = 1,
datatype: str | typing.List[str] = "str",
type: str = "array",
max_rows: Optional[int] = 20,
max_cols: Optional[int] = None,
max_rows: int | None = 20,
max_cols: int | None = None,
overflow_row_behaviour: str = "paginate",
label: Optional[str] = None,
label: str | None = None,
show_label: bool = True,
interactive: Optional[bool] = None,
interactive: bool | None = None,
visible: bool = True,
elem_id: Optional[str] = None,
elem_id: str | None = None,
wrap: bool = False,
**kwargs,
):

View File

@ -32,6 +32,7 @@ from typing import (
Tuple,
Type,
TypeVar,
Union,
)
import aiohttp
@ -43,6 +44,7 @@ from pydantic import BaseModel, Json, parse_obj_as
import gradio
from gradio.context import Context
from gradio.strings import en
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.blocks import BlockContext
@ -57,9 +59,10 @@ T = TypeVar("T")
def version_check():
try:
current_pkg_version = (
pkgutil.get_data(__name__, "version.txt").decode("ascii").strip()
)
version_data = pkgutil.get_data(__name__, "version.txt")
if not version_data:
raise FileNotFoundError
current_pkg_version = version_data.decode("ascii").strip()
latest_pkg_version = requests.get(url=PKG_VERSION_URL, timeout=3).json()[
"version"
]
@ -91,7 +94,7 @@ def get_local_ip_address() -> str:
return ip_address
def initiated_analytics(data: Dict[str:Any]) -> None:
def initiated_analytics(data: Dict[str, Any]) -> None:
try:
requests.post(
analytics_url + "gradio-initiated-analytics/", data=data, timeout=3
@ -121,7 +124,8 @@ def integration_analytics(data: Dict[str, Any]) -> None:
def error_analytics(ip_address: str, message: str) -> None:
"""
Send error analytics if there is network
:param type: RuntimeError or NameError
:param ip_address: IP address where error occurred
:param message: Details about error
"""
data = {"ip_address": ip_address, "error": message}
try:
@ -187,7 +191,7 @@ def readme_to_html(article: str) -> str:
def show_tip(interface: gradio.Blocks) -> None:
if interface.show_tips and random.random() < 1.5:
tip: str = random.choice(gradio.strings.en["TIPS"])
tip: str = random.choice(en["TIPS"])
print(f"Tip: {tip}")
@ -202,7 +206,7 @@ def launch_counter() -> None:
launches = json.load(j)
launches["launches"] += 1
if launches["launches"] in [25, 50, 150, 500, 1000]:
print(gradio.strings.en["BETA_INVITE"])
print(en["BETA_INVITE"])
with open(JSON_PATH, "w") as j:
j.write(json.dumps(launches))
except:
@ -271,11 +275,12 @@ def assert_configs_are_equivalent_besides_ids(
return True
def format_ner_list(input_string: str, ner_groups: Dict[str : str | int]):
def format_ner_list(input_string: str, ner_groups: List[Dict[str, str | int]]):
if len(ner_groups) == 0:
return [(input_string, None)]
output = []
end = 0
prev_end = 0
for group in ner_groups:
@ -325,6 +330,7 @@ def component_or_layout_class(cls_name: str) -> Type[Component] | Type[BlockCont
Returns:
cls: the component class
"""
import gradio.blocks
import gradio.components
import gradio.layouts
import gradio.templates
@ -450,8 +456,8 @@ class AsyncRequest:
method: Method,
url: str,
*,
validation_model: Type[BaseModel] = None,
validation_function: Callable = None,
validation_model: Type[BaseModel] | None = None,
validation_function: Union[Callable, None] = None,
exception_type: Type[Exception] = Exception,
raise_for_status: bool = False,
**kwargs,
@ -467,8 +473,7 @@ class AsyncRequest:
exception_class(Type[Exception]): a exception type to throw with its type
raise_for_status(bool): a flag that determines to raise httpx.Request.raise_for_status() exceptions.
"""
self._response = None
self._exception = None
self._exception: Union[Exception, None] = None
self._status = None
self._raise_for_status = raise_for_status
self._validation_model = validation_model
@ -517,7 +522,7 @@ class AsyncRequest:
return self
@staticmethod
def _create_request(method: Method, url: str, **kwargs) -> AsyncRequest:
def _create_request(method: Method, url: str, **kwargs) -> httpx.Request:
"""
Create a request. This is a httpx request wrapper function.
Args:
@ -530,7 +535,9 @@ class AsyncRequest:
request = httpx.Request(method, url, **kwargs)
return request
def _validate_response_data(self, response: ResponseJson) -> ResponseJson:
def _validate_response_data(
self, response: ResponseJson
) -> Union[BaseModel, ResponseJson | None]:
"""
Validate response using given validation methods. If there is a validation method and response is not valid,
validation functions will raise an exception for them.
@ -546,13 +553,11 @@ class AsyncRequest:
try:
# If a validation model is provided, validate response using the validation model.
if self._validation_model:
validated_response = self._validate_response_by_model(
validated_response
)
validated_response = self._validate_response_by_model(response)
# Then, If a validation function is provided, validate response using the validation function.
if self._validation_function:
validated_response = self._validate_response_by_validation_function(
validated_response
response
)
except Exception as exception:
# If one of the validation methods does not confirm, raised exception will be silently handled.
@ -561,7 +566,7 @@ class AsyncRequest:
return validated_response
def _validate_response_by_model(self, response: ResponseJson) -> ResponseJson:
def _validate_response_by_model(self, response: ResponseJson) -> BaseModel:
"""
Validate response json using the validation model.
Args:
@ -569,12 +574,14 @@ class AsyncRequest:
Returns:
ResponseJson: Validated Json object.
"""
validated_data = parse_obj_as(self._validation_model, response)
validated_data = BaseModel()
if self._validation_model:
validated_data = parse_obj_as(self._validation_model, response)
return validated_data
def _validate_response_by_validation_function(
self, response: ResponseJson
) -> ResponseJson:
) -> ResponseJson | None:
"""
Validate response json using the validation function.
Args:
@ -582,7 +589,11 @@ class AsyncRequest:
Returns:
ResponseJson: Validated Json object.
"""
validated_data = self._validation_function(response)
validated_data = None
if self._validation_function:
validated_data = self._validation_function(response)
return validated_data
def is_valid(self, raise_exceptions: bool = False) -> bool:
@ -593,7 +604,7 @@ class AsyncRequest:
Returns:
bool: validity of the data
"""
if self.has_exception:
if self.has_exception and self._exception:
if raise_exceptions:
raise self._exception
return False
@ -618,7 +629,7 @@ class AsyncRequest:
@property
def raise_exceptions(self):
if self.has_exception:
if self.has_exception and self._exception:
raise self._exception
@property
@ -666,7 +677,7 @@ def sanitize_value_for_csv(value: str | Number) -> str | Number:
return value
def sanitize_list_for_csv(values: T) -> T:
def sanitize_list_for_csv(values: List[Any]) -> List[Any]:
"""
Sanitizes a list of values (or a list of list of values) that is being written to a
CSV file to prevent CSV injection attacks.
@ -684,13 +695,13 @@ def sanitize_list_for_csv(values: T) -> T:
def append_unique_suffix(name: str, list_of_names: List[str]):
"""Appends a numerical suffix to `name` so that it does not appear in `list_of_names`."""
list_of_names = set(list_of_names) # for O(1) lookup
if name not in list_of_names:
set_of_names: set[str] = set(list_of_names) # for O(1) lookup
if name not in set_of_names:
return name
else:
suffix_counter = 1
new_name = name + f"_{suffix_counter}"
while new_name in list_of_names:
while new_name in set_of_names:
suffix_counter += 1
new_name = name + f"_{suffix_counter}"
return new_name
@ -718,7 +729,7 @@ def get_continuous_fn(fn: Callable, every: float) -> Callable:
return continuous_fn
async def cancel_tasks(task_ids: List[str]):
async def cancel_tasks(task_ids: set[str]):
if sys.version_info < (3, 8):
return None
@ -742,10 +753,13 @@ def get_cancel_function(
) -> Tuple[Callable, List[int]]:
fn_to_comp = {}
for dep in dependencies:
fn_index = next(
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
)
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]
if Context.root_block:
fn_index = next(
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
)
fn_to_comp[fn_index] = [
Context.root_block.blocks[o] for o in dep["outputs"]
]
async def cancel(session_hash: str) -> None:
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
@ -773,6 +787,7 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
parameter_types = typing.get_type_hints(fn) if inspect.isfunction(fn) else {}
min_args = 0
max_args = 0
infinity = -1
for name, param in signature.parameters.items():
has_default = param.default != param.empty
if param.kind in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD]:
@ -781,7 +796,7 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
min_args += 1
max_args += 1
elif param.kind == param.VAR_POSITIONAL:
max_args = "infinity"
max_args = infinity
elif param.kind == param.KEYWORD_ONLY:
if not has_default:
return f"Keyword-only args must have default values for function {fn}"
@ -794,7 +809,7 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
warnings.warn(
f"Expected at least {min_args} arguments for function {fn}, received {arg_count}."
)
if max_args != "infinity" and arg_count > max_args:
if max_args != infinity and arg_count > max_args:
warnings.warn(
f"Expected maximum {max_args} arguments for function {fn}, received {arg_count}."
)

View File

@ -6,4 +6,4 @@ pip_required
pip install --upgrade pip
pip install pyright
cd gradio
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py
pyright blocks.py components.py context.py data_classes.py deprecation.py documentation.py encryptor.py events.py examples.py exceptions.py external.py external_utils.py serializing.py layouts.py flagging.py interface.py utils.py templates.py