upgrade pyright==1.1.372 (#8837)

* upgrade typing

* more fixes

* add stubs folder

* add changeset

* fix docstring

* docstring

* type

* fix tests

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-07-19 00:21:51 -07:00 committed by GitHub
parent 7e8c829aad
commit 0d76169e46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 162 additions and 85 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:upgrade pyright==1.1.372

1
.gitignore vendored
View File

@ -10,6 +10,7 @@ __pycache__/
build/
__tmp/*
*.pyi
!gradio/stubs/**/*.pyi
py.typed
.ipynb_checkpoints/
.python-version

View File

@ -1,6 +1,6 @@
pytest-asyncio
pytest==7.1.2
ruff==0.4.1
pyright==1.1.327
pyright==1.1.372
gradio
pydub==0.25.1

View File

@ -116,5 +116,5 @@ class SimpleDropdown(FormComponent):
self._warn_if_invalid_choice(value)
return value
def process_example(self, input_data):
return next((c[0] for c in self.choices if c[1] == input_data), None)
def process_example(self, value):
return next((c[0] for c in self.choices if c[1] == value), None)

View File

@ -138,11 +138,11 @@ class Block:
self.render()
@property
def stateful(self):
def stateful(self) -> bool:
return False
@property
def skip_api(self):
def skip_api(self) -> bool:
return False
@property
@ -236,7 +236,7 @@ class Block:
if hasattr(self, parameter.name):
value = getattr(self, parameter.name)
if dataclasses.is_dataclass(value):
value = dataclasses.asdict(value)
value = dataclasses.asdict(value) # type: ignore
config[parameter.name] = value
for e in self.events:
to_add = e.config_data()
@ -1153,11 +1153,16 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
# ID 0 should be the root Blocks component
original_mapping[0] = root_block = Context.root_block or blocks
iterate_over_children(config["layout"]["children"])
if "layout" in config:
iterate_over_children(config["layout"]["children"]) #
first_dependency = None
# add the event triggers
if "dependencies" not in config:
raise ValueError(
"This config is missing the 'dependencies' field and cannot be loaded."
)
for dependency, fn in zip(config["dependencies"], fns):
# We used to add a "fake_event" to the config to cache examples
# without removing it. This was causing bugs in calling gr.load
@ -1480,7 +1485,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
if inspect.iscoroutinefunction(fn):
prediction = await fn(*processed_input)
else:
prediction = await anyio.to_thread.run_sync(
prediction = await anyio.to_thread.run_sync( # type: ignore
fn, *processed_input, limiter=self.limiter
)
else:

View File

@ -105,7 +105,7 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
@property
@abstractmethod
def skip_api(self):
def skip_api(self) -> bool:
"""Whether this component should be skipped from the api return value"""
@classmethod
@ -368,7 +368,7 @@ class StreamingOutput(metaclass=abc.ABCMeta):
@abc.abstractmethod
def stream_output(
self, value, output_id: str, first_chunk: bool
) -> tuple[bytes, Any]:
) -> tuple[bytes | None, Any]:
pass

View File

@ -57,27 +57,6 @@ class MessageDict(TypedDict):
metadata: NotRequired[MetadataDict]
TupleFormat = List[List[Union[str, Tuple[str], Tuple[str, str], None]]]
if TYPE_CHECKING:
from gradio.components import Timer
def import_component_and_data(
component_name: str,
) -> GradioComponent | ComponentMeta | Any | None:
try:
for component in utils.get_all_components():
if component_name == component.__name__ and isinstance(
component, ComponentMeta
):
return component
except ModuleNotFoundError as e:
raise ValueError(f"Error importing {component_name}: {e}") from e
except AttributeError:
pass
class FileMessage(GradioModel):
file: FileData
alt_text: Optional[str] = None
@ -120,6 +99,27 @@ class ChatbotDataMessages(GradioRootModel):
root: List[Message]
TupleFormat = List[List[Union[str, Tuple[str], Tuple[str, str], None]]]
if TYPE_CHECKING:
from gradio.components import Timer
def import_component_and_data(
component_name: str,
) -> GradioComponent | ComponentMeta | Any | None:
try:
for component in utils.get_all_components():
if component_name == component.__name__ and isinstance(
component, ComponentMeta
):
return component
except ModuleNotFoundError as e:
raise ValueError(f"Error importing {component_name}: {e}") from e
except AttributeError:
pass
@document()
class Chatbot(Component):
"""

View File

@ -205,7 +205,7 @@ class Dataframe(Component):
if payload.headers is not None:
return pd.DataFrame(
[] if payload.data == [[]] else payload.data,
columns=payload.headers,
columns=payload.headers, # type: ignore
)
else:
return pd.DataFrame(payload.data)
@ -294,8 +294,8 @@ class Dataframe(Component):
)
elif _is_polars_available() and isinstance(value, _import_polars().DataFrame):
if len(value) == 0:
return DataframeData(headers=list(value.to_dict().keys()), data=[[]])
df_dict = value.to_dict()
return DataframeData(headers=list(value.to_dict().keys()), data=[[]]) # type: ignore
df_dict = value.to_dict() # type: ignore
headers = list(df_dict.keys())
data = list(zip(*df_dict.values()))
return DataframeData(headers=headers, data=data)
@ -383,7 +383,7 @@ class Dataframe(Component):
if value is None:
return ""
value_df_data = self.postprocess(value)
value_df = pd.DataFrame(value_df_data.data, columns=value_df_data.headers)
value_df = pd.DataFrame(value_df_data.data, columns=value_df_data.headers) # type: ignore
return value_df.head(n=5).to_dict(orient="split")["data"]
def example_payload(self) -> Any:

View File

@ -155,18 +155,18 @@ class Dataset(Component):
elif self.type == "tuple":
return payload, self.raw_samples[payload]
def postprocess(self, sample: int | list | None) -> int | None:
def postprocess(self, value: int | list | None) -> int | None:
"""
Parameters:
sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
value: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
Returns:
Returns the index of the sample in the dataset.
"""
if sample is None or isinstance(sample, int):
return sample
if isinstance(sample, list):
if value is None or isinstance(value, int):
return value
if isinstance(value, list):
try:
index = self.samples.index(sample)
index = self.samples.index(value)
except ValueError:
index = None
warnings.warn(

View File

@ -205,13 +205,13 @@ class File(Component):
size=Path(value).stat().st_size,
)
def process_example(self, input_data: str | list | None) -> str:
if input_data is None:
def process_example(self, value: str | list | None) -> str:
if value is None:
return ""
elif isinstance(input_data, list):
return ", ".join([Path(file).name for file in input_data])
elif isinstance(value, list):
return ", ".join([Path(file).name for file in value])
else:
return Path(input_data).name
return Path(value).name
def example_payload(self) -> Any:
if self.file_count == "single":

View File

@ -126,8 +126,8 @@ class Model3D(Component):
return value
return FileData(path=str(value), orig_name=Path(value).name)
def process_example(self, input_data: str | Path | None) -> str:
return Path(input_data).name if input_data else ""
def process_example(self, value: str | Path | None) -> str:
return Path(value).name if value else ""
def example_payload(self):
return handle_file(

View File

@ -24,7 +24,7 @@ class PlotData(GradioModel):
class AltairPlotData(PlotData):
chart: Literal["bar", "line", "scatter"]
type: Literal["altair"] = "altair"
type: Literal["altair"] = "altair" # type: ignore
@document()

View File

@ -53,7 +53,7 @@ class State(Component):
super().__init__(value=self.value, render=render)
@property
def stateful(self):
def stateful(self) -> bool:
return True
def preprocess(self, payload: Any) -> Any:

View File

@ -48,14 +48,14 @@ class Timer(Component):
"""
return payload
def postprocess(self, payload: float | None) -> float | None:
def postprocess(self, value: float | None) -> float | None:
"""
Parameters:
payload: The interval of the timer as a float or None.
value: The interval of the timer as a float or None.
Returns:
The interval of the timer as a float.
"""
return payload
return value
def api_info(self) -> dict:
return {"type": "number"}

View File

@ -213,7 +213,7 @@ class Video(Component):
raise wasm_utils.WasmUnsupportedError(
"Video formatting is not supported in the Wasm mode."
)
ff = FFmpeg(
ff = FFmpeg( # type: ignore
inputs={str(file_name): None},
outputs={output_file_name: output_options},
)
@ -227,7 +227,7 @@ class Video(Component):
raise wasm_utils.WasmUnsupportedError(
"include_audio=False is not supported in the Wasm mode."
)
ff = FFmpeg(
ff = FFmpeg( # type: ignore
inputs={str(file_name): None},
outputs={output_file_name: ["-an"]},
)
@ -315,7 +315,7 @@ class Video(Component):
"Returning a video in a different format is not supported in the Wasm mode."
)
output_file_name = video[0 : video.rindex(".") + 1] + self.format
ff = FFmpeg(
ff = FFmpeg( # type: ignore
inputs={video: None},
outputs={output_file_name: None},
global_options="-y",

View File

@ -8,7 +8,7 @@ import secrets
import shutil
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Any, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Any, Iterator, List, Literal, Optional, Tuple, TypedDict, Union
from fastapi import Request
from gradio_client.documentation import document
@ -260,7 +260,7 @@ class ListFiles(GradioRootModel):
def __getitem__(self, index):
return self.root[index]
def __iter__(self):
def __iter__(self) -> Iterator[FileData]: # type: ignore[override]
return iter(self.root)

View File

@ -532,9 +532,9 @@ class EventListener(str):
_callback(block)
return Dependency(block, dep.get_config(), dep_index, fn, timer)
event_trigger.event_name = _event_name
event_trigger.has_trigger = _has_trigger
event_trigger.callback = _callback
event_trigger.event_name = _event_name # type: ignore
event_trigger.has_trigger = _has_trigger # type: ignore
event_trigger.callback = _callback # type: ignore
return event_trigger
@ -659,8 +659,8 @@ def on(
]
if triggers:
for trigger in triggers:
if trigger.callback:
trigger.callback(trigger.__self__)
if trigger.callback: # type: ignore
trigger.callback(trigger.__self__) # type: ignore
if every is not None:
from gradio.components import Timer

View File

@ -57,17 +57,17 @@ with gr.Blocks() as demo:
df_filtered = df if function == "All" else df[df["function"] == function]
if timespan != "All Time":
df_filtered = df_filtered[
df_filtered["time"] > pd.Timestamp.now() - pd.Timedelta(timespan)
df_filtered["time"] > pd.Timestamp.now() - pd.Timedelta(timespan) # type: ignore
]
df_filtered["time"] = df_filtered["time"].dt.floor("min")
df_filtered["time"] = df_filtered["time"].dt.floor("min") # type: ignore
plot = df_filtered.groupby(["time", "status"]).size().reset_index(name="count") # type: ignore
mean_process_time_for_success = df_filtered[df_filtered["status"] == "success"][
"process_time"
].mean()
return (
df_filtered["session_hash"].nunique(),
df_filtered["session_hash"].nunique(), # type: ignore
df_filtered.shape[0],
round(mean_process_time_for_success, 2),
plot,

View File

@ -12,8 +12,8 @@ from gradio.pipelines_utils import (
)
if TYPE_CHECKING:
import diffusers
import transformers
import diffusers # type: ignore
import transformers # type: ignore
def load_from_pipeline(

View File

@ -188,7 +188,7 @@ def handle_transformers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
def handle_diffusers_pipeline(pipeline: Any) -> Optional[Dict[str, Any]]:
try:
import diffusers
import diffusers # type: ignore
except ImportError as ie:
raise ImportError(
"diffusers not installed. Please try `pip install diffusers`"

View File

@ -46,7 +46,7 @@ if wasm_utils.IS_WASM:
def handle_request(self, request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
method = str(request.method)
headers = dict(request.headers)
body = None if method in ["GET", "HEAD"] else request.read()

View File

@ -12,7 +12,7 @@ import aiofiles
from aiofiles.os import stat as aio_stat
from starlette.datastructures import Headers
from starlette.exceptions import HTTPException
from starlette.responses import Response, guess_type
from starlette.responses import Response, guess_type # type: ignore
from starlette.staticfiles import StaticFiles
from starlette.types import Receive, Scope, Send

View File

@ -510,7 +510,7 @@ class GradioMultiPartParser:
self.items.append(
(
self._current_part.field_name,
_user_safe_decode(self._current_part.data, self._charset),
_user_safe_decode(self._current_part.data, str(self._charset)),
)
)
else:

View File

@ -325,7 +325,7 @@ class App(FastAPI):
not callable(app.auth)
and username in app.auth
and compare_passwords_securely(password, app.auth[username]) # type: ignore
) or (callable(app.auth) and app.auth.__call__(username, password)):
) or (callable(app.auth) and app.auth.__call__(username, password)): # type: ignore
token = secrets.token_urlsafe(16)
app.tokens[token] = username
response = JSONResponse(content={"success": True})

View File

@ -18,12 +18,12 @@ class ProgressUnit(BaseModel):
class ProgressMessage(BaseMessage):
msg: Literal[ServerMessage.progress] = ServerMessage.progress
msg: Literal[ServerMessage.progress] = ServerMessage.progress # type: ignore
progress_data: List[ProgressUnit] = []
class LogMessage(BaseMessage):
msg: Literal[ServerMessage.log] = ServerMessage.log
msg: Literal[ServerMessage.log] = ServerMessage.log # type: ignore
log: str
level: Literal["info", "warning"]
duration: Optional[float] = 10
@ -31,39 +31,39 @@ class LogMessage(BaseMessage):
class EstimationMessage(BaseMessage):
msg: Literal[ServerMessage.estimation] = ServerMessage.estimation
msg: Literal[ServerMessage.estimation] = ServerMessage.estimation # type: ignore
rank: Optional[int] = None
queue_size: int
rank_eta: Optional[float] = None
class ProcessStartsMessage(BaseMessage):
msg: Literal[ServerMessage.process_starts] = ServerMessage.process_starts
msg: Literal[ServerMessage.process_starts] = ServerMessage.process_starts # type: ignore
eta: Optional[float] = None
class ProcessCompletedMessage(BaseMessage):
msg: Literal[ServerMessage.process_completed] = ServerMessage.process_completed
msg: Literal[ServerMessage.process_completed] = ServerMessage.process_completed # type: ignore
output: dict
success: bool
class ProcessGeneratingMessage(BaseMessage):
msg: Literal[ServerMessage.process_generating] = ServerMessage.process_generating
msg: Literal[ServerMessage.process_generating] = ServerMessage.process_generating # type: ignore
output: dict
success: bool
class HeartbeatMessage(BaseModel):
msg: Literal[ServerMessage.heartbeat] = ServerMessage.heartbeat
class HeartbeatMessage(BaseMessage):
msg: Literal[ServerMessage.heartbeat] = ServerMessage.heartbeat # type: ignore
class CloseStreamMessage(BaseModel):
msg: Literal[ServerMessage.close_stream] = ServerMessage.close_stream
class CloseStreamMessage(BaseMessage):
msg: Literal[ServerMessage.close_stream] = ServerMessage.close_stream # type: ignore
class UnexpectedErrorMessage(BaseModel):
msg: Literal[ServerMessage.unexpected_error] = ServerMessage.unexpected_error
class UnexpectedErrorMessage(BaseMessage):
msg: Literal[ServerMessage.unexpected_error] = ServerMessage.unexpected_error # type: ignore
message: str
success: Literal[False] = False

62
gradio/stubs/anyio.pyi Normal file
View File

@ -0,0 +1,62 @@
""" This module contains type hints for the anyio library. It was auto-generated so may include errors."""
from typing import Any, Callable, Coroutine, TypeVar, overload, Optional, Union
from types import TracebackType
T = TypeVar('T')
T_Retval = TypeVar('T_Retval')
class CapacityLimiter:
def __init__(self, total_tokens: float): ...
async def acquire(self) -> None: ...
async def acquire_nowait(self) -> None: ...
def release(self) -> None: ...
@property
def total_tokens(self) -> float: ...
@property
def available_tokens(self) -> float: ...
def __enter__(self) -> 'CapacityLimiter': ...
def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> None: ...
class to_thread:
@staticmethod
def run_sync(func: Callable[..., T], *args: Any, cancellable: bool = False,
limiter: Optional[CapacityLimiter] = None, **kwargs: Any) -> Coroutine[Any, Any, T]: ...
@overload
def run(func: Callable[[], T_Retval], *, backend: Optional[str] = ...,
backend_options: Optional[dict] = ...) -> T_Retval: ...
@overload
def run(func: Callable[..., T_Retval], *args: Any, backend: Optional[str] = ...,
backend_options: Optional[dict] = ..., **kwargs: Any) -> T_Retval: ...
def sleep(delay: Union[int, float]) -> Coroutine[Any, Any, None]: ...
async def sleep_forever() -> None: ...
def current_time() -> float: ...
def get_cancelled_exc_class() -> type[BaseException]: ...
def create_task_group() -> 'TaskGroup': ...
class TaskGroup:
async def __aenter__(self) -> 'TaskGroup': ...
async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> bool: ...
def start_soon(self, func: Callable[..., Coroutine[Any, Any, Any]], *args: Any,
name: Optional[str] = None, **kwargs: Any) -> None: ...
def create_memory_object_stream(
max_buffer_size: int = 0
) -> tuple['MemoryObjectSender', 'MemoryObjectReceiver']: ...
class MemoryObjectSender:
async def send(self, item: Any) -> None: ...
def send_nowait(self, item: Any) -> None: ...
class MemoryObjectReceiver:
async def receive(self) -> Any: ...
def receive_nowait(self) -> Any: ...

View File

@ -107,6 +107,9 @@ exclude = [
"gradio/node/*.py",
"gradio/_frontend_code/*.py",
]
stubPath = "gradio/stubs"
extraPaths = ["gradio/stubs"]
[tool.ruff]
exclude = ["gradio/node/*.py", ".venv/*", "gradio/_frontend_code/*.py"]

View File

@ -1,5 +1,6 @@
aiofiles>=22.0,<24.0
altair>=5.0,<6.0
anyio>=3.0,<5.0
fastapi
ffmpy
gradio_client==1.1.0