diff --git a/CHANGELOG.md b/CHANGELOG.md index c1cb61049d..e472419714 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## New Features: -No changes to highlight. +* Make `Blocks.load` behave like other event listeners (allows chaining `then` off of it) [@anentropic](https://github.com/anentropic/) in [PR 4304](https://github.com/gradio-app/gradio/pull/4304) ## Bug Fixes: diff --git a/gradio/analytics.py b/gradio/analytics.py index c1c1a908bb..c4b710aa26 100644 --- a/gradio/analytics.py +++ b/gradio/analytics.py @@ -114,7 +114,10 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None: for x in blocks.dependencies: targets_telemetry = targets_telemetry + [ - str(blocks.blocks[y]) for y in x["targets"] + # Sometimes the target can be the Blocks object itself, so we need to check if its in blocks.blocks + str(blocks.blocks[y]) + for y in x["targets"] + if y in blocks.blocks ] inputs_telemetry = inputs_telemetry + [ str(blocks.blocks[y]) for y in x["inputs"] diff --git a/gradio/blocks.py b/gradio/blocks.py index 7b35cd8468..9731fb739d 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1482,7 +1482,9 @@ Received outputs: name=name, src=src, hf_token=api_key, alias=alias, **kwargs ) else: - return self_or_cls.set_event_trigger( + from gradio.events import Dependency + + dep, dep_index = self_or_cls.set_event_trigger( event_name="load", fn=fn, inputs=inputs, @@ -1498,7 +1500,8 @@ Received outputs: max_batch_size=max_batch_size, every=every, no_target=True, - )[0] + ) + return Dependency(self_or_cls, dep, dep_index) def clear(self): """Resets the layout of the Blocks object.""" diff --git a/gradio/routes.py b/gradio/routes.py index aad4270d16..753c55c7c6 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -411,7 +411,7 @@ class App(FastAPI): dependency = app.get_blocks().dependencies[fn_index_inferred] target = dependency["targets"][0] if len(dependency["targets"]) else None event_data = EventData( - app.get_blocks().blocks[target] if target else None, + app.get_blocks().blocks.get(target) if target else None, body.event_data, ) batch = dependency["batch"] diff --git a/test/test_events.py b/test/test_events.py index 79f8b63b87..2c72af136c 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -1,8 +1,12 @@ +import os + import pytest from fastapi.testclient import TestClient import gradio as gr +os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" + class TestEvent: def test_clear_event(self): @@ -69,6 +73,47 @@ class TestEvent: assert not parent.config["dependencies"][2]["trigger_only_on_success"] assert parent.config["dependencies"][3]["trigger_only_on_success"] + def test_load_chaining(self): + calls = 0 + + def increment(): + nonlocal calls + calls += 1 + return str(calls) + + with gr.Blocks() as demo: + out = gr.Textbox(label="Call counter") + demo.load(increment, inputs=None, outputs=out).then( + increment, inputs=None, outputs=out + ) + + assert demo.config["dependencies"][0]["trigger"] == "load" + assert demo.config["dependencies"][0]["trigger_after"] is None + assert demo.config["dependencies"][1]["trigger"] == "then" + assert demo.config["dependencies"][1]["trigger_after"] == 0 + + def test_load_chaining_reuse(self): + calls = 0 + + def increment(): + nonlocal calls + calls += 1 + return str(calls) + + with gr.Blocks() as demo: + out = gr.Textbox(label="Call counter") + demo.load(increment, inputs=None, outputs=out).then( + increment, inputs=None, outputs=out + ) + + with gr.Blocks() as demo2: + demo.render() + + assert demo2.config["dependencies"][0]["trigger"] == "load" + assert demo2.config["dependencies"][0]["trigger_after"] is None + assert demo2.config["dependencies"][1]["trigger"] == "then" + assert demo2.config["dependencies"][1]["trigger_after"] == 0 + class TestEventErrors: def test_event_defined_invalid_scope(self):