Fix type hints for render and on (#8429)

* type hint

* add changeset

* Use union

* type check

* lint

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-06-03 19:33:25 -04:00 committed by GitHub
parent 341844f04e
commit d393a4a224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 17 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Fix type hints for render and on

View File

@ -5,7 +5,17 @@ from __future__ import annotations
import dataclasses
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Sequence,
Union,
cast,
)
from gradio_client.documentation import document
from jinja2 import Template
@ -145,6 +155,32 @@ class EventListenerMethod:
event_name: str
if TYPE_CHECKING:
EventListenerCallable = Callable[
[
Union[Callable, None],
Union[Component, Sequence[Component], None],
Union[Block, Sequence[Block], Sequence[Component], Component, None],
Union[str, None, Literal[False]],
bool,
Literal["full", "minimal", "hidden"],
Union[bool, None],
bool,
int,
bool,
bool,
Union[Dict[str, Any], List[Dict[str, Any]], None],
Union[float, None],
Union[Literal["once", "multiple", "always_last"], None],
Union[str, None],
Union[int, None, Literal["default"]],
Union[str, None],
bool,
],
Dependency,
]
class EventListener(str):
def __new__(cls, event_name, *_args, **_kwargs):
return super().__new__(cls, event_name)
@ -331,7 +367,7 @@ class EventListener(str):
def on(
triggers: Sequence[Any] | Any | None = None,
triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None,
fn: Callable | None | Literal["decorator"] = "decorator",
inputs: Component | list[Component] | set[Component] | None = None,
outputs: Block | list[Block] | list[Component] | None = None,
@ -376,8 +412,10 @@ def on(
"""
from gradio.components.base import Component
if isinstance(triggers, EventListener):
triggers = [triggers]
triggers_typed = cast(EventListener, triggers)
if isinstance(triggers_typed, EventListener):
triggers_typed = [triggers_typed]
if isinstance(inputs, Component):
inputs = [inputs]
@ -418,18 +456,18 @@ def on(
if root_block is None:
raise Exception("Cannot call on() outside of a gradio.Blocks context.")
if triggers is None:
triggers = (
methods = (
[EventListenerMethod(input, "change") for input in inputs]
if inputs is not None
else []
) # type: ignore
else:
triggers = [
EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name)
for t in triggers
] # type: ignore
methods = [
EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name) # type: ignore
for t in triggers_typed
]
dep, dep_index = root_block.set_event_trigger(
triggers,
methods,
fn,
inputs,
outputs,
@ -448,7 +486,7 @@ def on(
show_api=show_api,
trigger_mode=trigger_mode,
)
set_cancel_events(triggers, cancels)
set_cancel_events(methods, cancels)
return Dependency(None, dep.get_config(), dep_index, fn)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Callable, Literal
from typing import TYPE_CHECKING, Callable, List, Literal, Sequence, Union, cast
from gradio_client.documentation import document
@ -10,6 +10,9 @@ from gradio.context import Context, LocalContext
from gradio.events import EventListener, EventListenerMethod
from gradio.layouts import Column, Row
if TYPE_CHECKING:
from gradio.events import EventListenerCallable
class Renderable:
def __init__(
@ -76,8 +79,8 @@ class Renderable:
@document()
def render(
inputs: list[Component] | None = None,
triggers: list[EventListener] | EventListener | None = None,
inputs: list[Component] | Component | None = None,
triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None,
*,
queue: bool = True,
trigger_mode: Literal["once", "multiple", "always_last"] | None = "always_last",
@ -116,6 +119,8 @@ def render(
btn = gr.Button("Clear")
btn.click(lambda: gr.Textbox(value=""), None, text)
"""
new_triggers = cast(Union[List[EventListener], EventListener, None], triggers)
if Context.root_block is None:
raise ValueError("Reactive render must be inside a Blocks context.")
@ -123,16 +128,18 @@ def render(
[inputs] if isinstance(inputs, Component) else [] if inputs is None else inputs
)
_triggers: list[tuple[Block | None, str]] = []
if triggers is None:
if new_triggers is None:
_triggers = [(Context.root_block, "load")]
for input in inputs:
if hasattr(input, "change"):
_triggers.append((input, "change"))
else:
triggers = [triggers] if isinstance(triggers, EventListener) else triggers
new_triggers = (
[new_triggers] if isinstance(new_triggers, EventListener) else new_triggers
)
_triggers = [
(getattr(t, "__self__", None) if t.has_trigger else None, t.event_name)
for t in triggers
for t in new_triggers
]
def wrapper_function(fn):