From 831ae1405f0f2cddf8d508775731eef39c5a09df Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 14 Oct 2022 18:43:24 -0400 Subject: [PATCH] Cancel events from other events (#2433) * WIP * Use async iteration * Format + comment * Very hacky WIP * Fix synchronization * Add comments + tidy up implementation * Remove print * Fix rebase * Lint * Disconnect queue when cancelled * Add stop button for interface automaticallY * Unit tests + interface fixes * Skip some tests on 3.7 * Skip in 3.7 * Fix skip message * Fix for python 3.7 * Add stop variant to button variant type union * CHANGELOG * Add demos/gifs to the changelog --- CHANGELOG.md | 90 +++++++++++++++++++- demo/cancel_events/run.py | 49 +++++++++++ gradio/blocks.py | 19 ++++- gradio/events.py | 104 +++++++++++++++++++++--- gradio/external.py | 14 +++- gradio/interface.py | 29 ++++++- gradio/queue.py | 27 +++++- gradio/routes.py | 47 ++++++++++- gradio/test_data/blocks_configs.py | 8 ++ scripts/copy_demos.py | 1 + test/test_blocks.py | 55 +++++++++++++ test/test_external.py | 6 +- test/test_interfaces.py | 46 +++++++++++ test/test_queue.py | 34 +++++++- test/test_routes.py | 2 + ui/packages/app/src/Blocks.svelte | 7 +- ui/packages/app/src/api.ts | 27 +++++- ui/packages/app/src/components/types.ts | 1 + ui/packages/button/src/Button.svelte | 2 +- ui/packages/theme/src/tokens.css | 5 ++ 20 files changed, 538 insertions(+), 35 deletions(-) create mode 100644 demo/cancel_events/run.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 450c1e6511..b6b8a8d4a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,90 @@ -# Upcoming Release +# Upcoming Release + +## New Features: + +### Cancelling Running Events +Running events can be cancelled when other events are triggered! To test this feature, pass the `cancels` parameter to the event listener. +For this feature to work, the queue must be enabled. + +![cancel_on_change_rl](https://user-images.githubusercontent.com/41651716/195952623-61a606bd-e82b-4e1a-802e-223154cb8727.gif) + +Code: +```python +import time +import gradio as gr + +def fake_diffusion(steps): + for i in range(steps): + time.sleep(1) + yield str(i) + +def long_prediction(*args, **kwargs): + time.sleep(10) + return 42 + + +with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + n = gr.Slider(1, 10, value=9, step=1, label="Number Steps") + run = gr.Button() + output = gr.Textbox(label="Iterative Output") + stop = gr.Button(value="Stop Iterating") + with gr.Column(): + prediction = gr.Number(label="Expensive Calculation") + run_pred = gr.Button(value="Run Expensive Calculation") + with gr.Column(): + cancel_on_change = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Change") + + click_event = run.click(fake_diffusion, n, output) + stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event]) + pred_event = run_pred.click(fn=long_prediction, inputs=None, outputs=prediction) + + cancel_on_change.change(None, None, None, cancels=[click_event, pred_event]) + + +demo.queue(concurrency_count=1, max_size=20).launch() +``` + +For interfaces, a stop button will be added automatically if the function uses a `yield` statement. + +```python +import gradio as gr +import time + +def iteration(steps): + for i in range(steps): + time.sleep(0.5) + yield i + +gr.Interface(iteration, + inputs=gr.Slider(minimum=1, maximum=10, step=1, value=5), + outputs=gr.Number()).queue().launch() +``` + +![stop_interface_rl](https://user-images.githubusercontent.com/41651716/195952883-e7ca4235-aae3-4852-8f28-96d01d0c5822.gif) + + +## Bug Fixes: +No changes to highlight. + +## Documentation Changes: +No changes to highlight. + +## Testing and Infrastructure Changes: +No changes to highlight. + +## Breaking Changes: +No changes to highlight. + +## Full Changelog: +* Enable running events to be cancelled from other events by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2433](https://github.com/gradio-app/gradio/pull/2433) + +## Contributors Shoutout: +No changes to highlight. + + +# Version 3.5 ## Bug Fixes: @@ -42,8 +128,6 @@ No changes to highlight. * Fix embedded interfaces on touch screen devices by [@aliabd](https://github.com/aliabd) in [PR 2457](https://github.com/gradio-app/gradio/pull/2457) * Upload all demos to spaces by [@aliabd](https://github.com/aliabd) in [PR 2281](https://github.com/gradio-app/gradio/pull/2281) - - ## Contributors Shoutout: No changes to highlight. diff --git a/demo/cancel_events/run.py b/demo/cancel_events/run.py new file mode 100644 index 0000000000..4f3267c166 --- /dev/null +++ b/demo/cancel_events/run.py @@ -0,0 +1,49 @@ +import time +import gradio as gr + + +def fake_diffusion(steps): + for i in range(steps): + print(f"Current step: {i}") + time.sleep(1) + yield str(i) + + +def long_prediction(*args, **kwargs): + time.sleep(10) + return 42 + + +with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(): + n = gr.Slider(1, 10, value=9, step=1, label="Number Steps") + run = gr.Button() + output = gr.Textbox(label="Iterative Output") + stop = gr.Button(value="Stop Iterating") + with gr.Column(): + textbox = gr.Textbox(label="Prompt") + prediction = gr.Number(label="Expensive Calculation") + run_pred = gr.Button(value="Run Expensive Calculation") + with gr.Column(): + cancel_on_change = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Change") + cancel_on_submit = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Submit") + echo = gr.Textbox(label="Echo") + with gr.Row(): + with gr.Column(): + image = gr.Image(source="webcam", tool="editor", label="Cancel on edit", interactive=True) + with gr.Column(): + video = gr.Video(source="webcam", label="Cancel on play", interactive=True) + + click_event = run.click(fake_diffusion, n, output) + stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event]) + pred_event = run_pred.click(fn=long_prediction, inputs=[textbox], outputs=prediction) + + cancel_on_change.change(None, None, None, cancels=[click_event, pred_event]) + cancel_on_submit.submit(lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event]) + image.edit(None, None, None, cancels=[click_event, pred_event]) + video.play(None, None, None, cancels=[click_event, pred_event]) + + +if __name__ == "__main__": + demo.queue(concurrency_count=2, max_size=20).launch() diff --git a/gradio/blocks.py b/gradio/blocks.py index 34c0b15b30..a13736195b 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -119,7 +119,8 @@ class Block: js: Optional[str] = None, no_target: bool = False, queue: Optional[bool] = None, - ) -> None: + cancels: List[int] | None = None, + ) -> Dict[str, Any]: """ Adds an event to the component's dependencies. Parameters: @@ -166,6 +167,7 @@ class Block: "api_name": api_name, "scroll_to_output": scroll_to_output, "show_progress": show_progress, + "cancels": cancels if cancels else [], } if api_name is not None: dependency["documentation"] = [ @@ -179,6 +181,7 @@ class Block: ], ] Context.root_block.dependencies.append(dependency) + return dependency def get_config(self): return { @@ -1062,6 +1065,20 @@ class Blocks(BlockContext): if self.enable_queue and not hasattr(self, "_queue"): self.queue() + for dep in self.dependencies: + for i in dep["cancels"]: + queue_status = self.dependencies[i]["queue"] + if queue_status is False or ( + queue_status is None and not self.enable_queue + ): + raise ValueError( + "In order to cancel an event, the queue for that event must be enabled! " + "You may get this error by either 1) passing a function that uses the yield keyword " + "into an interface without enabling the queue or 2) defining an event that cancels " + "another event without enabling the queue. Both can be solved by calling .queue() " + "before .launch()" + ) + self.config = self.get_config_file() self.share = share self.encrypt = encrypt diff --git a/gradio/events.py b/gradio/events.py index fc92eacc37..40a60db9b9 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -1,14 +1,59 @@ from __future__ import annotations +import asyncio +import sys import warnings from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple -from gradio.blocks import Block +from gradio.blocks import Block, Context, update if TYPE_CHECKING: # Only import for type checking (is False at runtime). from gradio.components import Component, StatusTracker +def get_cancel_function( + dependencies: List[Dict[str, Any]] +) -> Tuple[Callable, List[int]]: + fn_to_comp = {} + for dep in dependencies: + fn_index = next( + i for i, d in enumerate(Context.root_block.dependencies) if d == dep + ) + fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]] + + async def cancel(session_hash: str) -> None: + if sys.version_info < (3, 8): + return None + + task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp]) + + matching_tasks = [ + task for task in asyncio.all_tasks() if task.get_name() in task_ids + ] + for task in matching_tasks: + task.cancel() + await asyncio.gather(*matching_tasks, return_exceptions=True) + + return ( + cancel, + list(fn_to_comp.keys()), + ) + + +def set_cancel_events(block: Block, event_name: str, cancels: List[Dict[str, Any]]): + if cancels: + cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels) + block.set_event_trigger( + event_name, + cancel_fn, + inputs=None, + outputs=None, + queue=False, + preprocess=False, + cancels=fn_indices_to_cancel, + ) + + class Changeable(Block): def change( self, @@ -22,6 +67,7 @@ class Changeable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -38,6 +84,7 @@ class Changeable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -45,7 +92,7 @@ class Changeable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "change", fn, inputs, @@ -58,6 +105,8 @@ class Changeable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "change", cancels) + return dep class Clickable(Block): @@ -73,6 +122,7 @@ class Clickable(Block): queue=None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -89,6 +139,7 @@ class Clickable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -96,7 +147,7 @@ class Clickable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "click", fn, inputs, @@ -109,6 +160,8 @@ class Clickable(Block): preprocess=preprocess, postprocess=postprocess, ) + set_cancel_events(self, "click", cancels) + return dep class Submittable(Block): @@ -124,6 +177,7 @@ class Submittable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -141,6 +195,7 @@ class Submittable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -148,7 +203,7 @@ class Submittable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "submit", fn, inputs, @@ -161,6 +216,8 @@ class Submittable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "submit", cancels) + return dep class Editable(Block): @@ -176,6 +233,7 @@ class Editable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -192,6 +250,7 @@ class Editable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -199,7 +258,7 @@ class Editable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "edit", fn, inputs, @@ -212,6 +271,8 @@ class Editable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "edit", cancels) + return dep class Clearable(Block): @@ -227,6 +288,7 @@ class Clearable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -243,6 +305,7 @@ class Clearable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -250,7 +313,7 @@ class Clearable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "submit", fn, inputs, @@ -263,6 +326,8 @@ class Clearable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "submit", cancels) + return dep class Playable(Block): @@ -278,6 +343,7 @@ class Playable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -294,6 +360,7 @@ class Playable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -301,7 +368,7 @@ class Playable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "play", fn, inputs, @@ -314,6 +381,8 @@ class Playable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "play", cancels) + return dep def pause( self, @@ -327,6 +396,7 @@ class Playable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -343,6 +413,7 @@ class Playable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -350,7 +421,7 @@ class Playable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "pause", fn, inputs, @@ -363,6 +434,8 @@ class Playable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "pause", cancels) + return dep def stop( self, @@ -376,6 +449,7 @@ class Playable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -392,6 +466,7 @@ class Playable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. if status_tracker: @@ -399,7 +474,7 @@ class Playable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "stop", fn, inputs, @@ -412,6 +487,8 @@ class Playable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "stop", cancels) + return dep class Streamable(Block): @@ -427,6 +504,7 @@ class Streamable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -443,6 +521,7 @@ class Streamable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. self.streaming = True @@ -452,7 +531,7 @@ class Streamable(Block): "The 'status_tracker' parameter has been deprecated and has no effect." ) - self.set_event_trigger( + dep = self.set_event_trigger( "stream", fn, inputs, @@ -465,6 +544,8 @@ class Streamable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "stream", cancels) + return dep class Blurrable(Block): @@ -480,6 +561,7 @@ class Blurrable(Block): queue: Optional[bool] = None, preprocess: bool = True, postprocess: bool = True, + cancels: List[Dict[str, Any]] | None = None, _js: Optional[str] = None, ): """ @@ -495,6 +577,7 @@ class Blurrable(Block): queue: If True, will place the request on the queue, if the queue exists preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component). postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser. + cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. """ # _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. @@ -511,3 +594,4 @@ class Blurrable(Block): postprocess=postprocess, queue=queue, ) + set_cancel_events(self, "blur", cancels) diff --git a/gradio/external.py b/gradio/external.py index 07ef5df90d..43461b63e2 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -9,6 +9,7 @@ import math import numbers import operator import re +import uuid import warnings from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple @@ -413,7 +414,7 @@ def get_spaces(model_name, api_key, alias, **kwargs): async def get_pred_from_ws( - websocket: websockets.WebSocketClientProtocol, data: str + websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str ) -> Dict[str, Any]: completed = False while not completed: @@ -421,6 +422,8 @@ async def get_pred_from_ws( resp = json.loads(msg) if resp["msg"] == "queue_full": raise exceptions.Error("Queue is full! Please try again.") + if resp["msg"] == "send_hash": + await websocket.send(hash_data) elif resp["msg"] == "send_data": await websocket.send(data) completed = resp["msg"] == "process_completed" @@ -428,9 +431,9 @@ async def get_pred_from_ws( def get_ws_fn(ws_url): - async def ws_fn(data): + async def ws_fn(data, hash_data): async with websockets.connect(ws_url, open_timeout=10) as websocket: - return await get_pred_from_ws(websocket, data) + return await get_pred_from_ws(websocket, data, hash_data) return ws_fn @@ -467,8 +470,11 @@ def get_spaces_blocks(model_name, config): def get_fn(outputs, fn_index, use_ws): def fn(*data): data = json.dumps({"data": data, "fn_index": fn_index}) + hash_data = json.dumps( + {"fn_index": fn_index, "session_hash": str(uuid.uuid4())} + ) if use_ws: - result = utils.synchronize_async(ws_fn, data) + result = utils.synchronize_async(ws_fn, data, hash_data) output = result["data"] else: response = requests.post(api_url, headers=headers, data=data) diff --git a/gradio/interface.py b/gradio/interface.py index e130f84595..dff7a487d9 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -20,7 +20,7 @@ from markdown_it import MarkdownIt from mdit_py_plugins.footnote import footnote_plugin from gradio import Examples, interpretation, utils -from gradio.blocks import Blocks +from gradio.blocks import Blocks, update from gradio.components import ( Button, Component, @@ -472,9 +472,19 @@ class Interface(Blocks): clear_btn = Button("Clear") if not self.live: submit_btn = Button("Submit", variant="primary") + # Stopping jobs only works if the queue is enabled + # We don't know if the queue is enabled when the interface + # is created. We use whether a generator function is provided + # as a proxy of whether the queue will be enabled. + # Using a generator function without the queue will raise an error. + if inspect.isgeneratorfunction(fn): + stop_btn = Button("Stop", variant="stop") + elif self.interface_type == self.InterfaceTypes.UNIFIED: clear_btn = Button("Clear") submit_btn = Button("Submit", variant="primary") + if inspect.isgeneratorfunction(fn) and not self.live: + stop_btn = Button("Stop", variant="stop") if self.allow_flagging == "manual": flag_btns = render_flag_btns(self.flagging_options) @@ -491,6 +501,13 @@ class Interface(Blocks): if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY: clear_btn = Button("Clear") submit_btn = Button("Generate", variant="primary") + if inspect.isgeneratorfunction(fn) and not self.live: + # Stopping jobs only works if the queue is enabled + # We don't know if the queue is enabled when the interface + # is created. We use whether a generator function is provided + # as a proxy of whether the queue will be enabled. + # Using a generator function without the queue will raise an error. + stop_btn = Button("Stop", variant="stop") if self.allow_flagging == "manual": flag_btns = render_flag_btns(self.flagging_options) if self.interpretation: @@ -535,7 +552,7 @@ class Interface(Blocks): postprocess=not (self.api_mode), ) else: - submit_btn.click( + pred = submit_btn.click( self.fn, self.input_components, self.output_components, @@ -544,6 +561,14 @@ class Interface(Blocks): preprocess=not (self.api_mode), postprocess=not (self.api_mode), ) + if inspect.isgeneratorfunction(fn): + stop_btn.click( + None, + inputs=None, + outputs=None, + cancels=[pred], + ) + clear_btn.click( None, [], diff --git a/gradio/queue.py b/gradio/queue.py index 74007b9ca2..dd523a40e4 100644 --- a/gradio/queue.py +++ b/gradio/queue.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -import json +import sys import time from typing import Dict, List, Optional @@ -84,7 +84,9 @@ class Queue: event = self.event_queue.pop(0) self.active_jobs[self.active_jobs.index(None)] = event - run_coro_in_background(self.process_event, event) + task = run_coro_in_background(self.process_event, event) + if sys.version_info >= (3, 8): + task.set_name(f"{event.session_hash}_{event.fn_index}") run_coro_in_background(self.broadcast_live_estimations) def push(self, event: Event) -> int | None: @@ -239,6 +241,13 @@ class Queue: ) elif response.json.get("is_generating", False): while response.json.get("is_generating", False): + # Python 3.7 doesn't have named tasks. + # In order to determine if a task was cancelled, we + # ping the websocket to see if it was closed mid-iteration. + if sys.version_info < (3, 8): + is_alive = await self.send_message(event, {"msg": "alive?"}) + if not is_alive: + return old_response = response await self.send_message( event, @@ -276,6 +285,18 @@ class Queue: pass finally: await self.clean_event(event) + # Always reset the state of the iterator + # If the job finished successfully, this has no effect + # If the job is cancelled, this will enable future runs + # to start "from scratch" + await Request( + method=Request.Method.POST, + url=f"{self.server_path}reset", + json={ + "session_hash": event.session_hash, + "fn_index": event.fn_index, + }, + ) async def send_message(self, event, data: Dict) -> bool: try: @@ -301,6 +322,8 @@ class Event: self.websocket = websocket self.data: PredictBody | None = None self.lost_connection_time: float | None = None + self.fn_index = 0 + self.session_hash = "foo" async def disconnect(self, code=1000): await self.websocket.close(code=code) diff --git a/gradio/routes.py b/gradio/routes.py index 771e2fdca3..fa9334ad9f 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -77,6 +77,11 @@ class PredictBody(BaseModel): fn_index: Optional[int] +class ResetBody(BaseModel): + session_hash: str + fn_index: int + + ########### # Auth ########### @@ -93,7 +98,7 @@ class App(FastAPI): self.blocks: Optional[gradio.Blocks] = None self.state_holder = {} self.iterators = defaultdict(dict) - + self.lock = asyncio.Lock() super().__init__(**kwargs) def configure_app(self, blocks: gradio.Blocks) -> None: @@ -254,6 +259,16 @@ class App(FastAPI): def file_deprecated(path: str): return file(path) + @app.post("/reset/") + @app.post("/reset") + async def reset_iterator(body: ResetBody): + if body.session_hash not in app.iterators: + return {"success": False} + async with app.lock: + app.iterators[body.session_hash][body.fn_index] = None + app.iterators[body.session_hash]["should_reset"].add(body.fn_index) + return {"success": True} + async def run_predict( body: PredictBody, username: str = Depends(get_current_user) ): @@ -266,6 +281,14 @@ class App(FastAPI): } session_state = app.state_holder[body.session_hash] iterators = app.iterators[body.session_hash] + # The should_reset set keeps track of the fn_indices + # that have been cancelled. When a job is cancelled, + # the /reset route will mark the jobs as having been reset. + # That way if the cancel job finishes BEFORE the job being cancelled + # the job being cancelled will not overwrite the state of the iterator. + # In all cases, should_reset will be the empty set the next time + # the fn_index is run. + app.iterators[body.session_hash]["should_reset"] = set([]) else: session_state = {} iterators = {} @@ -277,7 +300,10 @@ class App(FastAPI): ) iterator = output.pop("iterator", None) if hasattr(body, "session_hash"): - app.iterators[body.session_hash][fn_index] = iterator + if fn_index in app.iterators[body.session_hash]["should_reset"]: + app.iterators[body.session_hash][fn_index] = None + else: + app.iterators[body.session_hash][fn_index] = iterator if isinstance(output, Error): raise output except BaseException as error: @@ -306,7 +332,12 @@ class App(FastAPI): }, status_code=500, ) - return await run_predict(body=body, username=username) + # If this fn_index cancels jobs, then the only input we need is the + # current session hash + if app.blocks.dependencies[body.fn_index]["cancels"]: + body.data = [body.session_hash] + result = await run_predict(body=body, username=username) + return result @app.websocket("/queue/join") async def join_queue(websocket: WebSocket): @@ -315,10 +346,18 @@ class App(FastAPI): app_url = get_server_url_from_ws_url(str(websocket.url)) print(f"Server URL: {app_url}") app.blocks._queue.set_url(app_url) - await websocket.accept() event = Event(websocket) + + # In order to cancel jobs, we need the session_hash and fn_index + # to create a unique id for each job + await websocket.send_json({"msg": "send_hash"}) + session_hash = await websocket.receive_json() + event.session_hash = session_hash["session_hash"] + event.fn_index = session_hash["fn_index"] + rank = app.blocks._queue.push(event) + if rank is None: await app.blocks._queue.send_message(event, {"msg": "queue_full"}) await event.disconnect() diff --git a/gradio/test_data/blocks_configs.py b/gradio/test_data/blocks_configs.py index 6f34323abe..9c347682e3 100644 --- a/gradio/test_data/blocks_configs.py +++ b/gradio/test_data/blocks_configs.py @@ -200,6 +200,7 @@ XRAY_CONFIG = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, { "targets": [35], @@ -212,6 +213,7 @@ XRAY_CONFIG = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, { "targets": [], @@ -224,6 +226,7 @@ XRAY_CONFIG = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, ], } @@ -439,6 +442,7 @@ XRAY_CONFIG_DIFF_IDS = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, { "targets": [13], @@ -451,6 +455,7 @@ XRAY_CONFIG_DIFF_IDS = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, { "targets": [], @@ -463,6 +468,7 @@ XRAY_CONFIG_DIFF_IDS = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, ], } @@ -639,6 +645,7 @@ XRAY_CONFIG_WITH_MISTAKE = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, { "targets": [13], @@ -648,6 +655,7 @@ XRAY_CONFIG_WITH_MISTAKE = { "api_name": None, "scroll_to_output": False, "show_progress": True, + "cancels": [], }, ], } diff --git a/scripts/copy_demos.py b/scripts/copy_demos.py index 8ec665381d..2dd5870372 100644 --- a/scripts/copy_demos.py +++ b/scripts/copy_demos.py @@ -15,6 +15,7 @@ def copy_all_demos(source_dir: str, dest_dir: str): "blocks_multiple_event_triggers", "blocks_update", "calculator", + "cancel_events", "fake_gan", "fake_diffusion_with_gif", "gender_sentence_default_interpretation", diff --git a/test/test_blocks.py b/test/test_blocks.py index 62fe60588c..e101d9e3e2 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -16,6 +16,8 @@ import pytest import wandb import gradio as gr +import gradio.events +from gradio.blocks import Block from gradio.exceptions import DuplicateBlockError from gradio.routes import PredictBody from gradio.test_data.blocks_configs import XRAY_CONFIG @@ -539,5 +541,58 @@ class TestDuplicateBlockError: io2.render() +@pytest.mark.skipif( + sys.version_info < (3, 8), + reason="Tasks dont have names in 3.7", +) +@pytest.mark.asyncio +async def test_cancel_function(capsys): + async def long_job(): + await asyncio.sleep(10) + print("HELLO FROM LONG JOB") + + with gr.Blocks(): + button = gr.Button(value="Start") + click = button.click(long_job, None, None) + cancel = gr.Button(value="Cancel") + cancel.click(None, None, None, cancels=[click]) + cancel_fun, _ = gradio.events.get_cancel_function(dependencies=[click]) + + task = asyncio.create_task(long_job()) + task.set_name("foo_0") + # If cancel_fun didn't cancel long_job the message would be printed to the console + # The test would also take 10 seconds + await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True) + captured = capsys.readouterr() + assert "HELLO FROM LONG JOB" not in captured.out + + +def test_raise_exception_if_cancelling_an_event_thats_not_queued(): + def iteration(a): + yield a + + msg = "In order to cancel an event, the queue for that event must be enabled!" + with pytest.raises(ValueError, match=msg): + gr.Interface(iteration, inputs=gr.Number(), outputs=gr.Number()).launch( + prevent_thread_lock=True + ) + + with pytest.raises(ValueError, match=msg): + with gr.Blocks() as demo: + button = gr.Button(value="Predict") + click = button.click(None, None, None) + cancel = gr.Button(value="Cancel") + cancel.click(None, None, None, cancels=[click]) + demo.launch(prevent_thread_lock=True) + + with pytest.raises(ValueError, match=msg): + with gr.Blocks() as demo: + button = gr.Button(value="Predict") + click = button.click(None, None, None, queue=False) + cancel = gr.Button(value="Cancel") + cancel.click(None, None, None, cancels=[click]) + demo.queue().launch(prevent_thread_lock=True) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_external.py b/test/test_external.py index 790765d3e0..dcffeb5664 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -384,7 +384,8 @@ async def test_get_pred_from_ws(): ] mock_ws.recv.side_effect = messages data = json.dumps({"data": ["foo"], "fn_index": "foo"}) - output = await get_pred_from_ws(mock_ws, data) + hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"}) + output = await get_pred_from_ws(mock_ws, data, hash_data) assert output == {"data": ["result!"]} mock_ws.send.assert_called_once_with(data) @@ -395,8 +396,9 @@ async def test_get_pred_from_ws_raises_if_queue_full(): messages = [json.dumps({"msg": "queue_full"})] mock_ws.recv.side_effect = messages data = json.dumps({"data": ["foo"], "fn_index": "foo"}) + hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"}) with pytest.raises(gradio.Error, match="Queue is full!"): - await get_pred_from_ws(mock_ws, data) + await get_pred_from_ws(mock_ws, data, hash_data) @pytest.mark.skipif( diff --git a/test/test_interfaces.py b/test/test_interfaces.py index 754d2a0915..073da1f63e 100644 --- a/test/test_interfaces.py +++ b/test/test_interfaces.py @@ -12,6 +12,7 @@ import requests import wandb from fastapi.testclient import TestClient +import gradio from gradio.blocks import Blocks from gradio.interface import Interface, TabbedInterface, close_all, os from gradio.layouts import TabItem, Tabs @@ -268,5 +269,50 @@ class TestInterfaceInterpretation(unittest.TestCase): close_all() +@pytest.mark.parametrize( + "interface_type", ["standard", "input_only", "output_only", "unified"] +) +@pytest.mark.parametrize("live", [True, False]) +@pytest.mark.parametrize("use_generator", [True, False]) +def test_interface_adds_stop_button(interface_type, live, use_generator): + def gen_func(inp): + yield inp + + def func(inp): + return inp + + if interface_type == "standard": + interface = gradio.Interface( + gen_func if use_generator else func, "number", "number", live=live + ) + elif interface_type == "input_only": + interface = gradio.Interface( + gen_func if use_generator else func, "number", None, live=live + ) + elif interface_type == "output_only": + interface = gradio.Interface( + gen_func if use_generator else func, None, "number", live=live + ) + else: + num = gradio.Number() + interface = gradio.Interface( + gen_func if use_generator else func, num, num, live=live + ) + has_stop = ( + len( + [ + c + for c in interface.config["components"] + if c["props"].get("variant", "") == "stop" + ] + ) + == 1 + ) + if use_generator and not live: + assert has_stop + else: + assert not has_stop + + if __name__ == "__main__": unittest.main() diff --git a/test/test_queue.py b/test/test_queue.py index fe04218b28..0eefa02f48 100644 --- a/test/test_queue.py +++ b/test/test_queue.py @@ -1,9 +1,11 @@ import os -from unittest.mock import MagicMock +import sys +from unittest.mock import MagicMock, patch import pytest from gradio.queue import Event, Queue +from gradio.utils import Request os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -161,8 +163,13 @@ class TestQueueEstimation: class TestQueueProcessEvents: + @pytest.mark.skipif( + sys.version_info < (3, 8), + reason="Mocks of async context manager don't work for 3.7", + ) @pytest.mark.asyncio - async def test_process_event(self, queue: Queue, mock_event: Event): + @patch("gradio.queue.Request", new_callable=AsyncMock) + async def test_process_event(self, mock_request, queue: Queue, mock_event: Event): queue.gather_event_data = AsyncMock() queue.gather_event_data.return_value = True queue.send_message = AsyncMock() @@ -178,6 +185,14 @@ class TestQueueProcessEvents: queue.call_prediction.assert_called_once() mock_event.disconnect.assert_called_once() queue.clean_event.assert_called_once() + mock_request.assert_called_with( + method=Request.Method.POST, + url=f"{queue.server_path}reset", + json={ + "session_hash": mock_event.session_hash, + "fn_index": mock_event.fn_index, + }, + ) @pytest.mark.asyncio async def test_process_event_handles_error_when_gathering_data( @@ -246,9 +261,14 @@ class TestQueueProcessEvents: mock_event.disconnect.assert_called_once() assert queue.clean_event.call_count >= 1 + @pytest.mark.skipif( + sys.version_info < (3, 8), + reason="Mocks of async context manager don't work for 3.7", + ) @pytest.mark.asyncio + @patch("gradio.queue.Request", new_callable=AsyncMock) async def test_process_event_handles_exception_during_disconnect( - self, queue: Queue, mock_event: Event + self, mock_request, queue: Queue, mock_event: Event ): mock_event.websocket.send_json = AsyncMock() queue.call_prediction = AsyncMock( @@ -259,3 +279,11 @@ class TestQueueProcessEvents: queue.clean_event = AsyncMock() mock_event.data = None await queue.process_event(mock_event) + mock_request.assert_called_with( + method=Request.Method.POST, + url=f"{queue.server_path}reset", + json={ + "session_hash": mock_event.session_hash, + "fn_index": mock_event.fn_index, + }, + ) diff --git a/test/test_routes.py b/test/test_routes.py index 8a46de63c7..585465bb4b 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -236,6 +236,8 @@ async def test_queue_join_routes_sets_url_if_none_set(mock_get_url): msg = json.loads(await ws.recv()) if msg["msg"] == "send_data": await ws.send(json.dumps({"data": ["foo"], "fn_index": 0})) + if msg["msg"] == "send_hash": + await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"})) completed = msg["msg"] == "process_completed" assert io._queue.server_path == "foo_url" diff --git a/ui/packages/app/src/Blocks.svelte b/ui/packages/app/src/Blocks.svelte index c184a6c0ce..7e8a2dc8c5 100644 --- a/ui/packages/app/src/Blocks.svelte +++ b/ui/packages/app/src/Blocks.svelte @@ -222,6 +222,7 @@ queue, backend_fn, frontend_fn, + cancels, ...rest }, i @@ -250,7 +251,8 @@ }, queue: queue === null ? enable_queue : queue, queue_callback: handle_update, - loading_status: loading_status + loading_status: loading_status, + cancels }); function handle_update(output: any) { @@ -303,7 +305,8 @@ output_data: outputs.map((id) => instance_map[id].props.value), queue: queue === null ? enable_queue : queue, queue_callback: handle_update, - loading_status: loading_status + loading_status: loading_status, + cancels }); if (!(queue === null ? enable_queue : queue)) { diff --git a/ui/packages/app/src/api.ts b/ui/packages/app/src/api.ts index f68a787ba5..0323d40b77 100644 --- a/ui/packages/app/src/api.ts +++ b/ui/packages/app/src/api.ts @@ -77,7 +77,8 @@ export const fn = frontend_fn, output_data, queue_callback, - loading_status + loading_status, + cancels }: { action: string; payload: Payload; @@ -87,6 +88,7 @@ export const fn = output_data?: Output["data"]; queue_callback: Function; loading_status: LoadingStatusType; + cancels: Array; }): Promise => { const fn_index = payload.fn_index; @@ -153,6 +155,14 @@ export const fn = case "send_data": send_message(fn_index, payload); break; + case "send_hash": + ws_map.get(fn_index)?.send( + JSON.stringify({ + session_hash: session_hash, + fn_index: fn_index + }) + ); + break; case "queue_full": loading_status.update( fn_index, @@ -243,6 +253,21 @@ export const fn = output.average_duration as number, null ); + // Cancelled jobs are set to complete + if (cancels.length > 0) { + cancels.forEach((fn_index) => { + loading_status.update( + fn_index, + "complete", + queue, + null, + null, + null, + null + ); + ws_map.get(fn_index)?.close(); + }); + } } else { loading_status.update( fn_index, diff --git a/ui/packages/app/src/components/types.ts b/ui/packages/app/src/components/types.ts index 8eb4f7cefe..68b2b63abf 100644 --- a/ui/packages/app/src/components/types.ts +++ b/ui/packages/app/src/components/types.ts @@ -30,6 +30,7 @@ export interface Dependency { queue: boolean | null; api_name: string | null; documentation?: Array>>; + cancels: Array; } export interface LayoutNode { diff --git a/ui/packages/button/src/Button.svelte b/ui/packages/button/src/Button.svelte index f71f89f127..1263bce74e 100644 --- a/ui/packages/button/src/Button.svelte +++ b/ui/packages/button/src/Button.svelte @@ -5,7 +5,7 @@ export let style: Styles = {}; export let elem_id: string = ""; export let visible: boolean = true; - export let variant: "primary" | "secondary" = "secondary"; + export let variant: "primary" | "secondary" | "stop" = "secondary"; export let size: "sm" | "lg" = "lg"; $: ({ classes } = get_styles(style, ["full_width"])); diff --git a/ui/packages/theme/src/tokens.css b/ui/packages/theme/src/tokens.css index f97f4bd5e1..ca94dafa65 100644 --- a/ui/packages/theme/src/tokens.css +++ b/ui/packages/theme/src/tokens.css @@ -86,6 +86,11 @@ dark:from-gray-600 dark:to-gray-700 dark:hover:to-gray-600 dark:text-white dark:border-gray-600; } +.gr-button-stop { + @apply from-red-200/70 to-red-300/80 hover:to-red-200/90 text-red-600 border-red-200 + dark:from-red-700 dark:to-red-700 dark:hover:to-red-500 dark:text-white dark:border-red-600; +} + .gr-button-sm { @apply px-3 py-1 text-sm rounded-md; }