Queue maximum length (#2036)

* changes

* format

* changes

* changes
This commit is contained in:
aliabid94 2022-08-18 15:29:51 -07:00 committed by GitHub
parent 0474e460ad
commit 029637cef9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 106 additions and 47 deletions

View File

@ -28,4 +28,4 @@ demo = gr.Interface(
description="Here's a sample toy calculator. Enjoy!",
)
if __name__ == "__main__":
demo.launch()
demo.launch(show_error=True)

View File

@ -8,7 +8,7 @@ import gradio as gr
def fake_gan(*args):
time.sleep(8)
time.sleep(15)
image = random.choice(
[
"https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
@ -36,7 +36,7 @@ demo = gr.Interface(
[os.path.join(os.path.dirname(__file__), "files/zebra.jpg")],
],
)
demo.queue()
demo.queue(max_size=3)
if __name__ == "__main__":
demo.launch()

View File

@ -668,9 +668,8 @@ class Blocks(BlockContext):
"css": self.css,
"title": self.title or "Gradio",
"is_space": self.is_space,
"enable_queue": getattr(
self, "enable_queue", False
), # attribute set at launch
"enable_queue": getattr(self, "enable_queue", False), # launch attributes
"show_error": getattr(self, "show_error", False),
}
for _id, block in self.blocks.items():
config["components"].append(
@ -776,6 +775,7 @@ class Blocks(BlockContext):
status_update_rate: float | str = "auto",
client_position_to_load_data: int = 30,
default_enabled: bool = True,
max_size: Optional[int] = None,
):
"""
You can control the rate of processed requests by creating a queue. This will allow you to set the number of requests to be processed at one time, and will let users know their position in the queue.
@ -795,6 +795,7 @@ class Blocks(BlockContext):
concurrency_count=concurrency_count,
data_gathering_start=client_position_to_load_data,
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
max_size=max_size,
)
return self

View File

@ -38,6 +38,7 @@ class Queue:
QUEUE_DURATION = 1
LIVE_UPDATES = True
SLEEP_WHEN_FREE = 0.001
MAX_SIZE = None
@classmethod
def configure_queue(
@ -46,6 +47,7 @@ class Queue:
concurrency_count: int,
data_gathering_start: int,
update_intervals: int,
max_size: Optional[int],
):
"""
See Blocks.queue() docstring for the explanation of parameters.
@ -55,6 +57,7 @@ class Queue:
cls.DATA_GATHERING_STARTS_AT = data_gathering_start
cls.UPDATE_INTERVALS = update_intervals
cls.ACTIVE_JOBS = [None] * cls.MAX_THREAD_COUNT
cls.MAX_SIZE = max_size
@classmethod
def set_url(cls, url: str):
@ -104,14 +107,16 @@ class Queue:
run_coro_in_background(cls.gather_data_and_broadcast_estimations)
@classmethod
def push(cls, event: Event) -> int:
def push(cls, event: Event) -> int | None:
"""
Add event to queue
Add event to queue, or return None if Queue is full
Parameters:
event: Event to add to Queue
Returns:
rank of submitted Event
"""
if cls.MAX_SIZE is not None and len(cls.EVENT_QUEUE) >= cls.MAX_SIZE:
return None
cls.EVENT_QUEUE.append(event)
return len(cls.EVENT_QUEUE) - 1

View File

@ -283,6 +283,10 @@ class App(FastAPI):
return
event.hash = e_hash["hash"]
rank = Queue.push(event)
if rank is None:
await event.send_message({"msg": "queue_full"})
await event.disconnect()
return
estimation = Queue.get_estimation()
await Queue.send_estimation(event, estimation, rank)
while True:

View File

@ -1 +1 @@
3.1.5
3.1.5b1

View File

@ -35,6 +35,7 @@
export let target: HTMLElement;
export let id: number = 0;
export let autoscroll: boolean = false;
export let show_error: boolean = false;
let app_mode = window.__gradio_mode__ === "app";
let loading_status = create_loading_status_store();
@ -280,10 +281,7 @@
}
if (!(queue === null ? enable_queue : queue)) {
req.then(handle_update).catch((error) => {
console.error(error);
loading_status.update(i, "error", queue || false, 0, 0, 0);
});
req.then(handle_update);
}
handled_dependencies[i] = [-1];
@ -313,10 +311,7 @@
});
if (!(queue === null ? enable_queue : queue)) {
req.then(handle_update).catch((error) => {
console.error(error);
loading_status.update(i, "error", queue || false, 0, 0, 0);
});
req.then(handle_update);
}
});

View File

