fix cancels (#4225)

* fix cancels

* changelog

* refactor to make it work
This commit is contained in:
pngwn 2023-05-16 22:37:28 +01:00 committed by GitHub
parent a26e9afde3
commit c3dadaebbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 57 deletions

View File

@ -14,6 +14,7 @@
- Gradio will no longer send any analytics or call home if analytics are disabled with the GRADIO_ANALYTICS_ENABLED environment variable. By [@akx](https://github.com/akx) in [PR 4194](https://github.com/gradio-app/gradio/pull/4194) and [PR 4236](https://github.com/gradio-app/gradio/pull/4236)
- The deprecation warnings for kwargs now show the actual stack level for the invocation, by [@akx](https://github.com/akx) in [PR 4203](https://github.com/gradio-app/gradio/pull/4203).
- Fix "TypeError: issubclass() arg 1 must be a class" When use Optional[Types] by [@lingfengchencn](https://github.com/lingfengchencn) in [PR 4200](https://github.com/gradio-app/gradio/pull/4200).
- Ensure cancelling functions work correctly by [@pngwn](https://github.com/pngwn) in [PR 4225](https://github.com/gradio-app/gradio/pull/4225)
- Fixes a bug with typing.get_type_hints() on Python 3.9 by [@abidlabs](https://github.com/abidlabs) in [PR 4228](https://github.com/gradio-app/gradio/pull/4228).
## Other Changes:

View File

@ -51,7 +51,7 @@ type client_return = {
type SubmitReturn = {
on: event;
off: event;
cancel: () => void;
cancel: () => Promise<void>;
destroy: () => void;
};
@ -230,8 +230,6 @@ export async function client(
let config: Config;
let api_map: Record<string, number> = {};
const listener_map: ListenerMap<EventType> = {};
let jwt: false | string = false;
if (hf_token && space_id) {
@ -335,6 +333,7 @@ export async function client(
): SubmitReturn {
let fn_index: number;
let api_info;
if (typeof endpoint === "number") {
fn_index = endpoint;
api_info = api.unnamed_endpoints[fn_index];
@ -355,6 +354,8 @@ export async function client(
const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
let payload: Payload;
let complete: false | Record<string, any> = false;
const listener_map: ListenerMap<EventType> = {};
//@ts-ignore
handle_blob(
@ -394,6 +395,14 @@ export async function client(
)
: output.data;
if (status_code == 200) {
fire_event({
type: "data",
endpoint: _endpoint,
fn_index,
data: output.data,
time: new Date()
});
fire_event({
type: "status",
endpoint: _endpoint,
@ -403,14 +412,6 @@ export async function client(
queue: false,
time: new Date()
});
fire_event({
type: "data",
endpoint: _endpoint,
fn_index,
data,
time: new Date()
});
} else {
fire_event({
type: "status",
@ -474,7 +475,7 @@ export async function client(
last_status[fn_index]
);
if (type === "update" && status) {
if (type === "update" && status && !complete) {
// call 'status' listeners
fire_event({
type: "status",
@ -492,16 +493,7 @@ export async function client(
} else if (type === "data") {
websocket.send(JSON.stringify({ ...payload, session_hash }));
} else if (type === "complete") {
fire_event({
type: "status",
time: new Date(),
...status,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
websocket.close();
complete = status;
} else if (type === "generating") {
fire_event({
type: "status",
@ -528,6 +520,19 @@ export async function client(
endpoint: _endpoint,
fn_index
});
if (complete) {
fire_event({
type: "status",
time: new Date(),
...complete,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
websocket.close();
}
}
};
@ -572,26 +577,19 @@ export async function client(
}
async function cancel() {
fire_event({
type: "status",
endpoint: _endpoint,
fn_index: fn_index,
const _status: Status = {
stage: "complete",
queue: false,
time: new Date()
};
complete = _status;
fire_event({
..._status,
type: "status",
endpoint: _endpoint,
fn_index: fn_index
});
try {
await fetch(`${http_protocol}//${host + config.path}/reset`, {
method: "POST",
body: JSON.stringify(session_hash)
});
} catch (e) {
console.warn(
"The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable."
);
}
if (websocket && websocket.readyState === 0) {
websocket.addEventListener("open", () => {
websocket.close();
@ -600,7 +598,17 @@ export async function client(
websocket.close();
}
destroy();
try {
await fetch(`${http_protocol}//${host + config.path}/reset`, {
headers: { "Content-Type": "application/json" },
method: "POST",
body: JSON.stringify({ fn_index, session_hash })
});
} catch (e) {
console.warn(
"The `/reset` endpoint could not be called. Subsequent endpoint results may be unreliable."
);
}
}
function destroy() {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 762 KiB

After

Width:  |  Height:  |  Size: 1.7 MiB

View File

@ -245,21 +245,27 @@
}
let handled_dependencies: Array<number[]> = [];
const trigger_api_call = async (dep_index: number, event_data: unknown) => {
const trigger_api_call = async (
dep_index: number,
event_data: unknown = null
) => {
let dep = dependencies[dep_index];
const current_status = loading_status.get_status_for_fn(dep_index);
if (current_status === "pending" || current_status === "generating") {
return;
}
if (dep.cancels) {
await Promise.all(
dep.cancels.map((fn_index) => {
submit_map.get(fn_index)?.cancel();
dep.cancels.map(async (fn_index) => {
const submission = submit_map.get(fn_index);
submission?.cancel();
return submission;
})
);
}
if (current_status === "pending" || current_status === "generating") {
return;
}
let payload = {
fn_index: dep_index,
data: dep.inputs.map((id) => instance_map[id].props.value),
@ -292,15 +298,6 @@
.submit(payload.fn_index, payload.data as unknown[], payload.event_data)
.on("data", ({ data, fn_index }) => {
handle_update(data, fn_index);
let status = loading_status.get_status_for_fn(fn_index);
if (status === "complete") {
dependencies.forEach((dep, i) => {
if (dep.trigger_after === fn_index) {
trigger_api_call(i, null);
}
});
}
})
.on("status", ({ fn_index, ...status }) => {
loading_status.update({
@ -310,16 +307,27 @@
fn_index
});
if (status.stage === "complete") {
dependencies.map(async (dep, i) => {
if (dep.trigger_after === fn_index) {
trigger_api_call(i);
}
});
submission.destroy();
}
if (status.stage === "error") {
// handle failed .then here, since "data" listener won't trigger
dependencies.forEach((dep, i) => {
dependencies.map(async (dep, i) => {
if (
dep.trigger_after === fn_index &&
!dep.trigger_only_on_success
) {
trigger_api_call(i, null);
trigger_api_call(i);
}
});
submission.destroy();
}
});
@ -353,7 +361,7 @@
outputs.every((v) => instance_map?.[v].instance) &&
inputs.every((v) => instance_map?.[v].instance)
) {
trigger_api_call(i, null);
trigger_api_call(i);
handled_dependencies[i] = [-1];
}