Fix type hints for render and on (#8429)

* type hint

* add changeset

* Use union

* type check

* lint

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-06-03 19:33:25 -04:00 committed by GitHub
parent 341844f04e
commit d393a4a224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 17 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Fix type hints for render and on

View File

@ -5,7 +5,17 @@ from __future__ import annotations
import dataclasses import dataclasses
from functools import partial, wraps from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Sequence,
Union,
cast,
)
from gradio_client.documentation import document from gradio_client.documentation import document
from jinja2 import Template from jinja2 import Template
@ -145,6 +155,32 @@ class EventListenerMethod:
event_name: str event_name: str
if TYPE_CHECKING:
EventListenerCallable = Callable[
[
Union[Callable, None],
Union[Component, Sequence[Component], None],
Union[Block, Sequence[Block], Sequence[Component], Component, None],
Union[str, None, Literal[False]],
bool,
Literal["full", "minimal", "hidden"],
Union[bool, None],
bool,
int,
bool,
bool,
Union[Dict[str, Any], List[Dict[str, Any]], None],
Union[float, None],
Union[Literal["once", "multiple", "always_last"], None],
Union[str, None],
Union[int, None, Literal["default"]],
Union[str, None],
bool,
],
Dependency,
]
class EventListener(str): class EventListener(str):
def __new__(cls, event_name, *_args, **_kwargs): def __new__(cls, event_name, *_args, **_kwargs):
return super().__new__(cls, event_name) return super().__new__(cls, event_name)
@ -331,7 +367,7 @@ class EventListener(str):
def on( def on(
triggers: Sequence[Any] | Any | None = None, triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None,
fn: Callable | None | Literal["decorator"] = "decorator", fn: Callable | None | Literal["decorator"] = "decorator",
inputs: Component | list[Component] | set[Component] | None = None, inputs: Component | list[Component] | set[Component] | None = None,
outputs: Block | list[Block] | list[Component] | None = None, outputs: Block | list[Block] | list[Component] | None = None,
@ -376,8 +412,10 @@ def on(
""" """
from gradio.components.base import Component from gradio.components.base import Component
if isinstance(triggers, EventListener): triggers_typed = cast(EventListener, triggers)
triggers = [triggers]
if isinstance(triggers_typed, EventListener):
triggers_typed = [triggers_typed]
if isinstance(inputs, Component): if isinstance(inputs, Component):
inputs = [inputs] inputs = [inputs]
@ -418,18 +456,18 @@ def on(
if root_block is None: if root_block is None:
raise Exception("Cannot call on() outside of a gradio.Blocks context.") raise Exception("Cannot call on() outside of a gradio.Blocks context.")
if triggers is None: if triggers is None:
triggers = ( methods = (
[EventListenerMethod(input, "change") for input in inputs] [EventListenerMethod(input, "change") for input in inputs]
if inputs is not None if inputs is not None
else [] else []
) # type: ignore ) # type: ignore
else: else:
triggers = [ methods = [
EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name) EventListenerMethod(t.__self__ if t.has_trigger else None, t.event_name) # type: ignore
for t in triggers for t in triggers_typed
] # type: ignore ]
dep, dep_index = root_block.set_event_trigger( dep, dep_index = root_block.set_event_trigger(
triggers, methods,
fn, fn,
inputs, inputs,
outputs, outputs,
@ -448,7 +486,7 @@ def on(
show_api=show_api, show_api=show_api,
trigger_mode=trigger_mode, trigger_mode=trigger_mode,
) )
set_cancel_events(triggers, cancels) set_cancel_events(methods, cancels)
return Dependency(None, dep.get_config(), dep_index, fn) return Dependency(None, dep.get_config(), dep_index, fn)

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, Literal from typing import TYPE_CHECKING, Callable, List, Literal, Sequence, Union, cast
from gradio_client.documentation import document from gradio_client.documentation import document
@ -10,6 +10,9 @@ from gradio.context import Context, LocalContext
from gradio.events import EventListener, EventListenerMethod from gradio.events import EventListener, EventListenerMethod
from gradio.layouts import Column, Row from gradio.layouts import Column, Row
if TYPE_CHECKING:
from gradio.events import EventListenerCallable
class Renderable: class Renderable:
def __init__( def __init__(
@ -76,8 +79,8 @@ class Renderable:
@document() @document()
def render( def render(
inputs: list[Component] | None = None, inputs: list[Component] | Component | None = None,
triggers: list[EventListener] | EventListener | None = None, triggers: Sequence[EventListenerCallable] | EventListenerCallable | None = None,
*, *,
queue: bool = True, queue: bool = True,
trigger_mode: Literal["once", "multiple", "always_last"] | None = "always_last", trigger_mode: Literal["once", "multiple", "always_last"] | None = "always_last",
@ -116,6 +119,8 @@ def render(
btn = gr.Button("Clear") btn = gr.Button("Clear")
btn.click(lambda: gr.Textbox(value=""), None, text) btn.click(lambda: gr.Textbox(value=""), None, text)
""" """
new_triggers = cast(Union[List[EventListener], EventListener, None], triggers)
if Context.root_block is None: if Context.root_block is None:
raise ValueError("Reactive render must be inside a Blocks context.") raise ValueError("Reactive render must be inside a Blocks context.")
@ -123,16 +128,18 @@ def render(
[inputs] if isinstance(inputs, Component) else [] if inputs is None else inputs [inputs] if isinstance(inputs, Component) else [] if inputs is None else inputs
) )
_triggers: list[tuple[Block | None, str]] = [] _triggers: list[tuple[Block | None, str]] = []
if triggers is None: if new_triggers is None:
_triggers = [(Context.root_block, "load")] _triggers = [(Context.root_block, "load")]
for input in inputs: for input in inputs:
if hasattr(input, "change"): if hasattr(input, "change"):
_triggers.append((input, "change")) _triggers.append((input, "change"))
else: else:
triggers = [triggers] if isinstance(triggers, EventListener) else triggers new_triggers = (
[new_triggers] if isinstance(new_triggers, EventListener) else new_triggers
)
_triggers = [ _triggers = [
(getattr(t, "__self__", None) if t.has_trigger else None, t.event_name) (getattr(t, "__self__", None) if t.has_trigger else None, t.event_name)
for t in triggers for t in new_triggers
] ]
def wrapper_function(fn): def wrapper_function(fn):