mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Refactor Cancelling Logic To Use /cancel (#8370)
* Cancel refactor * add changeset * add changeset * types * Add code * Fix types --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
96d8de2312
commit
48eeea4eaa
7
.changeset/deep-weeks-show.md
Normal file
7
.changeset/deep-weeks-show.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/app": patch
|
||||
"@gradio/client": patch
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
feat:Refactor Cancelling Logic To Use /cancel
|
@ -106,7 +106,7 @@ export function transform_api_info(
|
||||
dependencyIndex !== -1
|
||||
? config.dependencies.find((dep) => dep.id == dependencyIndex)
|
||||
?.types
|
||||
: { continuous: false, generator: false };
|
||||
: { continuous: false, generator: false, cancel: false };
|
||||
|
||||
if (
|
||||
dependencyIndex !== -1 &&
|
||||
|
@ -46,7 +46,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
|
||||
component: "Textbox"
|
||||
}
|
||||
],
|
||||
type: { continuous: false, generator: false }
|
||||
type: { continuous: false, generator: false, cancel: false }
|
||||
}
|
||||
},
|
||||
unnamed_endpoints: {
|
||||
@ -68,7 +68,7 @@ export const transformed_api_info: ApiInfo<ApiData> = {
|
||||
component: "Textbox"
|
||||
}
|
||||
],
|
||||
type: { continuous: false, generator: false }
|
||||
type: { continuous: false, generator: false, cancel: false }
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -395,7 +395,8 @@ export const config_response: Config = {
|
||||
cancels: [],
|
||||
types: {
|
||||
continuous: false,
|
||||
generator: false
|
||||
generator: false,
|
||||
cancel: false
|
||||
},
|
||||
collects_event_data: false,
|
||||
trigger_after: null,
|
||||
@ -421,7 +422,8 @@ export const config_response: Config = {
|
||||
cancels: [],
|
||||
types: {
|
||||
continuous: false,
|
||||
generator: false
|
||||
generator: false,
|
||||
cancel: false
|
||||
},
|
||||
collects_event_data: false,
|
||||
trigger_after: null,
|
||||
@ -447,7 +449,8 @@ export const config_response: Config = {
|
||||
cancels: [],
|
||||
types: {
|
||||
continuous: false,
|
||||
generator: false
|
||||
generator: false,
|
||||
cancel: false
|
||||
},
|
||||
collects_event_data: false,
|
||||
trigger_after: null,
|
||||
|
@ -235,6 +235,7 @@ export interface Dependency {
|
||||
export interface DependencyTypes {
|
||||
continuous: boolean;
|
||||
generator: boolean;
|
||||
cancel: boolean;
|
||||
}
|
||||
|
||||
export interface Payload {
|
||||
|
@ -122,6 +122,7 @@ export function submit(
|
||||
fn_index: fn_index
|
||||
});
|
||||
|
||||
let reset_request = {};
|
||||
let cancel_request = {};
|
||||
if (protocol === "ws") {
|
||||
if (websocket && websocket.readyState === 0) {
|
||||
@ -131,10 +132,11 @@ export function submit(
|
||||
} else {
|
||||
websocket.close();
|
||||
}
|
||||
cancel_request = { fn_index, session_hash };
|
||||
reset_request = { fn_index, session_hash };
|
||||
} else {
|
||||
stream?.close();
|
||||
cancel_request = { event_id };
|
||||
reset_request = { event_id };
|
||||
cancel_request = { event_id, session_hash, fn_index };
|
||||
}
|
||||
|
||||
try {
|
||||
@ -142,10 +144,18 @@ export function submit(
|
||||
throw new Error("Could not resolve app config");
|
||||
}
|
||||
|
||||
if ("event_id" in cancel_request) {
|
||||
await fetch(`${config.root}/cancel`, {
|
||||
headers: { "Content-Type": "application/json" },
|
||||
method: "POST",
|
||||
body: JSON.stringify(cancel_request)
|
||||
});
|
||||
}
|
||||
|
||||
await fetch(`${config.root}/reset`, {
|
||||
headers: { "Content-Type": "application/json" },
|
||||
method: "POST",
|
||||
body: JSON.stringify(cancel_request)
|
||||
body: JSON.stringify(reset_request)
|
||||
});
|
||||
} catch (e) {
|
||||
console.warn(
|
||||
|
@ -73,7 +73,7 @@ from gradio.utils import (
|
||||
TupleNoPrint,
|
||||
check_function_inputs_match,
|
||||
component_or_layout_class,
|
||||
get_cancel_function,
|
||||
get_cancelled_fn_indices,
|
||||
get_continuous_fn,
|
||||
get_package_version,
|
||||
get_upload_folder,
|
||||
@ -541,12 +541,7 @@ class BlockFunction:
|
||||
self.rendered_in = rendered_in
|
||||
|
||||
# We need to keep track of which events are cancel events
|
||||
# in two places:
|
||||
# 1. So that we can skip postprocessing for cancel events.
|
||||
# They return event_ids that have been cancelled but there
|
||||
# are no output components
|
||||
# 2. So that we can place the ProcessCompletedMessage in the
|
||||
# event stream so that clients can close the stream when necessary
|
||||
# so that the client can call the /cancel route directly
|
||||
self.is_cancel_function = is_cancel_function
|
||||
|
||||
self.spaces_auto_wrap()
|
||||
@ -589,6 +584,7 @@ class BlockFunction:
|
||||
"types": {
|
||||
"continuous": self.types_continuous,
|
||||
"generator": self.types_generator,
|
||||
"cancel": self.is_cancel_function,
|
||||
},
|
||||
"collects_event_data": self.collects_event_data,
|
||||
"trigger_after": self.trigger_after,
|
||||
@ -1377,7 +1373,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
updated_cancels = [
|
||||
root_context.fns[i].get_config() for i in dependency.cancels
|
||||
]
|
||||
dependency.fn = get_cancel_function(updated_cancels)[0]
|
||||
dependency.cancels = get_cancelled_fn_indices(updated_cancels)
|
||||
root_context.fns[root_context.fn_id] = dependency
|
||||
root_context.fn_id += 1
|
||||
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
|
||||
@ -1694,17 +1690,9 @@ Received outputs:
|
||||
block_fn: BlockFunction,
|
||||
predictions: list | dict,
|
||||
state: SessionState | None,
|
||||
) -> Any:
|
||||
) -> list[Any]:
|
||||
state = state or SessionState(self)
|
||||
|
||||
# If the function is a cancel function, 'predictions' are the ids of
|
||||
# the event in the queue that has been cancelled. We need these
|
||||
# so that the server can put the ProcessCompleted message in the event stream
|
||||
# Cancel events have no output components, so we need to return early otherise the output
|
||||
# be None.
|
||||
if block_fn.is_cancel_function:
|
||||
return predictions
|
||||
|
||||
if isinstance(predictions, dict) and len(predictions) > 0:
|
||||
predictions = convert_component_dict_to_list(
|
||||
[block._id for block in block_fn.outputs], predictions
|
||||
@ -1920,7 +1908,7 @@ Received outputs:
|
||||
for o in zip(*preds)
|
||||
]
|
||||
if root_path is not None:
|
||||
data = processing_utils.add_root_url(data, root_path, None)
|
||||
data = processing_utils.add_root_url(data, root_path, None) # type: ignore
|
||||
data = list(zip(*data))
|
||||
is_generating, iterator = None, None
|
||||
else:
|
||||
|
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from gradio.blocks import Block, Component
|
||||
|
||||
from gradio.context import get_blocks_context
|
||||
from gradio.utils import get_cancel_function
|
||||
from gradio.utils import get_cancelled_fn_indices
|
||||
|
||||
|
||||
def set_cancel_events(
|
||||
@ -36,7 +36,7 @@ def set_cancel_events(
|
||||
if cancels:
|
||||
if not isinstance(cancels, list):
|
||||
cancels = [cancels]
|
||||
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
|
||||
fn_indices_to_cancel = get_cancelled_fn_indices(cancels)
|
||||
|
||||
root_block = get_blocks_context()
|
||||
if root_block is None:
|
||||
@ -44,7 +44,7 @@ def set_cancel_events(
|
||||
|
||||
root_block.set_event_trigger(
|
||||
triggers,
|
||||
cancel_fn,
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
queue=False,
|
||||
|
@ -624,13 +624,8 @@ class App(FastAPI):
|
||||
|
||||
@app.post("/reset/")
|
||||
@app.post("/reset")
|
||||
async def reset_iterator(body: ResetBody):
|
||||
if body.event_id not in app.iterators:
|
||||
return {"success": False}
|
||||
async with app.lock:
|
||||
del app.iterators[body.event_id]
|
||||
app.iterators_to_reset.add(body.event_id)
|
||||
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
|
||||
async def reset_iterator(body: ResetBody): # noqa: ARG001
|
||||
# No-op, all the cancelling/reset logic handled by /cancel
|
||||
return {"success": True}
|
||||
|
||||
@app.get("/heartbeat/{session_hash}")
|
||||
@ -739,18 +734,6 @@ class App(FastAPI):
|
||||
fn=fn,
|
||||
root_path=root_path,
|
||||
)
|
||||
if fn.is_cancel_function:
|
||||
# Need to complete the job so that the client disconnects
|
||||
blocks = app.get_blocks()
|
||||
if body.session_hash in blocks._queue.pending_messages_per_session:
|
||||
for event_id in output["data"]:
|
||||
message = ProcessCompletedMessage(
|
||||
output={}, success=True, event_id=event_id
|
||||
)
|
||||
blocks._queue.pending_messages_per_session[ # type: ignore
|
||||
body.session_hash
|
||||
].put_nowait(message)
|
||||
|
||||
except BaseException as error:
|
||||
show_error = app.get_blocks().show_error or isinstance(error, Error)
|
||||
traceback.print_exc()
|
||||
@ -823,13 +806,24 @@ class App(FastAPI):
|
||||
await cancel_tasks({f"{body.session_hash}_{body.fn_index}"})
|
||||
blocks = app.get_blocks()
|
||||
# Need to complete the job so that the client disconnects
|
||||
if body.session_hash in blocks._queue.pending_messages_per_session:
|
||||
session_open = (
|
||||
body.session_hash in blocks._queue.pending_messages_per_session
|
||||
)
|
||||
event_running = (
|
||||
body.event_id
|
||||
in blocks._queue.pending_event_ids_session.get(body.session_hash, {})
|
||||
)
|
||||
if session_open and event_running:
|
||||
message = ProcessCompletedMessage(
|
||||
output={}, success=True, event_id=body.event_id
|
||||
)
|
||||
blocks._queue.pending_messages_per_session[
|
||||
body.session_hash
|
||||
].put_nowait(message)
|
||||
if body.event_id in app.iterators:
|
||||
async with app.lock:
|
||||
del app.iterators[body.event_id]
|
||||
app.iterators_to_reset.add(body.event_id)
|
||||
return {"success": True}
|
||||
|
||||
@app.get("/call/{api_name}/{event_id}", dependencies=[Depends(login_check)])
|
||||
|
@ -872,7 +872,7 @@ def get_function_with_locals(
|
||||
|
||||
async def cancel_tasks(task_ids: set[str]) -> list[str]:
|
||||
tasks = [(task, task.get_name()) for task in asyncio.all_tasks()]
|
||||
event_ids = []
|
||||
event_ids: list[str] = []
|
||||
matching_tasks = []
|
||||
for task, name in tasks:
|
||||
if "<gradio-sep>" not in name:
|
||||
@ -891,27 +891,19 @@ def set_task_name(task, session_hash: str, fn_index: int, event_id: str, batch:
|
||||
task.set_name(f"{session_hash}_{fn_index}<gradio-sep>{event_id}")
|
||||
|
||||
|
||||
def get_cancel_function(
|
||||
def get_cancelled_fn_indices(
|
||||
dependencies: list[dict[str, Any]],
|
||||
) -> tuple[Callable, list[int]]:
|
||||
fn_to_comp = {}
|
||||
) -> list[int]:
|
||||
fn_indices = []
|
||||
for dep in dependencies:
|
||||
root_block = get_blocks_context()
|
||||
if root_block:
|
||||
fn_index = next(
|
||||
i for i, d in root_block.fns.items() if d.get_config() == dep
|
||||
)
|
||||
fn_to_comp[fn_index] = [root_block.blocks[o] for o in dep["outputs"]]
|
||||
fn_indices.append(fn_index)
|
||||
|
||||
async def cancel(session_hash: str) -> list[str]:
|
||||
task_ids = {f"{session_hash}_{fn}" for fn in fn_to_comp}
|
||||
event_ids = await cancel_tasks(task_ids)
|
||||
return event_ids
|
||||
|
||||
return (
|
||||
cancel,
|
||||
list(fn_to_comp.keys()),
|
||||
)
|
||||
return fn_indices
|
||||
|
||||
|
||||
def get_type_hints(fn):
|
||||
|
@ -207,15 +207,6 @@
|
||||
|
||||
const current_status = loading_status.get_status_for_fn(dep_index);
|
||||
messages = messages.filter(({ fn_index }) => fn_index !== dep_index);
|
||||
if (dep.cancels) {
|
||||
await Promise.all(
|
||||
dep.cancels.map(async (fn_index) => {
|
||||
const submission = submit_map.get(fn_index);
|
||||
submission?.cancel();
|
||||
return submission;
|
||||
})
|
||||
);
|
||||
}
|
||||
if (current_status === "pending" || current_status === "generating") {
|
||||
dep.pending_request = true;
|
||||
}
|
||||
@ -242,6 +233,14 @@
|
||||
handle_update(v, dep_index);
|
||||
}
|
||||
});
|
||||
} else if (dep.types.cancel && dep.cancels) {
|
||||
await Promise.all(
|
||||
dep.cancels.map(async (fn_index) => {
|
||||
const submission = submit_map.get(fn_index);
|
||||
submission?.cancel();
|
||||
return submission;
|
||||
})
|
||||
);
|
||||
} else {
|
||||
if (dep.backend_fn) {
|
||||
if (dep.trigger_mode === "once") {
|
||||
|
@ -35,6 +35,7 @@ export interface ComponentMeta {
|
||||
export interface DependencyTypes {
|
||||
continuous: boolean;
|
||||
generator: boolean;
|
||||
cancel: boolean;
|
||||
}
|
||||
|
||||
/** An event payload that is sent with an API request */
|
||||
|
@ -24,7 +24,7 @@ import gradio as gr
|
||||
from gradio.data_classes import GradioModel, GradioRootModel
|
||||
from gradio.events import SelectData
|
||||
from gradio.exceptions import DuplicateBlockError
|
||||
from gradio.utils import assert_configs_are_equivalent_besides_ids
|
||||
from gradio.utils import assert_configs_are_equivalent_besides_ids, cancel_tasks
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
@ -335,13 +335,29 @@ class TestBlocksMethods:
|
||||
|
||||
for i, dependency in enumerate(demo.config["dependencies"]):
|
||||
if i == 3:
|
||||
assert dependency["types"] == {"continuous": True, "generator": True}
|
||||
assert dependency["types"] == {
|
||||
"continuous": True,
|
||||
"generator": True,
|
||||
"cancel": False,
|
||||
}
|
||||
if i == 0:
|
||||
assert dependency["types"] == {"continuous": False, "generator": False}
|
||||
assert dependency["types"] == {
|
||||
"continuous": False,
|
||||
"generator": False,
|
||||
"cancel": False,
|
||||
}
|
||||
if i == 1:
|
||||
assert dependency["types"] == {"continuous": False, "generator": True}
|
||||
assert dependency["types"] == {
|
||||
"continuous": False,
|
||||
"generator": True,
|
||||
"cancel": False,
|
||||
}
|
||||
if i == 2:
|
||||
assert dependency["types"] == {"continuous": True, "generator": True}
|
||||
assert dependency["types"] == {
|
||||
"continuous": True,
|
||||
"generator": True,
|
||||
"cancel": False,
|
||||
}
|
||||
|
||||
@patch(
|
||||
"gradio.themes.ThemeClass.from_hub",
|
||||
@ -1265,18 +1281,17 @@ class TestCancel:
|
||||
await asyncio.sleep(10)
|
||||
print("HELLO FROM LONG JOB")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Blocks():
|
||||
button = gr.Button(value="Start")
|
||||
click = button.click(long_job, None, None)
|
||||
cancel = gr.Button(value="Cancel")
|
||||
cancel.click(None, None, None, cancels=[click])
|
||||
|
||||
cancel_fun = demo.fns[demo.default_config.fn_id - 1].fn
|
||||
task = asyncio.create_task(long_job())
|
||||
task.set_name("foo_0<gradio-sep>event")
|
||||
# If cancel_fun didn't cancel long_job the message would be printed to the console
|
||||
# The test would also take 10 seconds
|
||||
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
|
||||
await asyncio.gather(task, cancel_tasks({"foo_0"}), return_exceptions=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "HELLO FROM LONG JOB" not in captured.out
|
||||
|
||||
@ -1296,17 +1311,15 @@ class TestCancel:
|
||||
cancel = gr.Button(value="Cancel")
|
||||
cancel.click(None, None, None, cancels=[click])
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Blocks():
|
||||
with gr.Tab("Demo 1"):
|
||||
demo1.render()
|
||||
with gr.Tab("Demo 2"):
|
||||
demo2.render()
|
||||
|
||||
cancel_fun = demo.fns[demo.default_config.fn_id - 1].fn
|
||||
|
||||
task = asyncio.create_task(long_job())
|
||||
task.set_name("foo_1<gradio-sep>event")
|
||||
await asyncio.gather(task, cancel_fun("foo"), return_exceptions=True)
|
||||
await asyncio.gather(task, cancel_tasks({"foo_1"}), return_exceptions=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "HELLO FROM LONG JOB" not in captured.out
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user