Cancel events from other events (#2433)

* WIP

* Use async iteration

* Format + comment

* Very hacky WIP

* Fix synchronization

* Add comments + tidy up implementation

* Remove print

* Fix rebase

* Lint

* Disconnect queue when cancelled

* Add stop button for interface automaticallY

* Unit tests + interface fixes

* Skip some tests on 3.7

* Skip in 3.7

* Fix skip message

* Fix for python 3.7

* Add stop variant to button variant type union

* CHANGELOG

* Add demos/gifs to the changelog
This commit is contained in:
Freddy Boulton 2022-10-14 18:43:24 -04:00 committed by GitHub
parent 07c77ece36
commit 831ae1405f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 538 additions and 35 deletions

View File

@ -1,4 +1,90 @@
# Upcoming Release
# Upcoming Release
## New Features:
### Cancelling Running Events
Running events can be cancelled when other events are triggered! To test this feature, pass the `cancels` parameter to the event listener.
For this feature to work, the queue must be enabled.
![cancel_on_change_rl](https://user-images.githubusercontent.com/41651716/195952623-61a606bd-e82b-4e1a-802e-223154cb8727.gif)
Code:
```python
import time
import gradio as gr
def fake_diffusion(steps):
for i in range(steps):
time.sleep(1)
yield str(i)
def long_prediction(*args, **kwargs):
time.sleep(10)
return 42
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
n = gr.Slider(1, 10, value=9, step=1, label="Number Steps")
run = gr.Button()
output = gr.Textbox(label="Iterative Output")
stop = gr.Button(value="Stop Iterating")
with gr.Column():
prediction = gr.Number(label="Expensive Calculation")
run_pred = gr.Button(value="Run Expensive Calculation")
with gr.Column():
cancel_on_change = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Change")
click_event = run.click(fake_diffusion, n, output)
stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
pred_event = run_pred.click(fn=long_prediction, inputs=None, outputs=prediction)
cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])
demo.queue(concurrency_count=1, max_size=20).launch()
```
For interfaces, a stop button will be added automatically if the function uses a `yield` statement.
```python
import gradio as gr
import time
def iteration(steps):
for i in range(steps):
time.sleep(0.5)
yield i
gr.Interface(iteration,
inputs=gr.Slider(minimum=1, maximum=10, step=1, value=5),
outputs=gr.Number()).queue().launch()
```
![stop_interface_rl](https://user-images.githubusercontent.com/41651716/195952883-e7ca4235-aae3-4852-8f28-96d01d0c5822.gif)
## Bug Fixes:
No changes to highlight.
## Documentation Changes:
No changes to highlight.
## Testing and Infrastructure Changes:
No changes to highlight.
## Breaking Changes:
No changes to highlight.
## Full Changelog:
* Enable running events to be cancelled from other events by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2433](https://github.com/gradio-app/gradio/pull/2433)
## Contributors Shoutout:
No changes to highlight.
# Version 3.5
## Bug Fixes:
@ -42,8 +128,6 @@ No changes to highlight.
* Fix embedded interfaces on touch screen devices by [@aliabd](https://github.com/aliabd) in [PR 2457](https://github.com/gradio-app/gradio/pull/2457)
* Upload all demos to spaces by [@aliabd](https://github.com/aliabd) in [PR 2281](https://github.com/gradio-app/gradio/pull/2281)
## Contributors Shoutout:
No changes to highlight.

49
demo/cancel_events/run.py Normal file
View File

@ -0,0 +1,49 @@
import time
import gradio as gr
def fake_diffusion(steps):
for i in range(steps):
print(f"Current step: {i}")
time.sleep(1)
yield str(i)
def long_prediction(*args, **kwargs):
time.sleep(10)
return 42
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
n = gr.Slider(1, 10, value=9, step=1, label="Number Steps")
run = gr.Button()
output = gr.Textbox(label="Iterative Output")
stop = gr.Button(value="Stop Iterating")
with gr.Column():
textbox = gr.Textbox(label="Prompt")
prediction = gr.Number(label="Expensive Calculation")
run_pred = gr.Button(value="Run Expensive Calculation")
with gr.Column():
cancel_on_change = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Change")
cancel_on_submit = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Submit")
echo = gr.Textbox(label="Echo")
with gr.Row():
with gr.Column():
image = gr.Image(source="webcam", tool="editor", label="Cancel on edit", interactive=True)
with gr.Column():
video = gr.Video(source="webcam", label="Cancel on play", interactive=True)
click_event = run.click(fake_diffusion, n, output)
stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
pred_event = run_pred.click(fn=long_prediction, inputs=[textbox], outputs=prediction)
cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])
cancel_on_submit.submit(lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event])
image.edit(None, None, None, cancels=[click_event, pred_event])
video.play(None, None, None, cancels=[click_event, pred_event])
if __name__ == "__main__":
demo.queue(concurrency_count=2, max_size=20).launch()

View File

@ -119,7 +119,8 @@ class Block:
js: Optional[str] = None,
no_target: bool = False,
queue: Optional[bool] = None,
) -> None:
cancels: List[int] | None = None,
) -> Dict[str, Any]:
"""
Adds an event to the component's dependencies.
Parameters:
@ -166,6 +167,7 @@ class Block:
"api_name": api_name,
"scroll_to_output": scroll_to_output,
"show_progress": show_progress,
"cancels": cancels if cancels else [],
}
if api_name is not None:
dependency["documentation"] = [
@ -179,6 +181,7 @@ class Block:
],
]
Context.root_block.dependencies.append(dependency)
return dependency
def get_config(self):
return {
@ -1062,6 +1065,20 @@ class Blocks(BlockContext):
if self.enable_queue and not hasattr(self, "_queue"):
self.queue()
for dep in self.dependencies:
for i in dep["cancels"]:
queue_status = self.dependencies[i]["queue"]
if queue_status is False or (
queue_status is None and not self.enable_queue
):
raise ValueError(
"In order to cancel an event, the queue for that event must be enabled! "
"You may get this error by either 1) passing a function that uses the yield keyword "
"into an interface without enabling the queue or 2) defining an event that cancels "
"another event without enabling the queue. Both can be solved by calling .queue() "
"before .launch()"
)
self.config = self.get_config_file()
self.share = share
self.encrypt = encrypt

View File

@ -1,14 +1,59 @@
from __future__ import annotations
import asyncio
import sys
import warnings
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
from gradio.blocks import Block
from gradio.blocks import Block, Context, update
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.components import Component, StatusTracker
def get_cancel_function(
dependencies: List[Dict[str, Any]]
) -> Tuple[Callable, List[int]]:
fn_to_comp = {}
for dep in dependencies:
fn_index = next(
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
)
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]
async def cancel(session_hash: str) -> None:
if sys.version_info < (3, 8):
return None
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
matching_tasks = [
task for task in asyncio.all_tasks() if task.get_name() in task_ids
]
for task in matching_tasks:
task.cancel()
await asyncio.gather(*matching_tasks, return_exceptions=True)
return (
cancel,
list(fn_to_comp.keys()),
)
def set_cancel_events(block: Block, event_name: str, cancels: List[Dict[str, Any]]):
if cancels:
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
block.set_event_trigger(
event_name,
cancel_fn,
inputs=None,
outputs=None,
queue=False,
preprocess=False,
cancels=fn_indices_to_cancel,
)
class Changeable(Block):
def change(
self,
@ -22,6 +67,7 @@ class Changeable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -38,6 +84,7 @@ class Changeable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -45,7 +92,7 @@ class Changeable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"change",
fn,
inputs,
@ -58,6 +105,8 @@ class Changeable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "change", cancels)
return dep
class Clickable(Block):
@ -73,6 +122,7 @@ class Clickable(Block):
queue=None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -89,6 +139,7 @@ class Clickable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -96,7 +147,7 @@ class Clickable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"click",
fn,
inputs,
@ -109,6 +160,8 @@ class Clickable(Block):
preprocess=preprocess,
postprocess=postprocess,
)
set_cancel_events(self, "click", cancels)
return dep
class Submittable(Block):
@ -124,6 +177,7 @@ class Submittable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -141,6 +195,7 @@ class Submittable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -148,7 +203,7 @@ class Submittable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"submit",
fn,
inputs,
@ -161,6 +216,8 @@ class Submittable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "submit", cancels)
return dep
class Editable(Block):
@ -176,6 +233,7 @@ class Editable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -192,6 +250,7 @@ class Editable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -199,7 +258,7 @@ class Editable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"edit",
fn,
inputs,
@ -212,6 +271,8 @@ class Editable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "edit", cancels)
return dep
class Clearable(Block):
@ -227,6 +288,7 @@ class Clearable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -243,6 +305,7 @@ class Clearable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -250,7 +313,7 @@ class Clearable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"submit",
fn,
inputs,
@ -263,6 +326,8 @@ class Clearable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "submit", cancels)
return dep
class Playable(Block):
@ -278,6 +343,7 @@ class Playable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -294,6 +360,7 @@ class Playable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -301,7 +368,7 @@ class Playable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"play",
fn,
inputs,
@ -314,6 +381,8 @@ class Playable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "play", cancels)
return dep
def pause(
self,
@ -327,6 +396,7 @@ class Playable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -343,6 +413,7 @@ class Playable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -350,7 +421,7 @@ class Playable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"pause",
fn,
inputs,
@ -363,6 +434,8 @@ class Playable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "pause", cancels)
return dep
def stop(
self,
@ -376,6 +449,7 @@ class Playable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -392,6 +466,7 @@ class Playable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
if status_tracker:
@ -399,7 +474,7 @@ class Playable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"stop",
fn,
inputs,
@ -412,6 +487,8 @@ class Playable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "stop", cancels)
return dep
class Streamable(Block):
@ -427,6 +504,7 @@ class Streamable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -443,6 +521,7 @@ class Streamable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
self.streaming = True
@ -452,7 +531,7 @@ class Streamable(Block):
"The 'status_tracker' parameter has been deprecated and has no effect."
)
self.set_event_trigger(
dep = self.set_event_trigger(
"stream",
fn,
inputs,
@ -465,6 +544,8 @@ class Streamable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "stream", cancels)
return dep
class Blurrable(Block):
@ -480,6 +561,7 @@ class Blurrable(Block):
queue: Optional[bool] = None,
preprocess: bool = True,
postprocess: bool = True,
cancels: List[Dict[str, Any]] | None = None,
_js: Optional[str] = None,
):
"""
@ -495,6 +577,7 @@ class Blurrable(Block):
queue: If True, will place the request on the queue, if the queue exists
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
"""
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
@ -511,3 +594,4 @@ class Blurrable(Block):
postprocess=postprocess,
queue=queue,
)
set_cancel_events(self, "blur", cancels)

View File

@ -9,6 +9,7 @@ import math
import numbers
import operator
import re
import uuid
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
@ -413,7 +414,7 @@ def get_spaces(model_name, api_key, alias, **kwargs):
async def get_pred_from_ws(
websocket: websockets.WebSocketClientProtocol, data: str
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
) -> Dict[str, Any]:
completed = False
while not completed:
@ -421,6 +422,8 @@ async def get_pred_from_ws(
resp = json.loads(msg)
if resp["msg"] == "queue_full":
raise exceptions.Error("Queue is full! Please try again.")
if resp["msg"] == "send_hash":
await websocket.send(hash_data)
elif resp["msg"] == "send_data":
await websocket.send(data)
completed = resp["msg"] == "process_completed"
@ -428,9 +431,9 @@ async def get_pred_from_ws(
def get_ws_fn(ws_url):
async def ws_fn(data):
async def ws_fn(data, hash_data):
async with websockets.connect(ws_url, open_timeout=10) as websocket:
return await get_pred_from_ws(websocket, data)
return await get_pred_from_ws(websocket, data, hash_data)
return ws_fn
@ -467,8 +470,11 @@ def get_spaces_blocks(model_name, config):
def get_fn(outputs, fn_index, use_ws):
def fn(*data):
data = json.dumps({"data": data, "fn_index": fn_index})
hash_data = json.dumps(
{"fn_index": fn_index, "session_hash": str(uuid.uuid4())}
)
if use_ws:
result = utils.synchronize_async(ws_fn, data)
result = utils.synchronize_async(ws_fn, data, hash_data)
output = result["data"]
else:
response = requests.post(api_url, headers=headers, data=data)

View File

@ -20,7 +20,7 @@ from markdown_it import MarkdownIt
from mdit_py_plugins.footnote import footnote_plugin
from gradio import Examples, interpretation, utils
from gradio.blocks import Blocks
from gradio.blocks import Blocks, update
from gradio.components import (
Button,
Component,
@ -472,9 +472,19 @@ class Interface(Blocks):
clear_btn = Button("Clear")
if not self.live:
submit_btn = Button("Submit", variant="primary")
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
if inspect.isgeneratorfunction(fn):
stop_btn = Button("Stop", variant="stop")
elif self.interface_type == self.InterfaceTypes.UNIFIED:
clear_btn = Button("Clear")
submit_btn = Button("Submit", variant="primary")
if inspect.isgeneratorfunction(fn) and not self.live:
stop_btn = Button("Stop", variant="stop")
if self.allow_flagging == "manual":
flag_btns = render_flag_btns(self.flagging_options)
@ -491,6 +501,13 @@ class Interface(Blocks):
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
clear_btn = Button("Clear")
submit_btn = Button("Generate", variant="primary")
if inspect.isgeneratorfunction(fn) and not self.live:
# Stopping jobs only works if the queue is enabled
# We don't know if the queue is enabled when the interface
# is created. We use whether a generator function is provided
# as a proxy of whether the queue will be enabled.
# Using a generator function without the queue will raise an error.
stop_btn = Button("Stop", variant="stop")
if self.allow_flagging == "manual":
flag_btns = render_flag_btns(self.flagging_options)
if self.interpretation:
@ -535,7 +552,7 @@ class Interface(Blocks):
postprocess=not (self.api_mode),
)
else:
submit_btn.click(
pred = submit_btn.click(
self.fn,
self.input_components,
self.output_components,
@ -544,6 +561,14 @@ class Interface(Blocks):
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
)
if inspect.isgeneratorfunction(fn):
stop_btn.click(
None,
inputs=None,
outputs=None,
cancels=[pred],
)
clear_btn.click(
None,
[],

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
import json
import sys
import time
from typing import Dict, List, Optional
@ -84,7 +84,9 @@ class Queue:
event = self.event_queue.pop(0)
self.active_jobs[self.active_jobs.index(None)] = event
run_coro_in_background(self.process_event, event)
task = run_coro_in_background(self.process_event, event)
if sys.version_info >= (3, 8):
task.set_name(f"{event.session_hash}_{event.fn_index}")
run_coro_in_background(self.broadcast_live_estimations)
def push(self, event: Event) -> int | None:
@ -239,6 +241,13 @@ class Queue:
)
elif response.json.get("is_generating", False):
while response.json.get("is_generating", False):
# Python 3.7 doesn't have named tasks.
# In order to determine if a task was cancelled, we
# ping the websocket to see if it was closed mid-iteration.
if sys.version_info < (3, 8):
is_alive = await self.send_message(event, {"msg": "alive?"})
if not is_alive:
return
old_response = response
await self.send_message(
event,
@ -276,6 +285,18 @@ class Queue:
pass
finally:
await self.clean_event(event)
# Always reset the state of the iterator
# If the job finished successfully, this has no effect
# If the job is cancelled, this will enable future runs
# to start "from scratch"
await Request(
method=Request.Method.POST,
url=f"{self.server_path}reset",
json={
"session_hash": event.session_hash,
"fn_index": event.fn_index,
},
)
async def send_message(self, event, data: Dict) -> bool:
try:
@ -301,6 +322,8 @@ class Event:
self.websocket = websocket
self.data: PredictBody | None = None
self.lost_connection_time: float | None = None
self.fn_index = 0
self.session_hash = "foo"
async def disconnect(self, code=1000):
await self.websocket.close(code=code)

View File

@ -77,6 +77,11 @@ class PredictBody(BaseModel):
fn_index: Optional[int]
class ResetBody(BaseModel):
session_hash: str
fn_index: int
###########
# Auth
###########
@ -93,7 +98,7 @@ class App(FastAPI):
self.blocks: Optional[gradio.Blocks] = None
self.state_holder = {}
self.iterators = defaultdict(dict)
self.lock = asyncio.Lock()
super().__init__(**kwargs)
def configure_app(self, blocks: gradio.Blocks) -> None:
@ -254,6 +259,16 @@ class App(FastAPI):
def file_deprecated(path: str):
return file(path)
@app.post("/reset/")
@app.post("/reset")
async def reset_iterator(body: ResetBody):
if body.session_hash not in app.iterators:
return {"success": False}
async with app.lock:
app.iterators[body.session_hash][body.fn_index] = None
app.iterators[body.session_hash]["should_reset"].add(body.fn_index)
return {"success": True}
async def run_predict(
body: PredictBody, username: str = Depends(get_current_user)
):
@ -266,6 +281,14 @@ class App(FastAPI):
}
session_state = app.state_holder[body.session_hash]
iterators = app.iterators[body.session_hash]
# The should_reset set keeps track of the fn_indices
# that have been cancelled. When a job is cancelled,
# the /reset route will mark the jobs as having been reset.
# That way if the cancel job finishes BEFORE the job being cancelled
# the job being cancelled will not overwrite the state of the iterator.
# In all cases, should_reset will be the empty set the next time
# the fn_index is run.
app.iterators[body.session_hash]["should_reset"] = set([])
else:
session_state = {}
iterators = {}
@ -277,7 +300,10 @@ class App(FastAPI):
)
iterator = output.pop("iterator", None)
if hasattr(body, "session_hash"):
app.iterators[body.session_hash][fn_index] = iterator
if fn_index in app.iterators[body.session_hash]["should_reset"]:
app.iterators[body.session_hash][fn_index] = None
else:
app.iterators[body.session_hash][fn_index] = iterator
if isinstance(output, Error):
raise output
except BaseException as error:
@ -306,7 +332,12 @@ class App(FastAPI):
},
status_code=500,
)
return await run_predict(body=body, username=username)
# If this fn_index cancels jobs, then the only input we need is the
# current session hash
if app.blocks.dependencies[body.fn_index]["cancels"]:
body.data = [body.session_hash]
result = await run_predict(body=body, username=username)
return result
@app.websocket("/queue/join")
async def join_queue(websocket: WebSocket):
@ -315,10 +346,18 @@ class App(FastAPI):
app_url = get_server_url_from_ws_url(str(websocket.url))
print(f"Server URL: {app_url}")
app.blocks._queue.set_url(app_url)
await websocket.accept()
event = Event(websocket)
# In order to cancel jobs, we need the session_hash and fn_index
# to create a unique id for each job
await websocket.send_json({"msg": "send_hash"})
session_hash = await websocket.receive_json()
event.session_hash = session_hash["session_hash"]
event.fn_index = session_hash["fn_index"]
rank = app.blocks._queue.push(event)
if rank is None:
await app.blocks._queue.send_message(event, {"msg": "queue_full"})
await event.disconnect()

View File

@ -200,6 +200,7 @@ XRAY_CONFIG = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
{
"targets": [35],
@ -212,6 +213,7 @@ XRAY_CONFIG = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
{
"targets": [],
@ -224,6 +226,7 @@ XRAY_CONFIG = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
],
}
@ -439,6 +442,7 @@ XRAY_CONFIG_DIFF_IDS = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
{
"targets": [13],
@ -451,6 +455,7 @@ XRAY_CONFIG_DIFF_IDS = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
{
"targets": [],
@ -463,6 +468,7 @@ XRAY_CONFIG_DIFF_IDS = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
],
}
@ -639,6 +645,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
{
"targets": [13],
@ -648,6 +655,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
"api_name": None,
"scroll_to_output": False,
"show_progress": True,
"cancels": [],
},
],
}

View File

@ -15,6 +15,7 @@ def copy_all_demos(source_dir: str, dest_dir: str):
"blocks_multiple_event_triggers",
"blocks_update",
"calculator",
"cancel_events",
"fake_gan",
"fake_diffusion_with_gif",
"gender_sentence_default_interpretation",

View File

@ -16,6 +16,8 @@ import pytest
import wandb
import gradio as gr
import gradio.events
from gradio.blocks import Block
from gradio.exceptions import DuplicateBlockError
from gradio.routes import PredictBody
from gradio.test_data.blocks_configs import XRAY_CONFIG
@ -539,5 +541,58 @@ class TestDuplicateBlockError:
io2.render()
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Tasks dont have names in 3.7",
)
@pytest.mark.asyncio
async def test_cancel_function(capsys):
async def long_job():
await asyncio.sleep(10)
print("HELLO FROM LONG JOB")
with gr.Blocks():
button = gr.Button(value="Start")
click = button.click(long_job, None, None)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
cancel_fun, _ = gradio.events.get_cancel_function(dependencies=[click])
task = asyncio.create_task(long_job())
task.set_name("foo_0")
# If cancel_fun didn't cancel long_job the message would be printed to the console
# The test would also take 10 seconds
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
captured = capsys.readouterr()
assert "HELLO FROM LONG JOB" not in captured.out
def test_raise_exception_if_cancelling_an_event_thats_not_queued():
def iteration(a):
yield a
msg = "In order to cancel an event, the queue for that event must be enabled!"
with pytest.raises(ValueError, match=msg):
gr.Interface(iteration, inputs=gr.Number(), outputs=gr.Number()).launch(
prevent_thread_lock=True
)
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
button = gr.Button(value="Predict")
click = button.click(None, None, None)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
demo.launch(prevent_thread_lock=True)
with pytest.raises(ValueError, match=msg):
with gr.Blocks() as demo:
button = gr.Button(value="Predict")
click = button.click(None, None, None, queue=False)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
demo.queue().launch(prevent_thread_lock=True)
if __name__ == "__main__":
unittest.main()

View File

@ -384,7 +384,8 @@ async def test_get_pred_from_ws():
]
mock_ws.recv.side_effect = messages
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
output = await get_pred_from_ws(mock_ws, data)
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
output = await get_pred_from_ws(mock_ws, data, hash_data)
assert output == {"data": ["result!"]}
mock_ws.send.assert_called_once_with(data)
@ -395,8 +396,9 @@ async def test_get_pred_from_ws_raises_if_queue_full():
messages = [json.dumps({"msg": "queue_full"})]
mock_ws.recv.side_effect = messages
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
with pytest.raises(gradio.Error, match="Queue is full!"):
await get_pred_from_ws(mock_ws, data)
await get_pred_from_ws(mock_ws, data, hash_data)
@pytest.mark.skipif(

View File

@ -12,6 +12,7 @@ import requests
import wandb
from fastapi.testclient import TestClient
import gradio
from gradio.blocks import Blocks
from gradio.interface import Interface, TabbedInterface, close_all, os
from gradio.layouts import TabItem, Tabs
@ -268,5 +269,50 @@ class TestInterfaceInterpretation(unittest.TestCase):
close_all()
@pytest.mark.parametrize(
"interface_type", ["standard", "input_only", "output_only", "unified"]
)
@pytest.mark.parametrize("live", [True, False])
@pytest.mark.parametrize("use_generator", [True, False])
def test_interface_adds_stop_button(interface_type, live, use_generator):
def gen_func(inp):
yield inp
def func(inp):
return inp
if interface_type == "standard":
interface = gradio.Interface(
gen_func if use_generator else func, "number", "number", live=live
)
elif interface_type == "input_only":
interface = gradio.Interface(
gen_func if use_generator else func, "number", None, live=live
)
elif interface_type == "output_only":
interface = gradio.Interface(
gen_func if use_generator else func, None, "number", live=live
)
else:
num = gradio.Number()
interface = gradio.Interface(
gen_func if use_generator else func, num, num, live=live
)
has_stop = (
len(
[
c
for c in interface.config["components"]
if c["props"].get("variant", "") == "stop"
]
)
== 1
)
if use_generator and not live:
assert has_stop
else:
assert not has_stop
if __name__ == "__main__":
unittest.main()

View File

@ -1,9 +1,11 @@
import os
from unittest.mock import MagicMock
import sys
from unittest.mock import MagicMock, patch
import pytest
from gradio.queue import Event, Queue
from gradio.utils import Request
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -161,8 +163,13 @@ class TestQueueEstimation:
class TestQueueProcessEvents:
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
async def test_process_event(self, queue: Queue, mock_event: Event):
@patch("gradio.queue.Request", new_callable=AsyncMock)
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
queue.send_message = AsyncMock()
@ -178,6 +185,14 @@ class TestQueueProcessEvents:
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
queue.clean_event.assert_called_once()
mock_request.assert_called_with(
method=Request.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,
"fn_index": mock_event.fn_index,
},
)
@pytest.mark.asyncio
async def test_process_event_handles_error_when_gathering_data(
@ -246,9 +261,14 @@ class TestQueueProcessEvents:
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queue.Request", new_callable=AsyncMock)
async def test_process_event_handles_exception_during_disconnect(
self, queue: Queue, mock_event: Event
self, mock_request, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
queue.call_prediction = AsyncMock(
@ -259,3 +279,11 @@ class TestQueueProcessEvents:
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
mock_request.assert_called_with(
method=Request.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,
"fn_index": mock_event.fn_index,
},
)

View File

@ -236,6 +236,8 @@ async def test_queue_join_routes_sets_url_if_none_set(mock_get_url):
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["foo"], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
completed = msg["msg"] == "process_completed"
assert io._queue.server_path == "foo_url"

View File

@ -222,6 +222,7 @@
queue,
backend_fn,
frontend_fn,
cancels,
...rest
},
i
@ -250,7 +251,8 @@
},
queue: queue === null ? enable_queue : queue,
queue_callback: handle_update,
loading_status: loading_status
loading_status: loading_status,
cancels
});
function handle_update(output: any) {
@ -303,7 +305,8 @@
output_data: outputs.map((id) => instance_map[id].props.value),
queue: queue === null ? enable_queue : queue,
queue_callback: handle_update,
loading_status: loading_status
loading_status: loading_status,
cancels
});
if (!(queue === null ? enable_queue : queue)) {

View File

@ -77,7 +77,8 @@ export const fn =
frontend_fn,
output_data,
queue_callback,
loading_status
loading_status,
cancels
}: {
action: string;
payload: Payload;
@ -87,6 +88,7 @@ export const fn =
output_data?: Output["data"];
queue_callback: Function;
loading_status: LoadingStatusType;
cancels: Array<number>;
}): Promise<unknown> => {
const fn_index = payload.fn_index;
@ -153,6 +155,14 @@ export const fn =
case "send_data":
send_message(fn_index, payload);
break;
case "send_hash":
ws_map.get(fn_index)?.send(
JSON.stringify({
session_hash: session_hash,
fn_index: fn_index
})
);
break;
case "queue_full":
loading_status.update(
fn_index,
@ -243,6 +253,21 @@ export const fn =
output.average_duration as number,
null
);
// Cancelled jobs are set to complete
if (cancels.length > 0) {
cancels.forEach((fn_index) => {
loading_status.update(
fn_index,
"complete",
queue,
null,
null,
null,
null
);
ws_map.get(fn_index)?.close();
});
}
} else {
loading_status.update(
fn_index,

View File

@ -30,6 +30,7 @@ export interface Dependency {
queue: boolean | null;
api_name: string | null;
documentation?: Array<Array<Array<string>>>;
cancels: Array<number>;
}
export interface LayoutNode {

View File

@ -5,7 +5,7 @@
export let style: Styles = {};
export let elem_id: string = "";
export let visible: boolean = true;
export let variant: "primary" | "secondary" = "secondary";
export let variant: "primary" | "secondary" | "stop" = "secondary";
export let size: "sm" | "lg" = "lg";
$: ({ classes } = get_styles(style, ["full_width"]));

View File

@ -86,6 +86,11 @@
dark:from-gray-600 dark:to-gray-700 dark:hover:to-gray-600 dark:text-white dark:border-gray-600;
}
.gr-button-stop {
@apply from-red-200/70 to-red-300/80 hover:to-red-200/90 text-red-600 border-red-200
dark:from-red-700 dark:to-red-700 dark:hover:to-red-500 dark:text-white dark:border-red-600;
}
.gr-button-sm {
@apply px-3 py-1 text-sm rounded-md;
}