@ -31,21 +31,24 @@ interface Payload {
declare let BUILD_MODE: string;
declare let BACKEND_URL: string;
async function post_data<
Return extends Record<string, unknown> = Record<string, unknown>
>(url: string, body: unknown): Promise<Return> {
interface PostResponse {
error?: string;
[x: string]: unknown;
}
const QUEUE_FULL_MSG = "This application is too busy! Try again soon.";
async function post_data(
url: string,
body: unknown
): Promise<[PostResponse, number]> {
const response = await fetch(url, {
method: "POST",
body: JSON.stringify(body),
headers: { "Content-Type": "application/json" }
});
if (response.status !== 200) {
throw new Error(response.statusText);
}
const output: Return = await response.json();
return output;
const output: PostResponse = await response.json();
return [output, response.status];
}
interface UpdateOutput {
__type__: string;
@ -61,7 +64,12 @@ type Output = {
const ws_map = new Map();
export const fn =
(session_hash: string, api_endpoint: string, is_space: boolean) =>
(
session_hash: string,
api_endpoint: string,
is_space: boolean,
show_error: boolean
) =>
async ({
action,
payload,
@ -97,17 +105,18 @@ export const fn =
queue,
null,
null,
null,
null
);
function send_message(fn: number, data: any) {
ws_map.get(fn).connection.send(JSON.stringify(data));
}
var ws_protocol = api_endpoint.startsWith("https") ? "wss:" : "ws:";
var ws_endpoint = api_endpoint === "api/" ? location.href : api_endpoint;
var ws_protocol = ws_endpoint.startsWith("https") ? "wss:" : "ws:";
if (is_space) {
const SPACE_REGEX = /embed\/(.*)\/\+/g;
var ws_path = Array.from(api_endpoint.matchAll(SPACE_REGEX))[0][1];
var ws_path = Array.from(ws_endpoint.matchAll(SPACE_REGEX))[0][1];
var ws_host = "spaces.huggingface.tech/";
} else {
var ws_path = location.pathname === "/" ? "" : location.pathname;
@ -142,6 +151,18 @@ export const fn =
case "send_data":
send_message(fn_index, payload);
break;
case "queue_full":
loading_status.update(
fn_index,
"error",
queue,
null,
null,
null,
QUEUE_FULL_MSG
);
websocket_data.connection.close();
break;
case "estimation":
loading_status.update(
fn_index,
@ -149,7 +170,8 @@ export const fn =
queue,
data.queue_size,
data.rank,
data.rank_eta
data.rank_eta,
null
);
break;
case "process_completed":
@ -159,7 +181,8 @@ export const fn =
queue,
null,
null,
data.output.average_duration
data.output.average_duration,
null
);
queue_callback(data.output);
websocket_data.connection.close();
@ -171,6 +194,7 @@ export const fn =
queue,
data.rank,
0,
null,
null
);
break;
@ -183,23 +207,35 @@ export const fn =
queue,
null,
null,
null,
null
);
const output = await post_data(api_endpoint + action + "/", {
var [output, status_code] = await post_data(api_endpoint + action + "/", {
...payload,
session_hash
});
if (status_code == 200) {
loading_status.update(
fn_index,
"complete",
queue,
null,
null,
output.average_duration as number
output.average_duration as number,
null
);
} else {
loading_status.update(
fn_index,
"error",
queue,
null,
null,
null,
show_error ? output.error : null
);
}
return output;
}
};

View File

@ -56,6 +56,7 @@
export let scroll_to_output: boolean = false;
export let timer: boolean = true;
export let visible: boolean = true;
export let message: string | null = null;
let el: HTMLDivElement;
@ -150,6 +151,9 @@
{/if}
{:else if status === "error"}
<span class="error">ERROR</span>
{#if message}
<span class="status-message dark:text-gray-100">{message}</span>
{/if}
{/if}
</div>
@ -177,4 +181,8 @@
.error {
@apply text-red-400 font-mono font-semibold text-lg;
}
.status-message {
@apply font-mono p-2 whitespace-pre;
}
</style>

View File

@ -34,6 +34,7 @@ interface Config {
title: string;
version: string;
is_space: boolean;
show_error: boolean;
// allow_flagging: string;
// allow_interpretation: boolean;
// article: string;
@ -183,7 +184,12 @@ function mount_app(
});
} else {
let session_hash = Math.random().toString(36).substring(2);
config.fn = fn(session_hash, config.root + "api/", config.is_space);
config.fn = fn(
session_hash,
config.root + "api/",
config.is_space,
config.show_error
);
new Blocks({
target: wrapper,

View File

@ -7,6 +7,7 @@ export interface LoadingStatus {
queue_position: number | null;
queue_size: number | null;
fn_index: number;
message?: string | null;
scroll_to_output?: boolean;
visible?: boolean;
}
@ -30,7 +31,8 @@ export function create_loading_status_store() {
queue: LoadingStatus["queue"],
size: LoadingStatus["queue_size"],
position: LoadingStatus["queue_position"],
eta: LoadingStatus["eta"]
eta: LoadingStatus["eta"],
message: LoadingStatus["message"]
) {
const outputs = fn_outputs[fn_index];
const inputs = fn_inputs[fn_index];
@ -66,7 +68,8 @@ export function create_loading_status_store() {
queue_position: position,
queue_size: size,
eta: eta,
status: new_status
status: new_status,
message: message
};
});
@ -88,12 +91,13 @@ export function create_loading_status_store() {
store.update((outputs) => {
outputs_to_update.forEach(
({ id, queue_position, queue_size, eta, status }) => {
({ id, queue_position, queue_size, eta, status, message }) => {
outputs[id] = {
queue: queue,
queue_size: queue_size,
queue_position: queue_position,
eta: eta,
message,
status,
fn_index
};