Refactor Cancelling Logic To Use /cancel (#8370)

* Cancel refactor

* add changeset

* add changeset

* types

* Add code

* Fix types

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-06-05 15:32:24 -04:00 committed by GitHub
parent 96d8de2312
commit 48eeea4eaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 93 additions and 85 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/app": patch
"@gradio/client": patch
"gradio": patch
---
feat:Refactor Cancelling Logic To Use /cancel

View File

@ -106,7 +106,7 @@ export function transform_api_info(
dependencyIndex !== -1
? config.dependencies.find((dep) => dep.id == dependencyIndex)
?.types
: { continuous: false, generator: false };
: { continuous: false, generator: false, cancel: false };
if (
dependencyIndex !== -1 &&

View File

@ -46,7 +46,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
component: "Textbox"
}
],
type: { continuous: false, generator: false }
type: { continuous: false, generator: false, cancel: false }
}
},
unnamed_endpoints: {
@ -68,7 +68,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
component: "Textbox"
}
],
type: { continuous: false, generator: false }
type: { continuous: false, generator: false, cancel: false }
}
}
};
@ -395,7 +395,8 @@ export const config_response: Config = {
cancels: [],
types: {
continuous: false,
generator: false
generator: false,
cancel: false
},
collects_event_data: false,
trigger_after: null,
@ -421,7 +422,8 @@ export const config_response: Config = {
cancels: [],
types: {
continuous: false,
generator: false
generator: false,
cancel: false
},
collects_event_data: false,
trigger_after: null,
@ -447,7 +449,8 @@ export const config_response: Config = {
cancels: [],
types: {
continuous: false,
generator: false
generator: false,
cancel: false
},
collects_event_data: false,
trigger_after: null,

View File

@ -235,6 +235,7 @@ export interface Dependency {
export interface DependencyTypes {
continuous: boolean;
generator: boolean;
cancel: boolean;
}
export interface Payload {

View File

@ -122,6 +122,7 @@ export function submit(
fn_index: fn_index
});
let reset_request = {};
let cancel_request = {};
if (protocol === "ws") {
if (websocket && websocket.readyState === 0) {
@ -131,10 +132,11 @@ export function submit(
} else {
websocket.close();
}
cancel_request = { fn_index, session_hash };
reset_request = { fn_index, session_hash };
} else {
stream?.close();
cancel_request = { event_id };
reset_request = { event_id };
cancel_request = { event_id, session_hash, fn_index };
}
try {
@ -142,10 +144,18 @@ export function submit(
throw new Error("Could not resolve app config");
}
if ("event_id" in cancel_request) {
await fetch(`${config.root}/cancel`, {
headers: { "Content-Type": "application/json" },
method: "POST",
body: JSON.stringify(cancel_request)
});
}
await fetch(`${config.root}/reset`, {
headers: { "Content-Type": "application/json" },
method: "POST",
body: JSON.stringify(cancel_request)
body: JSON.stringify(reset_request)
});
} catch (e) {
console.warn(

View File

@ -73,7 +73,7 @@ from gradio.utils import (
TupleNoPrint,
check_function_inputs_match,
component_or_layout_class,
get_cancel_function,
get_cancelled_fn_indices,
get_continuous_fn,
get_package_version,
get_upload_folder,
@ -541,12 +541,7 @@ class BlockFunction:
self.rendered_in = rendered_in
# We need to keep track of which events are cancel events
# in two places:
# 1. So that we can skip postprocessing for cancel events.
# They return event_ids that have been cancelled but there
# are no output components
# 2. So that we can place the ProcessCompletedMessage in the
# event stream so that clients can close the stream when necessary
# so that the client can call the /cancel route directly
self.is_cancel_function = is_cancel_function
self.spaces_auto_wrap()
@ -589,6 +584,7 @@ class BlockFunction:
"types": {
"continuous": self.types_continuous,
"generator": self.types_generator,
"cancel": self.is_cancel_function,
},
"collects_event_data": self.collects_event_data,
"trigger_after": self.trigger_after,
@ -1377,7 +1373,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
updated_cancels = [
root_context.fns[i].get_config() for i in dependency.cancels
]
dependency.fn = get_cancel_function(updated_cancels)[0]
dependency.cancels = get_cancelled_fn_indices(updated_cancels)
root_context.fns[root_context.fn_id] = dependency
root_context.fn_id += 1
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
@ -1694,17 +1690,9 @@ Received outputs:
block_fn: BlockFunction,
predictions: list | dict,
state: SessionState | None,
) -> Any:
) -> list[Any]:
state = state or SessionState(self)
# If the function is a cancel function, 'predictions' are the ids of
# the event in the queue that has been cancelled. We need these
# so that the server can put the ProcessCompleted message in the event stream
# Cancel events have no output components, so we need to return early otherise the output
# be None.
if block_fn.is_cancel_function:
return predictions
if isinstance(predictions, dict) and len(predictions) > 0:
predictions = convert_component_dict_to_list(
[block._id for block in block_fn.outputs], predictions
@ -1920,7 +1908,7 @@ Received outputs:
for o in zip(*preds)
]
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
data = processing_utils.add_root_url(data, root_path, None) # type: ignore
data = list(zip(*data))
is_generating, iterator = None, None
else:

View File

@ -26,7 +26,7 @@ if TYPE_CHECKING:
from gradio.blocks import Block, Component
from gradio.context import get_blocks_context
from gradio.utils import get_cancel_function
from gradio.utils import get_cancelled_fn_indices
def set_cancel_events(
@ -36,7 +36,7 @@ def set_cancel_events(
if cancels:
if not isinstance(cancels, list):
cancels = [cancels]
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
fn_indices_to_cancel = get_cancelled_fn_indices(cancels)
root_block = get_blocks_context()
if root_block is None:
@ -44,7 +44,7 @@ def set_cancel_events(
root_block.set_event_trigger(
triggers,
cancel_fn,
fn=None,
inputs=None,
outputs=None,
queue=False,

View File

@ -624,13 +624,8 @@ class App(FastAPI):
@app.post("/reset/")
@app.post("/reset")
async def reset_iterator(body: ResetBody):
if body.event_id not in app.iterators:
return {"success": False}
async with app.lock:
del app.iterators[body.event_id]
app.iterators_to_reset.add(body.event_id)
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
async def reset_iterator(body: ResetBody): # noqa: ARG001
# No-op, all the cancelling/reset logic handled by /cancel
return {"success": True}
@app.get("/heartbeat/{session_hash}")
@ -739,18 +734,6 @@ class App(FastAPI):
fn=fn,
root_path=root_path,
)
if fn.is_cancel_function:
# Need to complete the job so that the client disconnects
blocks = app.get_blocks()
if body.session_hash in blocks._queue.pending_messages_per_session:
for event_id in output["data"]:
message = ProcessCompletedMessage(
output={}, success=True, event_id=event_id
)
blocks._queue.pending_messages_per_session[ # type: ignore
body.session_hash
].put_nowait(message)
except BaseException as error:
show_error = app.get_blocks().show_error or isinstance(error, Error)
traceback.print_exc()
@ -823,13 +806,24 @@ class App(FastAPI):
await cancel_tasks({f"{body.session_hash}_{body.fn_index}"})
blocks = app.get_blocks()
# Need to complete the job so that the client disconnects
if body.session_hash in blocks._queue.pending_messages_per_session:
session_open = (
body.session_hash in blocks._queue.pending_messages_per_session
)
event_running = (
body.event_id
in blocks._queue.pending_event_ids_session.get(body.session_hash, {})
)
if session_open and event_running:
message = ProcessCompletedMessage(
output={}, success=True, event_id=body.event_id
)
blocks._queue.pending_messages_per_session[
body.session_hash
].put_nowait(message)
if body.event_id in app.iterators:
async with app.lock:
del app.iterators[body.event_id]
app.iterators_to_reset.add(body.event_id)
return {"success": True}
@app.get("/call/{api_name}/{event_id}", dependencies=[Depends(login_check)])

View File

@ -872,7 +872,7 @@ def get_function_with_locals(
async def cancel_tasks(task_ids: set[str]) -> list[str]:
tasks = [(task, task.get_name()) for task in asyncio.all_tasks()]
event_ids = []
event_ids: list[str] = []
matching_tasks = []
for task, name in tasks:
if "<gradio-sep>" not in name:
@ -891,27 +891,19 @@ def set_task_name(task, session_hash: str, fn_index: int, event_id: str, batch:
task.set_name(f"{session_hash}_{fn_index}<gradio-sep>{event_id}")
def get_cancel_function(
def get_cancelled_fn_indices(
dependencies: list[dict[str, Any]],
) -> tuple[Callable, list[int]]:
fn_to_comp = {}
) -> list[int]:
fn_indices = []
for dep in dependencies:
root_block = get_blocks_context()
if root_block:
fn_index = next(
i for i, d in root_block.fns.items() if d.get_config() == dep
)
fn_to_comp[fn_index] = [root_block.blocks[o] for o in dep["outputs"]]
fn_indices.append(fn_index)
async def cancel(session_hash: str) -> list[str]:
task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp}
event_ids = await cancel_tasks(task_ids)
return event_ids
return (
cancel,
list(fn_to_comp.keys()),
)
return fn_indices
def get_type_hints(fn):

View File

@ -207,15 +207,6 @@
const current_status = loading_status.get_status_for_fn(dep_index);
messages = messages.filter(({ fn_index }) => fn_index !== dep_index);
if (dep.cancels) {
await Promise.all(
dep.cancels.map(async (fn_index) => {
const submission = submit_map.get(fn_index);
submission?.cancel();
return submission;
})
);
}
if (current_status === "pending" || current_status === "generating") {
dep.pending_request = true;
}
@ -242,6 +233,14 @@
handle_update(v, dep_index);
}
});
} else if (dep.types.cancel && dep.cancels) {
await Promise.all(
dep.cancels.map(async (fn_index) => {
const submission = submit_map.get(fn_index);
submission?.cancel();
return submission;
})
);
} else {
if (dep.backend_fn) {
if (dep.trigger_mode === "once") {

View File

@ -35,6 +35,7 @@ export interface ComponentMeta {
export interface DependencyTypes {
continuous: boolean;
generator: boolean;
cancel: boolean;
}
/** An event payload that is sent with an API request */

View File

@ -24,7 +24,7 @@ import gradio as gr
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import SelectData
from gradio.exceptions import DuplicateBlockError
from gradio.utils import assert_configs_are_equivalent_besides_ids
from gradio.utils import assert_configs_are_equivalent_besides_ids, cancel_tasks
pytest_plugins = ("pytest_asyncio",)
@ -335,13 +335,29 @@ class TestBlocksMethods:
for i, dependency in enumerate(demo.config["dependencies"]):
if i == 3:
assert dependency["types"] == {"continuous": True, "generator": True}
assert dependency["types"] == {
"continuous": True,
"generator": True,
"cancel": False,
}
if i == 0:
assert dependency["types"] == {"continuous": False, "generator": False}
assert dependency["types"] == {
"continuous": False,
"generator": False,
"cancel": False,
}
if i == 1:
assert dependency["types"] == {"continuous": False, "generator": True}
assert dependency["types"] == {
"continuous": False,
"generator": True,
"cancel": False,
}
if i == 2:
assert dependency["types"] == {"continuous": True, "generator": True}
assert dependency["types"] == {
"continuous": True,
"generator": True,
"cancel": False,
}
@patch(
"gradio.themes.ThemeClass.from_hub",
@ -1265,18 +1281,17 @@ class TestCancel:
await asyncio.sleep(10)
print("HELLO FROM LONG JOB")
with gr.Blocks() as demo:
with gr.Blocks():
button = gr.Button(value="Start")
click = button.click(long_job, None, None)
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
cancel_fun = demo.fns[demo.default_config.fn_id - 1].fn
task = asyncio.create_task(long_job())
task.set_name("foo_0<gradio-sep>event")
# If cancel_fun didn't cancel long_job the message would be printed to the console
# The test would also take 10 seconds
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
await asyncio.gather(task, cancel_tasks({"foo_0"}), return_exceptions=True)
captured = capsys.readouterr()
assert "HELLO FROM LONG JOB" not in captured.out
@ -1296,17 +1311,15 @@ class TestCancel:
cancel = gr.Button(value="Cancel")
cancel.click(None, None, None, cancels=[click])
with gr.Blocks() as demo:
with gr.Blocks():
with gr.Tab("Demo 1"):
demo1.render()
with gr.Tab("Demo 2"):
demo2.render()
cancel_fun = demo.fns[demo.default_config.fn_id - 1].fn
task = asyncio.create_task(long_job())
task.set_name("foo_1<gradio-sep>event")
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
await asyncio.gather(task, cancel_tasks({"foo_1"}), return_exceptions=True)
captured = capsys.readouterr()
assert "HELLO FROM LONG JOB" not in captured.out