mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Cancel events from other events (#2433)
* WIP * Use async iteration * Format + comment * Very hacky WIP * Fix synchronization * Add comments + tidy up implementation * Remove print * Fix rebase * Lint * Disconnect queue when cancelled * Add stop button for interface automaticallY * Unit tests + interface fixes * Skip some tests on 3.7 * Skip in 3.7 * Fix skip message * Fix for python 3.7 * Add stop variant to button variant type union * CHANGELOG * Add demos/gifs to the changelog
This commit is contained in:
parent
07c77ece36
commit
831ae1405f
90
CHANGELOG.md
90
CHANGELOG.md
@ -1,4 +1,90 @@
|
||||
# Upcoming Release
|
||||
# Upcoming Release
|
||||
|
||||
## New Features:
|
||||
|
||||
### Cancelling Running Events
|
||||
Running events can be cancelled when other events are triggered! To test this feature, pass the `cancels` parameter to the event listener.
|
||||
For this feature to work, the queue must be enabled.
|
||||
|
||||

|
||||
|
||||
Code:
|
||||
```python
|
||||
import time
|
||||
import gradio as gr
|
||||
|
||||
def fake_diffusion(steps):
|
||||
for i in range(steps):
|
||||
time.sleep(1)
|
||||
yield str(i)
|
||||
|
||||
def long_prediction(*args, **kwargs):
|
||||
time.sleep(10)
|
||||
return 42
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
n = gr.Slider(1, 10, value=9, step=1, label="Number Steps")
|
||||
run = gr.Button()
|
||||
output = gr.Textbox(label="Iterative Output")
|
||||
stop = gr.Button(value="Stop Iterating")
|
||||
with gr.Column():
|
||||
prediction = gr.Number(label="Expensive Calculation")
|
||||
run_pred = gr.Button(value="Run Expensive Calculation")
|
||||
with gr.Column():
|
||||
cancel_on_change = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Change")
|
||||
|
||||
click_event = run.click(fake_diffusion, n, output)
|
||||
stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
|
||||
pred_event = run_pred.click(fn=long_prediction, inputs=None, outputs=prediction)
|
||||
|
||||
cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])
|
||||
|
||||
|
||||
demo.queue(concurrency_count=1, max_size=20).launch()
|
||||
```
|
||||
|
||||
For interfaces, a stop button will be added automatically if the function uses a `yield` statement.
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import time
|
||||
|
||||
def iteration(steps):
|
||||
for i in range(steps):
|
||||
time.sleep(0.5)
|
||||
yield i
|
||||
|
||||
gr.Interface(iteration,
|
||||
inputs=gr.Slider(minimum=1, maximum=10, step=1, value=5),
|
||||
outputs=gr.Number()).queue().launch()
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
## Bug Fixes:
|
||||
No changes to highlight.
|
||||
|
||||
## Documentation Changes:
|
||||
No changes to highlight.
|
||||
|
||||
## Testing and Infrastructure Changes:
|
||||
No changes to highlight.
|
||||
|
||||
## Breaking Changes:
|
||||
No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
* Enable running events to be cancelled from other events by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2433](https://github.com/gradio-app/gradio/pull/2433)
|
||||
|
||||
## Contributors Shoutout:
|
||||
No changes to highlight.
|
||||
|
||||
|
||||
# Version 3.5
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
@ -42,8 +128,6 @@ No changes to highlight.
|
||||
* Fix embedded interfaces on touch screen devices by [@aliabd](https://github.com/aliabd) in [PR 2457](https://github.com/gradio-app/gradio/pull/2457)
|
||||
* Upload all demos to spaces by [@aliabd](https://github.com/aliabd) in [PR 2281](https://github.com/gradio-app/gradio/pull/2281)
|
||||
|
||||
|
||||
|
||||
## Contributors Shoutout:
|
||||
No changes to highlight.
|
||||
|
||||
|
49
demo/cancel_events/run.py
Normal file
49
demo/cancel_events/run.py
Normal file
@ -0,0 +1,49 @@
|
||||
import time
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def fake_diffusion(steps):
|
||||
for i in range(steps):
|
||||
print(f"Current step: {i}")
|
||||
time.sleep(1)
|
||||
yield str(i)
|
||||
|
||||
|
||||
def long_prediction(*args, **kwargs):
|
||||
time.sleep(10)
|
||||
return 42
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
n = gr.Slider(1, 10, value=9, step=1, label="Number Steps")
|
||||
run = gr.Button()
|
||||
output = gr.Textbox(label="Iterative Output")
|
||||
stop = gr.Button(value="Stop Iterating")
|
||||
with gr.Column():
|
||||
textbox = gr.Textbox(label="Prompt")
|
||||
prediction = gr.Number(label="Expensive Calculation")
|
||||
run_pred = gr.Button(value="Run Expensive Calculation")
|
||||
with gr.Column():
|
||||
cancel_on_change = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Change")
|
||||
cancel_on_submit = gr.Textbox(label="Cancel Iteration and Expensive Calculation on Submit")
|
||||
echo = gr.Textbox(label="Echo")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image = gr.Image(source="webcam", tool="editor", label="Cancel on edit", interactive=True)
|
||||
with gr.Column():
|
||||
video = gr.Video(source="webcam", label="Cancel on play", interactive=True)
|
||||
|
||||
click_event = run.click(fake_diffusion, n, output)
|
||||
stop.click(fn=None, inputs=None, outputs=None, cancels=[click_event])
|
||||
pred_event = run_pred.click(fn=long_prediction, inputs=[textbox], outputs=prediction)
|
||||
|
||||
cancel_on_change.change(None, None, None, cancels=[click_event, pred_event])
|
||||
cancel_on_submit.submit(lambda s: s, cancel_on_submit, echo, cancels=[click_event, pred_event])
|
||||
image.edit(None, None, None, cancels=[click_event, pred_event])
|
||||
video.play(None, None, None, cancels=[click_event, pred_event])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue(concurrency_count=2, max_size=20).launch()
|
@ -119,7 +119,8 @@ class Block:
|
||||
js: Optional[str] = None,
|
||||
no_target: bool = False,
|
||||
queue: Optional[bool] = None,
|
||||
) -> None:
|
||||
cancels: List[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Adds an event to the component's dependencies.
|
||||
Parameters:
|
||||
@ -166,6 +167,7 @@ class Block:
|
||||
"api_name": api_name,
|
||||
"scroll_to_output": scroll_to_output,
|
||||
"show_progress": show_progress,
|
||||
"cancels": cancels if cancels else [],
|
||||
}
|
||||
if api_name is not None:
|
||||
dependency["documentation"] = [
|
||||
@ -179,6 +181,7 @@ class Block:
|
||||
],
|
||||
]
|
||||
Context.root_block.dependencies.append(dependency)
|
||||
return dependency
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
@ -1062,6 +1065,20 @@ class Blocks(BlockContext):
|
||||
if self.enable_queue and not hasattr(self, "_queue"):
|
||||
self.queue()
|
||||
|
||||
for dep in self.dependencies:
|
||||
for i in dep["cancels"]:
|
||||
queue_status = self.dependencies[i]["queue"]
|
||||
if queue_status is False or (
|
||||
queue_status is None and not self.enable_queue
|
||||
):
|
||||
raise ValueError(
|
||||
"In order to cancel an event, the queue for that event must be enabled! "
|
||||
"You may get this error by either 1) passing a function that uses the yield keyword "
|
||||
"into an interface without enabling the queue or 2) defining an event that cancels "
|
||||
"another event without enabling the queue. Both can be solved by calling .queue() "
|
||||
"before .launch()"
|
||||
)
|
||||
|
||||
self.config = self.get_config_file()
|
||||
self.share = share
|
||||
self.encrypt = encrypt
|
||||
|
104
gradio/events.py
104
gradio/events.py
@ -1,14 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from gradio.blocks import Block
|
||||
from gradio.blocks import Block, Context, update
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio.components import Component, StatusTracker
|
||||
|
||||
|
||||
def get_cancel_function(
|
||||
dependencies: List[Dict[str, Any]]
|
||||
) -> Tuple[Callable, List[int]]:
|
||||
fn_to_comp = {}
|
||||
for dep in dependencies:
|
||||
fn_index = next(
|
||||
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
|
||||
)
|
||||
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]
|
||||
|
||||
async def cancel(session_hash: str) -> None:
|
||||
if sys.version_info < (3, 8):
|
||||
return None
|
||||
|
||||
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
|
||||
|
||||
matching_tasks = [
|
||||
task for task in asyncio.all_tasks() if task.get_name() in task_ids
|
||||
]
|
||||
for task in matching_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*matching_tasks, return_exceptions=True)
|
||||
|
||||
return (
|
||||
cancel,
|
||||
list(fn_to_comp.keys()),
|
||||
)
|
||||
|
||||
|
||||
def set_cancel_events(block: Block, event_name: str, cancels: List[Dict[str, Any]]):
|
||||
if cancels:
|
||||
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
|
||||
block.set_event_trigger(
|
||||
event_name,
|
||||
cancel_fn,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
queue=False,
|
||||
preprocess=False,
|
||||
cancels=fn_indices_to_cancel,
|
||||
)
|
||||
|
||||
|
||||
class Changeable(Block):
|
||||
def change(
|
||||
self,
|
||||
@ -22,6 +67,7 @@ class Changeable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -38,6 +84,7 @@ class Changeable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -45,7 +92,7 @@ class Changeable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"change",
|
||||
fn,
|
||||
inputs,
|
||||
@ -58,6 +105,8 @@ class Changeable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "change", cancels)
|
||||
return dep
|
||||
|
||||
|
||||
class Clickable(Block):
|
||||
@ -73,6 +122,7 @@ class Clickable(Block):
|
||||
queue=None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -89,6 +139,7 @@ class Clickable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -96,7 +147,7 @@ class Clickable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"click",
|
||||
fn,
|
||||
inputs,
|
||||
@ -109,6 +160,8 @@ class Clickable(Block):
|
||||
preprocess=preprocess,
|
||||
postprocess=postprocess,
|
||||
)
|
||||
set_cancel_events(self, "click", cancels)
|
||||
return dep
|
||||
|
||||
|
||||
class Submittable(Block):
|
||||
@ -124,6 +177,7 @@ class Submittable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -141,6 +195,7 @@ class Submittable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -148,7 +203,7 @@ class Submittable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"submit",
|
||||
fn,
|
||||
inputs,
|
||||
@ -161,6 +216,8 @@ class Submittable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "submit", cancels)
|
||||
return dep
|
||||
|
||||
|
||||
class Editable(Block):
|
||||
@ -176,6 +233,7 @@ class Editable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -192,6 +250,7 @@ class Editable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -199,7 +258,7 @@ class Editable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"edit",
|
||||
fn,
|
||||
inputs,
|
||||
@ -212,6 +271,8 @@ class Editable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "edit", cancels)
|
||||
return dep
|
||||
|
||||
|
||||
class Clearable(Block):
|
||||
@ -227,6 +288,7 @@ class Clearable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -243,6 +305,7 @@ class Clearable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -250,7 +313,7 @@ class Clearable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"submit",
|
||||
fn,
|
||||
inputs,
|
||||
@ -263,6 +326,8 @@ class Clearable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "submit", cancels)
|
||||
return dep
|
||||
|
||||
|
||||
class Playable(Block):
|
||||
@ -278,6 +343,7 @@ class Playable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -294,6 +360,7 @@ class Playable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -301,7 +368,7 @@ class Playable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"play",
|
||||
fn,
|
||||
inputs,
|
||||
@ -314,6 +381,8 @@ class Playable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "play", cancels)
|
||||
return dep
|
||||
|
||||
def pause(
|
||||
self,
|
||||
@ -327,6 +396,7 @@ class Playable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -343,6 +413,7 @@ class Playable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -350,7 +421,7 @@ class Playable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"pause",
|
||||
fn,
|
||||
inputs,
|
||||
@ -363,6 +434,8 @@ class Playable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "pause", cancels)
|
||||
return dep
|
||||
|
||||
def stop(
|
||||
self,
|
||||
@ -376,6 +449,7 @@ class Playable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -392,6 +466,7 @@ class Playable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
if status_tracker:
|
||||
@ -399,7 +474,7 @@ class Playable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"stop",
|
||||
fn,
|
||||
inputs,
|
||||
@ -412,6 +487,8 @@ class Playable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "stop", cancels)
|
||||
return dep
|
||||
|
||||
|
||||
class Streamable(Block):
|
||||
@ -427,6 +504,7 @@ class Streamable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -443,6 +521,7 @@ class Streamable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
self.streaming = True
|
||||
@ -452,7 +531,7 @@ class Streamable(Block):
|
||||
"The 'status_tracker' parameter has been deprecated and has no effect."
|
||||
)
|
||||
|
||||
self.set_event_trigger(
|
||||
dep = self.set_event_trigger(
|
||||
"stream",
|
||||
fn,
|
||||
inputs,
|
||||
@ -465,6 +544,8 @@ class Streamable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "stream", cancels)
|
||||
return dep
|
||||
|
||||
|
||||
class Blurrable(Block):
|
||||
@ -480,6 +561,7 @@ class Blurrable(Block):
|
||||
queue: Optional[bool] = None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: List[Dict[str, Any]] | None = None,
|
||||
_js: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
@ -495,6 +577,7 @@ class Blurrable(Block):
|
||||
queue: If True, will place the request on the queue, if the queue exists
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
"""
|
||||
# _js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
|
||||
|
||||
@ -511,3 +594,4 @@ class Blurrable(Block):
|
||||
postprocess=postprocess,
|
||||
queue=queue,
|
||||
)
|
||||
set_cancel_events(self, "blur", cancels)
|
||||
|
@ -9,6 +9,7 @@ import math
|
||||
import numbers
|
||||
import operator
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
|
||||
@ -413,7 +414,7 @@ def get_spaces(model_name, api_key, alias, **kwargs):
|
||||
|
||||
|
||||
async def get_pred_from_ws(
|
||||
websocket: websockets.WebSocketClientProtocol, data: str
|
||||
websocket: websockets.WebSocketClientProtocol, data: str, hash_data: str
|
||||
) -> Dict[str, Any]:
|
||||
completed = False
|
||||
while not completed:
|
||||
@ -421,6 +422,8 @@ async def get_pred_from_ws(
|
||||
resp = json.loads(msg)
|
||||
if resp["msg"] == "queue_full":
|
||||
raise exceptions.Error("Queue is full! Please try again.")
|
||||
if resp["msg"] == "send_hash":
|
||||
await websocket.send(hash_data)
|
||||
elif resp["msg"] == "send_data":
|
||||
await websocket.send(data)
|
||||
completed = resp["msg"] == "process_completed"
|
||||
@ -428,9 +431,9 @@ async def get_pred_from_ws(
|
||||
|
||||
|
||||
def get_ws_fn(ws_url):
|
||||
async def ws_fn(data):
|
||||
async def ws_fn(data, hash_data):
|
||||
async with websockets.connect(ws_url, open_timeout=10) as websocket:
|
||||
return await get_pred_from_ws(websocket, data)
|
||||
return await get_pred_from_ws(websocket, data, hash_data)
|
||||
|
||||
return ws_fn
|
||||
|
||||
@ -467,8 +470,11 @@ def get_spaces_blocks(model_name, config):
|
||||
def get_fn(outputs, fn_index, use_ws):
|
||||
def fn(*data):
|
||||
data = json.dumps({"data": data, "fn_index": fn_index})
|
||||
hash_data = json.dumps(
|
||||
{"fn_index": fn_index, "session_hash": str(uuid.uuid4())}
|
||||
)
|
||||
if use_ws:
|
||||
result = utils.synchronize_async(ws_fn, data)
|
||||
result = utils.synchronize_async(ws_fn, data, hash_data)
|
||||
output = result["data"]
|
||||
else:
|
||||
response = requests.post(api_url, headers=headers, data=data)
|
||||
|
@ -20,7 +20,7 @@ from markdown_it import MarkdownIt
|
||||
from mdit_py_plugins.footnote import footnote_plugin
|
||||
|
||||
from gradio import Examples, interpretation, utils
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.blocks import Blocks, update
|
||||
from gradio.components import (
|
||||
Button,
|
||||
Component,
|
||||
@ -472,9 +472,19 @@ class Interface(Blocks):
|
||||
clear_btn = Button("Clear")
|
||||
if not self.live:
|
||||
submit_btn = Button("Submit", variant="primary")
|
||||
# Stopping jobs only works if the queue is enabled
|
||||
# We don't know if the queue is enabled when the interface
|
||||
# is created. We use whether a generator function is provided
|
||||
# as a proxy of whether the queue will be enabled.
|
||||
# Using a generator function without the queue will raise an error.
|
||||
if inspect.isgeneratorfunction(fn):
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
|
||||
elif self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Submit", variant="primary")
|
||||
if inspect.isgeneratorfunction(fn) and not self.live:
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = render_flag_btns(self.flagging_options)
|
||||
|
||||
@ -491,6 +501,13 @@ class Interface(Blocks):
|
||||
if self.interface_type == self.InterfaceTypes.OUTPUT_ONLY:
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Generate", variant="primary")
|
||||
if inspect.isgeneratorfunction(fn) and not self.live:
|
||||
# Stopping jobs only works if the queue is enabled
|
||||
# We don't know if the queue is enabled when the interface
|
||||
# is created. We use whether a generator function is provided
|
||||
# as a proxy of whether the queue will be enabled.
|
||||
# Using a generator function without the queue will raise an error.
|
||||
stop_btn = Button("Stop", variant="stop")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btns = render_flag_btns(self.flagging_options)
|
||||
if self.interpretation:
|
||||
@ -535,7 +552,7 @@ class Interface(Blocks):
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
else:
|
||||
submit_btn.click(
|
||||
pred = submit_btn.click(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
@ -544,6 +561,14 @@ class Interface(Blocks):
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
if inspect.isgeneratorfunction(fn):
|
||||
stop_btn.click(
|
||||
None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[pred],
|
||||
)
|
||||
|
||||
clear_btn.click(
|
||||
None,
|
||||
[],
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -84,7 +84,9 @@ class Queue:
|
||||
event = self.event_queue.pop(0)
|
||||
|
||||
self.active_jobs[self.active_jobs.index(None)] = event
|
||||
run_coro_in_background(self.process_event, event)
|
||||
task = run_coro_in_background(self.process_event, event)
|
||||
if sys.version_info >= (3, 8):
|
||||
task.set_name(f"{event.session_hash}_{event.fn_index}")
|
||||
run_coro_in_background(self.broadcast_live_estimations)
|
||||
|
||||
def push(self, event: Event) -> int | None:
|
||||
@ -239,6 +241,13 @@ class Queue:
|
||||
)
|
||||
elif response.json.get("is_generating", False):
|
||||
while response.json.get("is_generating", False):
|
||||
# Python 3.7 doesn't have named tasks.
|
||||
# In order to determine if a task was cancelled, we
|
||||
# ping the websocket to see if it was closed mid-iteration.
|
||||
if sys.version_info < (3, 8):
|
||||
is_alive = await self.send_message(event, {"msg": "alive?"})
|
||||
if not is_alive:
|
||||
return
|
||||
old_response = response
|
||||
await self.send_message(
|
||||
event,
|
||||
@ -276,6 +285,18 @@ class Queue:
|
||||
pass
|
||||
finally:
|
||||
await self.clean_event(event)
|
||||
# Always reset the state of the iterator
|
||||
# If the job finished successfully, this has no effect
|
||||
# If the job is cancelled, this will enable future runs
|
||||
# to start "from scratch"
|
||||
await Request(
|
||||
method=Request.Method.POST,
|
||||
url=f"{self.server_path}reset",
|
||||
json={
|
||||
"session_hash": event.session_hash,
|
||||
"fn_index": event.fn_index,
|
||||
},
|
||||
)
|
||||
|
||||
async def send_message(self, event, data: Dict) -> bool:
|
||||
try:
|
||||
@ -301,6 +322,8 @@ class Event:
|
||||
self.websocket = websocket
|
||||
self.data: PredictBody | None = None
|
||||
self.lost_connection_time: float | None = None
|
||||
self.fn_index = 0
|
||||
self.session_hash = "foo"
|
||||
|
||||
async def disconnect(self, code=1000):
|
||||
await self.websocket.close(code=code)
|
||||
|
@ -77,6 +77,11 @@ class PredictBody(BaseModel):
|
||||
fn_index: Optional[int]
|
||||
|
||||
|
||||
class ResetBody(BaseModel):
|
||||
session_hash: str
|
||||
fn_index: int
|
||||
|
||||
|
||||
###########
|
||||
# Auth
|
||||
###########
|
||||
@ -93,7 +98,7 @@ class App(FastAPI):
|
||||
self.blocks: Optional[gradio.Blocks] = None
|
||||
self.state_holder = {}
|
||||
self.iterators = defaultdict(dict)
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def configure_app(self, blocks: gradio.Blocks) -> None:
|
||||
@ -254,6 +259,16 @@ class App(FastAPI):
|
||||
def file_deprecated(path: str):
|
||||
return file(path)
|
||||
|
||||
@app.post("/reset/")
|
||||
@app.post("/reset")
|
||||
async def reset_iterator(body: ResetBody):
|
||||
if body.session_hash not in app.iterators:
|
||||
return {"success": False}
|
||||
async with app.lock:
|
||||
app.iterators[body.session_hash][body.fn_index] = None
|
||||
app.iterators[body.session_hash]["should_reset"].add(body.fn_index)
|
||||
return {"success": True}
|
||||
|
||||
async def run_predict(
|
||||
body: PredictBody, username: str = Depends(get_current_user)
|
||||
):
|
||||
@ -266,6 +281,14 @@ class App(FastAPI):
|
||||
}
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
iterators = app.iterators[body.session_hash]
|
||||
# The should_reset set keeps track of the fn_indices
|
||||
# that have been cancelled. When a job is cancelled,
|
||||
# the /reset route will mark the jobs as having been reset.
|
||||
# That way if the cancel job finishes BEFORE the job being cancelled
|
||||
# the job being cancelled will not overwrite the state of the iterator.
|
||||
# In all cases, should_reset will be the empty set the next time
|
||||
# the fn_index is run.
|
||||
app.iterators[body.session_hash]["should_reset"] = set([])
|
||||
else:
|
||||
session_state = {}
|
||||
iterators = {}
|
||||
@ -277,7 +300,10 @@ class App(FastAPI):
|
||||
)
|
||||
iterator = output.pop("iterator", None)
|
||||
if hasattr(body, "session_hash"):
|
||||
app.iterators[body.session_hash][fn_index] = iterator
|
||||
if fn_index in app.iterators[body.session_hash]["should_reset"]:
|
||||
app.iterators[body.session_hash][fn_index] = None
|
||||
else:
|
||||
app.iterators[body.session_hash][fn_index] = iterator
|
||||
if isinstance(output, Error):
|
||||
raise output
|
||||
except BaseException as error:
|
||||
@ -306,7 +332,12 @@ class App(FastAPI):
|
||||
},
|
||||
status_code=500,
|
||||
)
|
||||
return await run_predict(body=body, username=username)
|
||||
# If this fn_index cancels jobs, then the only input we need is the
|
||||
# current session hash
|
||||
if app.blocks.dependencies[body.fn_index]["cancels"]:
|
||||
body.data = [body.session_hash]
|
||||
result = await run_predict(body=body, username=username)
|
||||
return result
|
||||
|
||||
@app.websocket("/queue/join")
|
||||
async def join_queue(websocket: WebSocket):
|
||||
@ -315,10 +346,18 @@ class App(FastAPI):
|
||||
app_url = get_server_url_from_ws_url(str(websocket.url))
|
||||
print(f"Server URL: {app_url}")
|
||||
app.blocks._queue.set_url(app_url)
|
||||
|
||||
await websocket.accept()
|
||||
event = Event(websocket)
|
||||
|
||||
# In order to cancel jobs, we need the session_hash and fn_index
|
||||
# to create a unique id for each job
|
||||
await websocket.send_json({"msg": "send_hash"})
|
||||
session_hash = await websocket.receive_json()
|
||||
event.session_hash = session_hash["session_hash"]
|
||||
event.fn_index = session_hash["fn_index"]
|
||||
|
||||
rank = app.blocks._queue.push(event)
|
||||
|
||||
if rank is None:
|
||||
await app.blocks._queue.send_message(event, {"msg": "queue_full"})
|
||||
await event.disconnect()
|
||||
|
@ -200,6 +200,7 @@ XRAY_CONFIG = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
{
|
||||
"targets": [35],
|
||||
@ -212,6 +213,7 @@ XRAY_CONFIG = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
{
|
||||
"targets": [],
|
||||
@ -224,6 +226,7 @@ XRAY_CONFIG = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -439,6 +442,7 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
{
|
||||
"targets": [13],
|
||||
@ -451,6 +455,7 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
{
|
||||
"targets": [],
|
||||
@ -463,6 +468,7 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
@ -639,6 +645,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
{
|
||||
"targets": [13],
|
||||
@ -648,6 +655,7 @@ XRAY_CONFIG_WITH_MISTAKE = {
|
||||
"api_name": None,
|
||||
"scroll_to_output": False,
|
||||
"show_progress": True,
|
||||
"cancels": [],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ def copy_all_demos(source_dir: str, dest_dir: str):
|
||||
"blocks_multiple_event_triggers",
|
||||
"blocks_update",
|
||||
"calculator",
|
||||
"cancel_events",
|
||||
"fake_gan",
|
||||
"fake_diffusion_with_gif",
|
||||
"gender_sentence_default_interpretation",
|
||||
|
@ -16,6 +16,8 @@ import pytest
|
||||
import wandb
|
||||
|
||||
import gradio as gr
|
||||
import gradio.events
|
||||
from gradio.blocks import Block
|
||||
from gradio.exceptions import DuplicateBlockError
|
||||
from gradio.routes import PredictBody
|
||||
from gradio.test_data.blocks_configs import XRAY_CONFIG
|
||||
@ -539,5 +541,58 @@ class TestDuplicateBlockError:
|
||||
io2.render()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Tasks dont have names in 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_function(capsys):
|
||||
async def long_job():
|
||||
await asyncio.sleep(10)
|
||||
print("HELLO FROM LONG JOB")
|
||||
|
||||
with gr.Blocks():
|
||||
button = gr.Button(value="Start")
|
||||
click = button.click(long_job, None, None)
|
||||
cancel = gr.Button(value="Cancel")
|
||||
cancel.click(None, None, None, cancels=[click])
|
||||
cancel_fun, _ = gradio.events.get_cancel_function(dependencies=[click])
|
||||
|
||||
task = asyncio.create_task(long_job())
|
||||
task.set_name("foo_0")
|
||||
# If cancel_fun didn't cancel long_job the message would be printed to the console
|
||||
# The test would also take 10 seconds
|
||||
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "HELLO FROM LONG JOB" not in captured.out
|
||||
|
||||
|
||||
def test_raise_exception_if_cancelling_an_event_thats_not_queued():
|
||||
def iteration(a):
|
||||
yield a
|
||||
|
||||
msg = "In order to cancel an event, the queue for that event must be enabled!"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
gr.Interface(iteration, inputs=gr.Number(), outputs=gr.Number()).launch(
|
||||
prevent_thread_lock=True
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
with gr.Blocks() as demo:
|
||||
button = gr.Button(value="Predict")
|
||||
click = button.click(None, None, None)
|
||||
cancel = gr.Button(value="Cancel")
|
||||
cancel.click(None, None, None, cancels=[click])
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
with gr.Blocks() as demo:
|
||||
button = gr.Button(value="Predict")
|
||||
click = button.click(None, None, None, queue=False)
|
||||
cancel = gr.Button(value="Cancel")
|
||||
cancel.click(None, None, None, cancels=[click])
|
||||
demo.queue().launch(prevent_thread_lock=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -384,7 +384,8 @@ async def test_get_pred_from_ws():
|
||||
]
|
||||
mock_ws.recv.side_effect = messages
|
||||
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
||||
output = await get_pred_from_ws(mock_ws, data)
|
||||
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
|
||||
output = await get_pred_from_ws(mock_ws, data, hash_data)
|
||||
assert output == {"data": ["result!"]}
|
||||
mock_ws.send.assert_called_once_with(data)
|
||||
|
||||
@ -395,8 +396,9 @@ async def test_get_pred_from_ws_raises_if_queue_full():
|
||||
messages = [json.dumps({"msg": "queue_full"})]
|
||||
mock_ws.recv.side_effect = messages
|
||||
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
||||
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
|
||||
with pytest.raises(gradio.Error, match="Queue is full!"):
|
||||
await get_pred_from_ws(mock_ws, data)
|
||||
await get_pred_from_ws(mock_ws, data, hash_data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -12,6 +12,7 @@ import requests
|
||||
import wandb
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import gradio
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.interface import Interface, TabbedInterface, close_all, os
|
||||
from gradio.layouts import TabItem, Tabs
|
||||
@ -268,5 +269,50 @@ class TestInterfaceInterpretation(unittest.TestCase):
|
||||
close_all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"interface_type", ["standard", "input_only", "output_only", "unified"]
|
||||
)
|
||||
@pytest.mark.parametrize("live", [True, False])
|
||||
@pytest.mark.parametrize("use_generator", [True, False])
|
||||
def test_interface_adds_stop_button(interface_type, live, use_generator):
|
||||
def gen_func(inp):
|
||||
yield inp
|
||||
|
||||
def func(inp):
|
||||
return inp
|
||||
|
||||
if interface_type == "standard":
|
||||
interface = gradio.Interface(
|
||||
gen_func if use_generator else func, "number", "number", live=live
|
||||
)
|
||||
elif interface_type == "input_only":
|
||||
interface = gradio.Interface(
|
||||
gen_func if use_generator else func, "number", None, live=live
|
||||
)
|
||||
elif interface_type == "output_only":
|
||||
interface = gradio.Interface(
|
||||
gen_func if use_generator else func, None, "number", live=live
|
||||
)
|
||||
else:
|
||||
num = gradio.Number()
|
||||
interface = gradio.Interface(
|
||||
gen_func if use_generator else func, num, num, live=live
|
||||
)
|
||||
has_stop = (
|
||||
len(
|
||||
[
|
||||
c
|
||||
for c in interface.config["components"]
|
||||
if c["props"].get("variant", "") == "stop"
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
if use_generator and not live:
|
||||
assert has_stop
|
||||
else:
|
||||
assert not has_stop
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,9 +1,11 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gradio.queue import Event, Queue
|
||||
from gradio.utils import Request
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -161,8 +163,13 @@ class TestQueueEstimation:
|
||||
|
||||
|
||||
class TestQueueProcessEvents:
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event(self, queue: Queue, mock_event: Event):
|
||||
@patch("gradio.queue.Request", new_callable=AsyncMock)
|
||||
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
|
||||
queue.gather_event_data = AsyncMock()
|
||||
queue.gather_event_data.return_value = True
|
||||
queue.send_message = AsyncMock()
|
||||
@ -178,6 +185,14 @@ class TestQueueProcessEvents:
|
||||
queue.call_prediction.assert_called_once()
|
||||
mock_event.disconnect.assert_called_once()
|
||||
queue.clean_event.assert_called_once()
|
||||
mock_request.assert_called_with(
|
||||
method=Request.Method.POST,
|
||||
url=f"{queue.server_path}reset",
|
||||
json={
|
||||
"session_hash": mock_event.session_hash,
|
||||
"fn_index": mock_event.fn_index,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event_handles_error_when_gathering_data(
|
||||
@ -246,9 +261,14 @@ class TestQueueProcessEvents:
|
||||
mock_event.disconnect.assert_called_once()
|
||||
assert queue.clean_event.call_count >= 1
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@patch("gradio.queue.Request", new_callable=AsyncMock)
|
||||
async def test_process_event_handles_exception_during_disconnect(
|
||||
self, queue: Queue, mock_event: Event
|
||||
self, mock_request, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
queue.call_prediction = AsyncMock(
|
||||
@ -259,3 +279,11 @@ class TestQueueProcessEvents:
|
||||
queue.clean_event = AsyncMock()
|
||||
mock_event.data = None
|
||||
await queue.process_event(mock_event)
|
||||
mock_request.assert_called_with(
|
||||
method=Request.Method.POST,
|
||||
url=f"{queue.server_path}reset",
|
||||
json={
|
||||
"session_hash": mock_event.session_hash,
|
||||
"fn_index": mock_event.fn_index,
|
||||
},
|
||||
)
|
||||
|
@ -236,6 +236,8 @@ async def test_queue_join_routes_sets_url_if_none_set(mock_get_url):
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": ["foo"], "fn_index": 0}))
|
||||
if msg["msg"] == "send_hash":
|
||||
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
|
||||
completed = msg["msg"] == "process_completed"
|
||||
assert io._queue.server_path == "foo_url"
|
||||
|
||||
|
@ -222,6 +222,7 @@
|
||||
queue,
|
||||
backend_fn,
|
||||
frontend_fn,
|
||||
cancels,
|
||||
...rest
|
||||
},
|
||||
i
|
||||
@ -250,7 +251,8 @@
|
||||
},
|
||||
queue: queue === null ? enable_queue : queue,
|
||||
queue_callback: handle_update,
|
||||
loading_status: loading_status
|
||||
loading_status: loading_status,
|
||||
cancels
|
||||
});
|
||||
|
||||
function handle_update(output: any) {
|
||||
@ -303,7 +305,8 @@
|
||||
output_data: outputs.map((id) => instance_map[id].props.value),
|
||||
queue: queue === null ? enable_queue : queue,
|
||||
queue_callback: handle_update,
|
||||
loading_status: loading_status
|
||||
loading_status: loading_status,
|
||||
cancels
|
||||
});
|
||||
|
||||
if (!(queue === null ? enable_queue : queue)) {
|
||||
|
@ -77,7 +77,8 @@ export const fn =
|
||||
frontend_fn,
|
||||
output_data,
|
||||
queue_callback,
|
||||
loading_status
|
||||
loading_status,
|
||||
cancels
|
||||
}: {
|
||||
action: string;
|
||||
payload: Payload;
|
||||
@ -87,6 +88,7 @@ export const fn =
|
||||
output_data?: Output["data"];
|
||||
queue_callback: Function;
|
||||
loading_status: LoadingStatusType;
|
||||
cancels: Array<number>;
|
||||
}): Promise<unknown> => {
|
||||
const fn_index = payload.fn_index;
|
||||
|
||||
@ -153,6 +155,14 @@ export const fn =
|
||||
case "send_data":
|
||||
send_message(fn_index, payload);
|
||||
break;
|
||||
case "send_hash":
|
||||
ws_map.get(fn_index)?.send(
|
||||
JSON.stringify({
|
||||
session_hash: session_hash,
|
||||
fn_index: fn_index
|
||||
})
|
||||
);
|
||||
break;
|
||||
case "queue_full":
|
||||
loading_status.update(
|
||||
fn_index,
|
||||
@ -243,6 +253,21 @@ export const fn =
|
||||
output.average_duration as number,
|
||||
null
|
||||
);
|
||||
// Cancelled jobs are set to complete
|
||||
if (cancels.length > 0) {
|
||||
cancels.forEach((fn_index) => {
|
||||
loading_status.update(
|
||||
fn_index,
|
||||
"complete",
|
||||
queue,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
ws_map.get(fn_index)?.close();
|
||||
});
|
||||
}
|
||||
} else {
|
||||
loading_status.update(
|
||||
fn_index,
|
||||
|
@ -30,6 +30,7 @@ export interface Dependency {
|
||||
queue: boolean | null;
|
||||
api_name: string | null;
|
||||
documentation?: Array<Array<Array<string>>>;
|
||||
cancels: Array<number>;
|
||||
}
|
||||
|
||||
export interface LayoutNode {
|
||||
|
@ -5,7 +5,7 @@
|
||||
export let style: Styles = {};
|
||||
export let elem_id: string = "";
|
||||
export let visible: boolean = true;
|
||||
export let variant: "primary" | "secondary" = "secondary";
|
||||
export let variant: "primary" | "secondary" | "stop" = "secondary";
|
||||
export let size: "sm" | "lg" = "lg";
|
||||
|
||||
$: ({ classes } = get_styles(style, ["full_width"]));
|
||||
|
@ -86,6 +86,11 @@
|
||||
dark:from-gray-600 dark:to-gray-700 dark:hover:to-gray-600 dark:text-white dark:border-gray-600;
|
||||
}
|
||||
|
||||
.gr-button-stop {
|
||||
@apply from-red-200/70 to-red-300/80 hover:to-red-200/90 text-red-600 border-red-200
|
||||
dark:from-red-700 dark:to-red-700 dark:hover:to-red-500 dark:text-white dark:border-red-600;
|
||||
}
|
||||
|
||||
.gr-button-sm {
|
||||
@apply px-3 py-1 text-sm rounded-md;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user