return a Dependency instance from Blocks.load event listener (#4304)

* return a Dependency instance from Blocks.load event listener

* a test case for chaining then from load event

* update CHANGELOG

* add test for load.then with blocks re-used

* fixes

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Paul Garner 2023-05-29 21:52:23 +01:00 committed by GitHub
parent 4d163023ca
commit 5983836804
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 5 deletions

View File

@ -2,7 +2,7 @@
## New Features:
No changes to highlight.
* Make `Blocks.load` behave like other event listeners (allows chaining `then` off of it) [@anentropic](https://github.com/anentropic/) in [PR 4304](https://github.com/gradio-app/gradio/pull/4304)
## Bug Fixes:

View File

@ -114,7 +114,10 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
for x in blocks.dependencies:
targets_telemetry = targets_telemetry + [
str(blocks.blocks[y]) for y in x["targets"]
# Sometimes the target can be the Blocks object itself, so we need to check if its in blocks.blocks
str(blocks.blocks[y])
for y in x["targets"]
if y in blocks.blocks
]
inputs_telemetry = inputs_telemetry + [
str(blocks.blocks[y]) for y in x["inputs"]

View File

@ -1482,7 +1482,9 @@ Received outputs:
name=name, src=src, hf_token=api_key, alias=alias, **kwargs
)
else:
return self_or_cls.set_event_trigger(
from gradio.events import Dependency
dep, dep_index = self_or_cls.set_event_trigger(
event_name="load",
fn=fn,
inputs=inputs,
@ -1498,7 +1500,8 @@ Received outputs:
max_batch_size=max_batch_size,
every=every,
no_target=True,
)[0]
)
return Dependency(self_or_cls, dep, dep_index)
def clear(self):
"""Resets the layout of the Blocks object."""

View File

@ -411,7 +411,7 @@ class App(FastAPI):
dependency = app.get_blocks().dependencies[fn_index_inferred]
target = dependency["targets"][0] if len(dependency["targets"]) else None
event_data = EventData(
app.get_blocks().blocks[target] if target else None,
app.get_blocks().blocks.get(target) if target else None,
body.event_data,
)
batch = dependency["batch"]

View File

@ -1,8 +1,12 @@
import os
import pytest
from fastapi.testclient import TestClient
import gradio as gr
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestEvent:
def test_clear_event(self):
@ -69,6 +73,47 @@ class TestEvent:
assert not parent.config["dependencies"][2]["trigger_only_on_success"]
assert parent.config["dependencies"][3]["trigger_only_on_success"]
def test_load_chaining(self):
calls = 0
def increment():
nonlocal calls
calls += 1
return str(calls)
with gr.Blocks() as demo:
out = gr.Textbox(label="Call counter")
demo.load(increment, inputs=None, outputs=out).then(
increment, inputs=None, outputs=out
)
assert demo.config["dependencies"][0]["trigger"] == "load"
assert demo.config["dependencies"][0]["trigger_after"] is None
assert demo.config["dependencies"][1]["trigger"] == "then"
assert demo.config["dependencies"][1]["trigger_after"] == 0
def test_load_chaining_reuse(self):
calls = 0
def increment():
nonlocal calls
calls += 1
return str(calls)
with gr.Blocks() as demo:
out = gr.Textbox(label="Call counter")
demo.load(increment, inputs=None, outputs=out).then(
increment, inputs=None, outputs=out
)
with gr.Blocks() as demo2:
demo.render()
assert demo2.config["dependencies"][0]["trigger"] == "load"
assert demo2.config["dependencies"][0]["trigger_after"] is None
assert demo2.config["dependencies"][1]["trigger"] == "then"
assert demo2.config["dependencies"][1]["trigger_after"] == 0
class TestEventErrors:
def test_event_defined_invalid_scope(self):