Properly dequeue cancelled events when multiple apps are rendered (#2540)

* Fix render + tests

* Add comment + changelog
This commit is contained in:
Freddy Boulton 2022-10-26 16:29:47 -04:00 committed by GitHub
parent 933feb48ad
commit 099e1e84ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 39 deletions

View File

@ -59,6 +59,7 @@ inference time of 80 seconds).
* Fixes a bug with `cancels` in event triggers so that it works properly if multiple
Blocks are rendered by [@abidlabs](https://github.com/abidlabs) in [PR 2530](https://github.com/gradio-app/gradio/pull/2530)
* Prevent invalid targets of events from crashing the whole application. [@pngwn](https://github.com/pngwn) in [PR 2534](https://github.com/gradio-app/gradio/pull/2534)
* Properly dequeue cancelled events when multiple apps are rendered by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2540](https://github.com/gradio-app/gradio/pull/2540)
## Documentation Changes:
* Added an example interactive dashboard to the "Tabular & Plots" section of the Demos page by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2508](https://github.com/gradio-app/gradio/pull/2508)

View File

@ -45,7 +45,7 @@ from gradio.documentation import (
set_documentation_group,
)
from gradio.exceptions import DuplicateBlockError
from gradio.utils import component_or_layout_class, delete_none
from gradio.utils import component_or_layout_class, delete_none, get_cancel_function
set_documentation_group("blocks")
@ -622,7 +622,7 @@ class Blocks(BlockContext):
Context.root_block.blocks.update(self.blocks)
Context.root_block.fns.extend(self.fns)
dependency_offset = len(Context.root_block.dependencies)
for dependency in self.dependencies:
for i, dependency in enumerate(self.dependencies):
api_name = dependency["api_name"]
if api_name is not None:
api_name_ = utils.append_unique_suffix(
@ -639,6 +639,18 @@ class Blocks(BlockContext):
dependency["cancels"] = [
c + dependency_offset for c in dependency["cancels"]
]
# Recreate the cancel function so that it has the latest
# dependency fn indices. This is necessary to properly cancel
# events in the backend
if dependency["cancels"]:
updated_cancels = [
Context.root_block.dependencies[i]
for i in dependency["cancels"]
]
new_fn = BlockFunction(
get_cancel_function(updated_cancels)[0], False, True
)
Context.root_block.fns[dependency_offset + i] = new_fn
Context.root_block.dependencies.append(dependency)
Context.root_block.temp_dirs = Context.root_block.temp_dirs | self.temp_dirs
@ -650,7 +662,6 @@ class Blocks(BlockContext):
"""Checks if a particular Blocks function is callable (i.e. not stateful or a generator)."""
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
block_fn = self.fns[fn_index]
if inspect.isasyncgenfunction(block_fn.fn):
return False

View File

@ -1,45 +1,15 @@
from __future__ import annotations
import asyncio
import sys
import warnings
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional
from gradio.blocks import Block, Context, update
from gradio.blocks import Block
from gradio.utils import get_cancel_function
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.components import Component, StatusTracker
def get_cancel_function(
dependencies: List[Dict[str, Any]]
) -> Tuple[Callable, List[int]]:
fn_to_comp = {}
for dep in dependencies:
fn_index = next(
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
)
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]
async def cancel(session_hash: str) -> None:
if sys.version_info < (3, 8):
return None
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
matching_tasks = [
task for task in asyncio.all_tasks() if task.get_name() in task_ids
]
for task in matching_tasks:
task.cancel()
await asyncio.gather(*matching_tasks, return_exceptions=True)
return (
cancel,
list(fn_to_comp.keys()),
)
def set_cancel_events(
block: Block, event_name: str, cancels: None | Dict[str, Any] | List[Dict[str, Any]]
):

View File

@ -10,6 +10,7 @@ import json.decoder
import os
import pkgutil
import random
import sys
import warnings
from contextlib import contextmanager
from distutils.version import StrictVersion
@ -35,6 +36,7 @@ import requests
from pydantic import BaseModel, Json, parse_obj_as
import gradio
from gradio.context import Context
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.blocks import BlockContext
@ -685,3 +687,32 @@ def validate_url(possible_url: str) -> bool:
def is_update(val):
return type(val) is dict and "update" in val.get("__type__", "")
def get_cancel_function(
dependencies: List[Dict[str, Any]]
) -> Tuple[Callable, List[int]]:
fn_to_comp = {}
for dep in dependencies:
fn_index = next(
i for i, d in enumerate(Context.root_block.dependencies) if d == dep
)
fn_to_comp[fn_index] = [Context.root_block.blocks[o] for o in dep["outputs"]]
async def cancel(session_hash: str) -> None:
if sys.version_info < (3, 8):
return None
task_ids = set([f"{session_hash}_{fn}" for fn in fn_to_comp])
matching_tasks = [
task for task in asyncio.all_tasks() if task.get_name() in task_ids
]
for task in matching_tasks:
task.cancel()
await asyncio.gather(*matching_tasks, return_exceptions=True)
return (
cancel,
list(fn_to_comp.keys()),
)

View File

@ -17,6 +17,7 @@ import wandb
import gradio as gr
import gradio.events
import gradio.utils
from gradio.exceptions import DuplicateBlockError
from gradio.routes import PredictBody
from gradio.test_data.blocks_configs import XRAY_CONFIG
@ -745,13 +746,13 @@ class TestCancel:
await asyncio.sleep(10)
print("HELLO FROM LONG JOB")
with gr.Blocks():
with gr.Blocks() as demo:
button = gr.Button(value="Start")
click = button.click(long_job, None, None)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
cancel_fun, _ = gradio.events.get_cancel_function(dependencies=[click])
cancel_fun = demo.fns[-1].fn
task = asyncio.create_task(long_job())
task.set_name("foo_0")
# If cancel_fun didn't cancel long_job the message would be printed to the console
@ -789,7 +790,7 @@ class TestCancel:
cancel_fun = demo.fns[-1].fn
task = asyncio.create_task(long_job())
task.set_name("foo_0")
task.set_name("foo_1")
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
captured = capsys.readouterr()
assert "HELLO FROM LONG JOB" not in captured.out