mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
Properly dequeue cancelled events when multiple apps are rendered (#2540)
* Fix render + tests * Add comment + changelog
This commit is contained in:
parent
933feb48ad
commit
099e1e84ec
@ -59,6 +59,7 @@ inference time of 80 seconds).
|
||||
* Fixes a bug with `cancels` in event triggers so that it works properly if multiple
|
||||
Blocks are rendered by [@abidlabs](https://github.com/abidlabs) in [PR 2530](https://github.com/gradio-app/gradio/pull/2530)
|
||||
* Prevent invalid targets of events from crashing the whole application. [@pngwn](https://github.com/pngwn) in [PR 2534](https://github.com/gradio-app/gradio/pull/2534)
|
||||
* Properly dequeue cancelled events when multiple apps are rendered by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2540](https://github.com/gradio-app/gradio/pull/2540)
|
||||
|
||||
## Documentation Changes:
|
||||
* Added an example interactive dashboard to the "Tabular & Plots" section of the Demos page by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2508](https://github.com/gradio-app/gradio/pull/2508)
|
||||
|
@ -45,7 +45,7 @@ from gradio.documentation import (
|
||||
set_documentation_group,
|
||||
)
|
||||
from gradio.exceptions import DuplicateBlockError
|
||||
from gradio.utils import component_or_layout_class, delete_none
|
||||
from gradio.utils import component_or_layout_class, delete_none, get_cancel_function
|
||||
|
||||
set_documentation_group("blocks")
|
||||
|
||||
@ -622,7 +622,7 @@ class Blocks(BlockContext):
|
||||
Context.root_block.blocks.update(self.blocks)
|
||||
Context.root_block.fns.extend(self.fns)
|
||||
dependency_offset = len(Context.root_block.dependencies)
|
||||
for dependency in self.dependencies:
|
||||
for i, dependency in enumerate(self.dependencies):
|
||||
api_name = dependency["api_name"]
|
||||
if api_name is not None:
|
||||
api_name_ = utils.append_unique_suffix(
|
||||
@ -639,6 +639,18 @@ class Blocks(BlockContext):
|
||||
dependency["cancels"] = [
|
||||
c + dependency_offset for c in dependency["cancels"]
|
||||
]
|
||||
# Recreate the cancel function so that it has the latest
|
||||
# dependency fn indices. This is necessary to properly cancel
|
||||
# events in the backend
|
||||
if dependency["cancels"]:
|
||||
updated_cancels = [
|
||||
Context.root_block.dependencies[i]
|
||||
for i in dependency["cancels"]
|
||||
]
|
||||
new_fn = BlockFunction(
|
||||
get_cancel_function(updated_cancels)[0], False, True
|
||||
)
|
||||
Context.root_block.fns[dependency_offset + i] = new_fn
|
||||
Context.root_block.dependencies.append(dependency)
|
||||
Context.root_block.temp_dirs = Context.root_block.temp_dirs | self.temp_dirs
|
||||
|
||||
@ -650,7 +662,6 @@ class Blocks(BlockContext):
|
||||
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
block_fn = self.fns[fn_index]
|
||||
|
||||
if inspect.isasyncgenfunction(block_fn.fn):
|
||||
return False
|
||||
|
@ -1,45 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional
|
||||
|
||||
from gradio.blocks import Block, Context, update
|
||||
from gradio.blocks import Block
|
||||
from gradio.utils import get_cancel_function
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio.components import Component, StatusTracker
|
||||
|
||||
|
||||
def get_cancel_function(
|
||||
dependencies: List[Dict[str, Any]]
|
||||
) -> Tuple[Callable, List[int]]:
|
||||
fn_to_comp = {}
|
||||
for dep in dependencies:
|
||||
fn_index = next(
|
||||
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
|
||||
)
|
||||
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]
|
||||
|
||||
async def cancel(session_hash: str) -> None:
|
||||
if sys.version_info < (3, 8):
|
||||
return None
|
||||
|
||||
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
|
||||
|
||||
matching_tasks = [
|
||||
task for task in asyncio.all_tasks() if task.get_name() in task_ids
|
||||
]
|
||||
for task in matching_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*matching_tasks, return_exceptions=True)
|
||||
|
||||
return (
|
||||
cancel,
|
||||
list(fn_to_comp.keys()),
|
||||
)
|
||||
|
||||
|
||||
def set_cancel_events(
|
||||
block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
|
||||
):
|
||||
|
@ -10,6 +10,7 @@ import json.decoder
|
||||
import os
|
||||
import pkgutil
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import StrictVersion
|
||||
@ -35,6 +36,7 @@ import requests
|
||||
from pydantic import BaseModel, Json, parse_obj_as
|
||||
|
||||
import gradio
|
||||
from gradio.context import Context
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio.blocks import BlockContext
|
||||
@ -685,3 +687,32 @@ def validate_url(possible_url: str) -> bool:
|
||||
|
||||
def is_update(val):
|
||||
return type(val) is dict and "update" in val.get("__type__", "")
|
||||
|
||||
|
||||
def get_cancel_function(
|
||||
dependencies: List[Dict[str, Any]]
|
||||
) -> Tuple[Callable, List[int]]:
|
||||
fn_to_comp = {}
|
||||
for dep in dependencies:
|
||||
fn_index = next(
|
||||
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
|
||||
)
|
||||
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]
|
||||
|
||||
async def cancel(session_hash: str) -> None:
|
||||
if sys.version_info < (3, 8):
|
||||
return None
|
||||
|
||||
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
|
||||
|
||||
matching_tasks = [
|
||||
task for task in asyncio.all_tasks() if task.get_name() in task_ids
|
||||
]
|
||||
for task in matching_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*matching_tasks, return_exceptions=True)
|
||||
|
||||
return (
|
||||
cancel,
|
||||
list(fn_to_comp.keys()),
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ import wandb
|
||||
|
||||
import gradio as gr
|
||||
import gradio.events
|
||||
import gradio.utils
|
||||
from gradio.exceptions import DuplicateBlockError
|
||||
from gradio.routes import PredictBody
|
||||
from gradio.test_data.blocks_configs import XRAY_CONFIG
|
||||
@ -745,13 +746,13 @@ class TestCancel:
|
||||
await asyncio.sleep(10)
|
||||
print("HELLO FROM LONG JOB")
|
||||
|
||||
with gr.Blocks():
|
||||
with gr.Blocks() as demo:
|
||||
button = gr.Button(value="Start")
|
||||
click = button.click(long_job, None, None)
|
||||
cancel = gr.Button(value="Cancel")
|
||||
cancel.click(None, None, None, cancels=[click])
|
||||
cancel_fun, _ = gradio.events.get_cancel_function(dependencies=[click])
|
||||
|
||||
cancel_fun = demo.fns[-1].fn
|
||||
task = asyncio.create_task(long_job())
|
||||
task.set_name("foo_0")
|
||||
# If cancel_fun didn't cancel long_job the message would be printed to the console
|
||||
@ -789,7 +790,7 @@ class TestCancel:
|
||||
cancel_fun = demo.fns[-1].fn
|
||||
|
||||
task = asyncio.create_task(long_job())
|
||||
task.set_name("foo_0")
|
||||
task.set_name("foo_1")
|
||||
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "HELLO FROM LONG JOB" not in captured.out
|
||||
|
Loading…
Reference in New Issue
Block a user