mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Fix api event drops (#6556)
* changes * changes * add changeset * changes * changes * changes * changs * chagnes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes~git push * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Ubuntu <ubuntu@ip-172-31-25-241.us-west-2.compute.internal> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
67ddd40b4b
commit
d76bcaaaf0
7
.changeset/ripe-spiders-love.md
Normal file
7
.changeset/ripe-spiders-love.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/client": patch
|
||||
"gradio": patch
|
||||
"gradio_client": patch
|
||||
---
|
||||
|
||||
fix:Fix api event drops
|
@ -278,6 +278,9 @@ export function api_factory(
|
||||
|
||||
const session_hash = Math.random().toString(36).substring(2);
|
||||
const last_status: Record<string, Status["stage"]> = {};
|
||||
let stream_open = false;
|
||||
let event_stream: EventSource | null = null;
|
||||
const event_callbacks: Record<string, () => Promise<void>> = {};
|
||||
let config: Config;
|
||||
let api_map: Record<string, number> = {};
|
||||
|
||||
@ -437,7 +440,7 @@ export function api_factory(
|
||||
|
||||
let websocket: WebSocket;
|
||||
let eventSource: EventSource;
|
||||
let protocol = config.protocol ?? "sse";
|
||||
let protocol = config.protocol ?? "ws";
|
||||
|
||||
const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
|
||||
let payload: Payload;
|
||||
@ -646,7 +649,7 @@ export function api_factory(
|
||||
websocket.send(JSON.stringify({ hash: session_hash }))
|
||||
);
|
||||
}
|
||||
} else {
|
||||
} else if (protocol == "sse") {
|
||||
fire_event({
|
||||
type: "status",
|
||||
stage: "pending",
|
||||
@ -766,6 +769,121 @@ export function api_factory(
|
||||
}
|
||||
}
|
||||
};
|
||||
} else if (protocol == "sse_v1") {
|
||||
fire_event({
|
||||
type: "status",
|
||||
stage: "pending",
|
||||
queue: true,
|
||||
endpoint: _endpoint,
|
||||
fn_index,
|
||||
time: new Date()
|
||||
});
|
||||
|
||||
post_data(
|
||||
`${http_protocol}//${resolve_root(
|
||||
host,
|
||||
config.path,
|
||||
true
|
||||
)}/queue/join?${url_params}`,
|
||||
{
|
||||
...payload,
|
||||
session_hash
|
||||
},
|
||||
hf_token
|
||||
).then(([response, status]) => {
|
||||
if (status !== 200) {
|
||||
fire_event({
|
||||
type: "status",
|
||||
stage: "error",
|
||||
message: BROKEN_CONNECTION_MSG,
|
||||
queue: true,
|
||||
endpoint: _endpoint,
|
||||
fn_index,
|
||||
time: new Date()
|
||||
});
|
||||
} else {
|
||||
event_id = response.event_id as string;
|
||||
if (!stream_open) {
|
||||
open_stream();
|
||||
}
|
||||
|
||||
let callback = async function (_data: object): void {
|
||||
const { type, status, data } = handle_message(
|
||||
_data,
|
||||
last_status[fn_index]
|
||||
);
|
||||
|
||||
if (type === "update" && status && !complete) {
|
||||
// call 'status' listeners
|
||||
fire_event({
|
||||
type: "status",
|
||||
endpoint: _endpoint,
|
||||
fn_index,
|
||||
time: new Date(),
|
||||
...status
|
||||
});
|
||||
} else if (type === "complete") {
|
||||
complete = status;
|
||||
} else if (type === "log") {
|
||||
fire_event({
|
||||
type: "log",
|
||||
log: data.log,
|
||||
level: data.level,
|
||||
endpoint: _endpoint,
|
||||
fn_index
|
||||
});
|
||||
} else if (type === "generating") {
|
||||
fire_event({
|
||||
type: "status",
|
||||
time: new Date(),
|
||||
...status,
|
||||
stage: status?.stage!,
|
||||
queue: true,
|
||||
endpoint: _endpoint,
|
||||
fn_index
|
||||
});
|
||||
}
|
||||
if (data) {
|
||||
fire_event({
|
||||
type: "data",
|
||||
time: new Date(),
|
||||
data: transform_files
|
||||
? transform_output(
|
||||
data.data,
|
||||
api_info,
|
||||
config.root,
|
||||
config.root_url
|
||||
)
|
||||
: data.data,
|
||||
endpoint: _endpoint,
|
||||
fn_index
|
||||
});
|
||||
|
||||
if (complete) {
|
||||
fire_event({
|
||||
type: "status",
|
||||
time: new Date(),
|
||||
...complete,
|
||||
stage: status?.stage!,
|
||||
queue: true,
|
||||
endpoint: _endpoint,
|
||||
fn_index
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (status.stage === "complete" || status.stage === "error") {
|
||||
if (event_callbacks[event_id]) {
|
||||
delete event_callbacks[event_id];
|
||||
if (Object.keys(event_callbacks).length === 0) {
|
||||
close_stream();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
event_callbacks[event_id] = callback;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
@ -864,6 +982,30 @@ export function api_factory(
|
||||
};
|
||||
}
|
||||
|
||||
function open_stream(): void {
|
||||
stream_open = true;
|
||||
let params = new URLSearchParams({
|
||||
session_hash: session_hash
|
||||
}).toString();
|
||||
let url = new URL(
|
||||
`${http_protocol}//${resolve_root(
|
||||
host,
|
||||
config.path,
|
||||
true
|
||||
)}/queue/data?${params}`
|
||||
);
|
||||
event_stream = new EventSource(url);
|
||||
event_stream.onmessage = async function (event) {
|
||||
let _data = JSON.parse(event.data);
|
||||
await event_callbacks[_data.event_id](_data);
|
||||
};
|
||||
}
|
||||
|
||||
function close_stream(): void {
|
||||
stream_open = false;
|
||||
event_stream?.close();
|
||||
}
|
||||
|
||||
async function component_server(
|
||||
component_id: number,
|
||||
fn_name: string,
|
||||
|
@ -20,7 +20,7 @@ export interface Config {
|
||||
show_api: boolean;
|
||||
stylesheets: string[];
|
||||
path: string;
|
||||
protocol?: "sse" | "ws";
|
||||
protocol?: "sse_v1" | "sse" | "ws";
|
||||
}
|
||||
|
||||
export interface Payload {
|
||||
|
@ -36,6 +36,7 @@ from gradio_client.exceptions import SerializationSetupError
|
||||
from gradio_client.utils import (
|
||||
Communicator,
|
||||
JobStatus,
|
||||
Message,
|
||||
Status,
|
||||
StatusUpdate,
|
||||
)
|
||||
@ -124,25 +125,33 @@ class Client:
|
||||
if self.verbose:
|
||||
print(f"Loaded as API: {self.src} ✔")
|
||||
|
||||
if auth is not None:
|
||||
self._login(auth)
|
||||
|
||||
self.config = self._get_config()
|
||||
self.protocol: str = self.config.get("protocol", "ws")
|
||||
self.api_url = urllib.parse.urljoin(self.src, utils.API_URL)
|
||||
self.sse_url = urllib.parse.urljoin(self.src, utils.SSE_URL)
|
||||
self.sse_data_url = urllib.parse.urljoin(self.src, utils.SSE_DATA_URL)
|
||||
self.sse_url = urllib.parse.urljoin(
|
||||
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
|
||||
)
|
||||
self.sse_data_url = urllib.parse.urljoin(
|
||||
self.src,
|
||||
utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL,
|
||||
)
|
||||
self.ws_url = urllib.parse.urljoin(
|
||||
self.src.replace("http", "ws", 1), utils.WS_URL
|
||||
)
|
||||
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
|
||||
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
|
||||
if auth is not None:
|
||||
self._login(auth)
|
||||
self.config = self._get_config()
|
||||
self.app_version = version.parse(self.config.get("version", "2.0"))
|
||||
self._info = self._get_api_info()
|
||||
self.session_hash = str(uuid.uuid4())
|
||||
|
||||
protocol = self.config.get("protocol")
|
||||
endpoint_class = Endpoint if protocol == "sse" else EndpointV3Compatibility
|
||||
endpoint_class = (
|
||||
Endpoint if self.protocol.startswith("sse") else EndpointV3Compatibility
|
||||
)
|
||||
self.endpoints = [
|
||||
endpoint_class(self, fn_index, dependency)
|
||||
endpoint_class(self, fn_index, dependency, self.protocol)
|
||||
for fn_index, dependency in enumerate(self.config["dependencies"])
|
||||
]
|
||||
|
||||
@ -152,6 +161,84 @@ class Client:
|
||||
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
|
||||
threading.Thread(target=self._telemetry_thread).start()
|
||||
|
||||
self.stream_open = False
|
||||
self.streaming_future: Future | None = None
|
||||
self.pending_messages_per_event: dict[str, list[Message | None]] = {}
|
||||
self.pending_event_ids: set[str] = set()
|
||||
|
||||
async def stream_messages(self) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
|
||||
buffer = ""
|
||||
async with client.stream(
|
||||
"GET",
|
||||
self.sse_url,
|
||||
params={"session_hash": self.session_hash},
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
buffer += line
|
||||
while "\n\n" in buffer:
|
||||
message, buffer = buffer.split("\n\n", 1)
|
||||
if message.startswith("data:"):
|
||||
resp = json.loads(message[5:])
|
||||
if resp["msg"] == "heartbeat":
|
||||
continue
|
||||
elif resp["msg"] == "server_stopped":
|
||||
for (
|
||||
pending_messages
|
||||
) in self.pending_messages_per_event.values():
|
||||
pending_messages.append(resp)
|
||||
return
|
||||
event_id = resp["event_id"]
|
||||
if event_id not in self.pending_messages_per_event:
|
||||
self.pending_messages_per_event[event_id] = []
|
||||
self.pending_messages_per_event[event_id].append(resp)
|
||||
if resp["msg"] == "process_completed":
|
||||
self.pending_event_ids.remove(event_id)
|
||||
if len(self.pending_event_ids) == 0:
|
||||
self.stream_open = False
|
||||
return
|
||||
elif message == "":
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unexpected SSE line: '{message}'")
|
||||
except BaseException as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
async def send_data(self, data, hash_data):
|
||||
async with httpx.AsyncClient() as client:
|
||||
req = await client.post(
|
||||
self.sse_data_url,
|
||||
json={**data, **hash_data},
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
)
|
||||
req.raise_for_status()
|
||||
resp = req.json()
|
||||
event_id = resp["event_id"]
|
||||
|
||||
if not self.stream_open:
|
||||
self.stream_open = True
|
||||
|
||||
def open_stream():
|
||||
return utils.synchronize_async(self.stream_messages)
|
||||
|
||||
def close_stream(_):
|
||||
self.stream_open = False
|
||||
for _, pending_messages in self.pending_messages_per_event.items():
|
||||
pending_messages.append(None)
|
||||
|
||||
if self.streaming_future is None or self.streaming_future.done():
|
||||
self.streaming_future = self.executor.submit(open_stream)
|
||||
self.streaming_future.add_done_callback(close_stream)
|
||||
|
||||
return event_id
|
||||
|
||||
@classmethod
|
||||
def duplicate(
|
||||
cls,
|
||||
@ -340,7 +427,7 @@ class Client:
|
||||
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
|
||||
|
||||
helper = None
|
||||
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse"):
|
||||
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
|
||||
helper = self.new_helper(inferred_fn_index)
|
||||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
|
||||
future = self.executor.submit(end_to_end_fn, *args)
|
||||
@ -806,7 +893,9 @@ class ReplaceMe:
|
||||
class Endpoint:
|
||||
"""Helper class for storing all the information about a single API endpoint."""
|
||||
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict):
|
||||
def __init__(
|
||||
self, client: Client, fn_index: int, dependency: dict, protocol: str = "sse_v1"
|
||||
):
|
||||
self.client: Client = client
|
||||
self.fn_index = fn_index
|
||||
self.dependency = dependency
|
||||
@ -814,7 +903,7 @@ class Endpoint:
|
||||
self.api_name: str | Literal[False] | None = (
|
||||
"/" + api_name if isinstance(api_name, str) else api_name
|
||||
)
|
||||
self.protocol = "sse"
|
||||
self.protocol = protocol
|
||||
self.input_component_types = [
|
||||
self._get_component_type(id_) for id_ in dependency["inputs"]
|
||||
]
|
||||
@ -891,7 +980,20 @@ class Endpoint:
|
||||
"session_hash": self.client.session_hash,
|
||||
}
|
||||
|
||||
result = utils.synchronize_async(self._sse_fn, data, hash_data, helper)
|
||||
if self.protocol == "sse":
|
||||
result = utils.synchronize_async(
|
||||
self._sse_fn_v0, data, hash_data, helper
|
||||
)
|
||||
elif self.protocol == "sse_v1":
|
||||
event_id = utils.synchronize_async(
|
||||
self.client.send_data, data, hash_data
|
||||
)
|
||||
self.client.pending_event_ids.add(event_id)
|
||||
self.client.pending_messages_per_event[event_id] = []
|
||||
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported protocol: {self.protocol}")
|
||||
|
||||
if "error" in result:
|
||||
raise ValueError(result["error"])
|
||||
|
||||
@ -1068,24 +1170,33 @@ class Endpoint:
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
|
||||
async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator):
|
||||
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
|
||||
return await utils.get_pred_from_sse(
|
||||
return await utils.get_pred_from_sse_v0(
|
||||
client,
|
||||
data,
|
||||
hash_data,
|
||||
helper,
|
||||
sse_url=self.client.sse_url,
|
||||
sse_data_url=self.client.sse_data_url,
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
self.client.sse_url,
|
||||
self.client.sse_data_url,
|
||||
self.client.headers,
|
||||
self.client.cookies,
|
||||
)
|
||||
|
||||
async def _sse_fn_v1(self, helper: Communicator, event_id: str):
|
||||
return await utils.get_pred_from_sse_v1(
|
||||
helper,
|
||||
self.client.headers,
|
||||
self.client.cookies,
|
||||
self.client.pending_messages_per_event,
|
||||
event_id,
|
||||
)
|
||||
|
||||
|
||||
class EndpointV3Compatibility:
|
||||
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""
|
||||
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict):
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict, *args):
|
||||
self.client: Client = client
|
||||
self.fn_index = fn_index
|
||||
self.dependency = dependency
|
||||
|
@ -17,7 +17,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, TypedDict
|
||||
|
||||
import fsspec.asyn
|
||||
import httpx
|
||||
@ -26,8 +26,10 @@ from huggingface_hub import SpaceStage
|
||||
from websockets.legacy.protocol import WebSocketCommonProtocol
|
||||
|
||||
API_URL = "api/predict/"
|
||||
SSE_URL = "queue/join"
|
||||
SSE_DATA_URL = "queue/data"
|
||||
SSE_URL_V0 = "queue/join"
|
||||
SSE_DATA_URL_V0 = "queue/data"
|
||||
SSE_URL = "queue/data"
|
||||
SSE_DATA_URL = "queue/join"
|
||||
WS_URL = "queue/join"
|
||||
UPLOAD_URL = "upload"
|
||||
LOGIN_URL = "login"
|
||||
@ -48,6 +50,19 @@ INVALID_RUNTIME = [
|
||||
]
|
||||
|
||||
|
||||
class Message(TypedDict, total=False):
|
||||
msg: str
|
||||
output: dict[str, Any]
|
||||
event_id: str
|
||||
rank: int
|
||||
rank_eta: float
|
||||
queue_size: int
|
||||
success: bool
|
||||
progress_data: list[dict]
|
||||
log: str
|
||||
level: str
|
||||
|
||||
|
||||
def get_package_version() -> str:
|
||||
try:
|
||||
package_json_data = (
|
||||
@ -100,6 +115,7 @@ class Status(Enum):
|
||||
PROGRESS = "PROGRESS"
|
||||
FINISHED = "FINISHED"
|
||||
CANCELLED = "CANCELLED"
|
||||
LOG = "LOG"
|
||||
|
||||
@staticmethod
|
||||
def ordering(status: Status) -> int:
|
||||
@ -133,6 +149,7 @@ class Status(Enum):
|
||||
"process_generating": Status.ITERATING,
|
||||
"process_completed": Status.FINISHED,
|
||||
"progress": Status.PROGRESS,
|
||||
"log": Status.LOG,
|
||||
}[msg]
|
||||
|
||||
|
||||
@ -169,6 +186,7 @@ class StatusUpdate:
|
||||
success: bool | None
|
||||
time: datetime | None
|
||||
progress_data: list[ProgressUnit] | None
|
||||
log: tuple[str, str] | None = None
|
||||
|
||||
|
||||
def create_initial_status_update():
|
||||
@ -307,7 +325,7 @@ async def get_pred_from_ws(
|
||||
return resp["output"]
|
||||
|
||||
|
||||
async def get_pred_from_sse(
|
||||
async def get_pred_from_sse_v0(
|
||||
client: httpx.AsyncClient,
|
||||
data: dict,
|
||||
hash_data: dict,
|
||||
@ -315,21 +333,21 @@ async def get_pred_from_sse(
|
||||
sse_url: str,
|
||||
sse_data_url: str,
|
||||
headers: dict[str, str],
|
||||
cookies: dict[str, str] | None = None,
|
||||
cookies: dict[str, str] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
asyncio.create_task(check_for_cancel(helper, cookies)),
|
||||
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
|
||||
asyncio.create_task(
|
||||
stream_sse(
|
||||
stream_sse_v0(
|
||||
client,
|
||||
data,
|
||||
hash_data,
|
||||
helper,
|
||||
sse_url,
|
||||
sse_data_url,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
headers,
|
||||
cookies,
|
||||
)
|
||||
),
|
||||
],
|
||||
@ -348,7 +366,42 @@ async def get_pred_from_sse(
|
||||
return task.result()
|
||||
|
||||
|
||||
async def check_for_cancel(helper: Communicator, cookies: dict[str, str] | None):
|
||||
async def get_pred_from_sse_v1(
|
||||
helper: Communicator,
|
||||
headers: dict[str, str],
|
||||
cookies: dict[str, str] | None,
|
||||
pending_messages_per_event: dict[str, list[Message | None]],
|
||||
event_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
|
||||
asyncio.create_task(
|
||||
stream_sse_v1(
|
||||
helper,
|
||||
pending_messages_per_event,
|
||||
event_id,
|
||||
)
|
||||
),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
assert len(done) == 1
|
||||
for task in done:
|
||||
return task.result()
|
||||
|
||||
|
||||
async def check_for_cancel(
|
||||
helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None
|
||||
):
|
||||
while True:
|
||||
await asyncio.sleep(0.05)
|
||||
with helper.lock:
|
||||
@ -357,12 +410,15 @@ async def check_for_cancel(helper: Communicator, cookies: dict[str, str] | None)
|
||||
if helper.event_id:
|
||||
async with httpx.AsyncClient() as http:
|
||||
await http.post(
|
||||
helper.reset_url, json={"event_id": helper.event_id}, cookies=cookies
|
||||
helper.reset_url,
|
||||
json={"event_id": helper.event_id},
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
)
|
||||
raise CancelledError()
|
||||
|
||||
|
||||
async def stream_sse(
|
||||
async def stream_sse_v0(
|
||||
client: httpx.AsyncClient,
|
||||
data: dict,
|
||||
hash_data: dict,
|
||||
@ -370,15 +426,15 @@ async def stream_sse(
|
||||
sse_url: str,
|
||||
sse_data_url: str,
|
||||
headers: dict[str, str],
|
||||
cookies: dict[str, str] | None = None,
|
||||
cookies: dict[str, str] | None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
sse_url,
|
||||
params=hash_data,
|
||||
cookies=cookies,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
if line.startswith("data:"):
|
||||
@ -413,8 +469,8 @@ async def stream_sse(
|
||||
req = await client.post(
|
||||
sse_data_url,
|
||||
json={"event_id": event_id, **data, **hash_data},
|
||||
cookies=cookies,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
)
|
||||
req.raise_for_status()
|
||||
elif resp["msg"] == "process_completed":
|
||||
@ -426,6 +482,64 @@ async def stream_sse(
|
||||
raise
|
||||
|
||||
|
||||
async def stream_sse_v1(
|
||||
helper: Communicator,
|
||||
pending_messages_per_event: dict[str, list[Message | None]],
|
||||
event_id: str,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
pending_messages = pending_messages_per_event[event_id]
|
||||
|
||||
while True:
|
||||
if len(pending_messages) > 0:
|
||||
msg = pending_messages.pop(0)
|
||||
else:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if msg is None:
|
||||
raise CancelledError()
|
||||
|
||||
with helper.lock:
|
||||
log_message = None
|
||||
if msg["msg"] == "log":
|
||||
log = msg.get("log")
|
||||
level = msg.get("level")
|
||||
if log and level:
|
||||
log_message = (log, level)
|
||||
status_update = StatusUpdate(
|
||||
code=Status.msg_to_status(msg["msg"]),
|
||||
queue_size=msg.get("queue_size"),
|
||||
rank=msg.get("rank", None),
|
||||
success=msg.get("success"),
|
||||
time=datetime.now(),
|
||||
eta=msg.get("rank_eta"),
|
||||
progress_data=ProgressUnit.from_msg(msg["progress_data"])
|
||||
if "progress_data" in msg
|
||||
else None,
|
||||
log=log_message,
|
||||
)
|
||||
output = msg.get("output", {}).get("data", [])
|
||||
if output and status_update.code != Status.FINISHED:
|
||||
try:
|
||||
result = helper.prediction_processor(*output)
|
||||
except Exception as e:
|
||||
result = [e]
|
||||
helper.job.outputs.append(result)
|
||||
helper.job.latest_status = status_update
|
||||
|
||||
if msg["msg"] == "queue_full":
|
||||
raise QueueError("Queue is full! Please try again.")
|
||||
elif msg["msg"] == "process_completed":
|
||||
del pending_messages_per_event[event_id]
|
||||
return msg["output"]
|
||||
elif msg["msg"] == "server_stopped":
|
||||
raise ValueError("Server stopped.")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
|
||||
########################
|
||||
# Data processing utils
|
||||
########################
|
||||
|
@ -95,6 +95,17 @@ class TestClientPredictions:
|
||||
output = client.predict("abc", api_name="/predict")
|
||||
assert output == "abc"
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_private_space_v4_sse_v1(self):
|
||||
space_id = "gradio-tests/not-actually-private-spacev4-sse-v1"
|
||||
api = huggingface_hub.HfApi()
|
||||
assert api.space_info(space_id).private
|
||||
client = Client(
|
||||
space_id,
|
||||
)
|
||||
output = client.predict("abc", api_name="/predict")
|
||||
assert output == "abc"
|
||||
|
||||
def test_state(self, increment_demo):
|
||||
with connect(increment_demo) as client:
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
|
@ -893,6 +893,13 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
fn = get_continuous_fn(fn, every)
|
||||
elif every:
|
||||
raise ValueError("Cannot set a value for `every` without a `fn`.")
|
||||
if every and concurrency_limit is not None:
|
||||
if concurrency_limit == "default":
|
||||
concurrency_limit = None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot set a value for `concurrency_limit` with `every`."
|
||||
)
|
||||
|
||||
if _targets[0][1] == "change" and trigger_mode is None:
|
||||
trigger_mode = "always_last"
|
||||
@ -1581,7 +1588,7 @@ Received outputs:
|
||||
"is_colab": utils.colab_check(),
|
||||
"stylesheets": self.stylesheets,
|
||||
"theme": self.theme.name,
|
||||
"protocol": "sse",
|
||||
"protocol": "sse_v1",
|
||||
}
|
||||
|
||||
def get_layout(block):
|
||||
@ -2169,7 +2176,7 @@ Received outputs:
|
||||
try:
|
||||
if wasm_utils.IS_WASM:
|
||||
# NOTE:
|
||||
# Normally, queue-related async tasks (e.g. continuous events created by `gr.Blocks.load(..., every=interval)`, whose async tasks are started at the `/queue/join` endpoint function)
|
||||
# Normally, queue-related async tasks (e.g. continuous events created by `gr.Blocks.load(..., every=interval)`, whose async tasks are started at the `/queue/data` endpoint function)
|
||||
# are running in an event loop in the server thread,
|
||||
# so they will be cancelled by `self.server.close()` below.
|
||||
# However, in the Wasm env, we don't have the `server` and
|
||||
|
@ -23,7 +23,7 @@ from gradio.data_classes import (
|
||||
)
|
||||
from gradio.exceptions import Error
|
||||
from gradio.helpers import TrackedIterable
|
||||
from gradio.utils import run_coro_in_background, safe_get_lock, set_task_name
|
||||
from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import BlockFunction
|
||||
@ -37,7 +37,6 @@ class Event:
|
||||
request: fastapi.Request,
|
||||
username: str | None,
|
||||
):
|
||||
self.message_queue = ThreadQueue()
|
||||
self.session_hash = session_hash
|
||||
self.fn_index = fn_index
|
||||
self.request = request
|
||||
@ -48,28 +47,6 @@ class Event:
|
||||
self.progress_pending: bool = False
|
||||
self.alive = True
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
message_type: str,
|
||||
data: dict | None = None,
|
||||
final: bool = False,
|
||||
):
|
||||
data = {} if data is None else data
|
||||
self.message_queue.put_nowait({"msg": message_type, **data})
|
||||
if final:
|
||||
self.message_queue.put_nowait(None)
|
||||
|
||||
async def get_data(self, timeout=5) -> bool:
|
||||
self.send_message("send_data", {"event_id": self._id})
|
||||
sleep_interval = 0.05
|
||||
wait_time = 0
|
||||
while wait_time < timeout and self.alive:
|
||||
if self.data is not None:
|
||||
break
|
||||
await asyncio.sleep(sleep_interval)
|
||||
wait_time += sleep_interval
|
||||
return self.data is not None
|
||||
|
||||
|
||||
class Queue:
|
||||
def __init__(
|
||||
@ -81,6 +58,9 @@ class Queue:
|
||||
block_fns: list[BlockFunction],
|
||||
default_concurrency_limit: int | None | Literal["not_set"] = "not_set",
|
||||
):
|
||||
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000)
|
||||
self.pending_event_ids_session: dict[str, set[str]] = {}
|
||||
self.pending_message_lock = safe_get_lock()
|
||||
self.event_queue: list[Event] = []
|
||||
self.awaiting_data_events: dict[str, Event] = {}
|
||||
self.stopped = False
|
||||
@ -132,6 +112,16 @@ class Queue:
|
||||
def close(self):
|
||||
self.stopped = True
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
event: Event,
|
||||
message_type: str,
|
||||
data: dict | None = None,
|
||||
):
|
||||
data = {} if data is None else data
|
||||
messages = self.pending_messages_per_session[event.session_hash]
|
||||
messages.put_nowait({"msg": message_type, "event_id": event._id, **data})
|
||||
|
||||
def _resolve_concurrency_limit(self, default_concurrency_limit):
|
||||
"""
|
||||
Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination
|
||||
@ -152,13 +142,33 @@ class Queue:
|
||||
else:
|
||||
return 1
|
||||
|
||||
def attach_data(self, body: PredictBody):
|
||||
event_id = body.event_id
|
||||
if event_id in self.awaiting_data_events:
|
||||
event = self.awaiting_data_events[event_id]
|
||||
event.data = body
|
||||
else:
|
||||
raise ValueError("Event not found", event_id)
|
||||
async def push(
|
||||
self, body: PredictBody, request: fastapi.Request, username: str | None
|
||||
):
|
||||
if body.session_hash is None:
|
||||
raise ValueError("No session hash provided.")
|
||||
if body.fn_index is None:
|
||||
raise ValueError("No function index provided.")
|
||||
queue_len = len(self.event_queue)
|
||||
if self.max_size is not None and queue_len >= self.max_size:
|
||||
raise ValueError(
|
||||
f"Queue is full. Max size is {self.max_size} and current size is {queue_len}."
|
||||
)
|
||||
|
||||
event = Event(body.session_hash, body.fn_index, request, username)
|
||||
event.data = body
|
||||
async with self.pending_message_lock:
|
||||
if body.session_hash not in self.pending_messages_per_session:
|
||||
self.pending_messages_per_session[body.session_hash] = ThreadQueue()
|
||||
if body.session_hash not in self.pending_event_ids_session:
|
||||
self.pending_event_ids_session[body.session_hash] = set()
|
||||
self.pending_event_ids_session[body.session_hash].add(event._id)
|
||||
self.event_queue.append(event)
|
||||
|
||||
estimation = self.get_estimation()
|
||||
await self.send_estimation(event, estimation, queue_len)
|
||||
|
||||
return event._id
|
||||
|
||||
def _cancel_asyncio_tasks(self):
|
||||
for task in self._asyncio_tasks:
|
||||
@ -276,7 +286,7 @@ class Queue:
|
||||
for event in events:
|
||||
if event.progress_pending and event.progress:
|
||||
event.progress_pending = False
|
||||
event.send_message("progress", event.progress.model_dump())
|
||||
self.send_message(event, "progress", event.progress.model_dump())
|
||||
|
||||
await asyncio.sleep(self.progress_update_sleep_when_free)
|
||||
|
||||
@ -320,34 +330,23 @@ class Queue:
|
||||
log=log,
|
||||
level=level,
|
||||
)
|
||||
event.send_message("log", log_message.model_dump())
|
||||
self.send_message(event, "log", log_message.model_dump())
|
||||
|
||||
def push(self, event: Event) -> int | None:
|
||||
"""
|
||||
Add event to queue, or return None if Queue is full
|
||||
Parameters:
|
||||
event: Event to add to Queue
|
||||
Returns:
|
||||
rank of submitted Event
|
||||
"""
|
||||
queue_len = len(self.event_queue)
|
||||
if self.max_size is not None and queue_len >= self.max_size:
|
||||
return None
|
||||
self.event_queue.append(event)
|
||||
return queue_len
|
||||
async def clean_events(
|
||||
self, *, session_hash: str | None = None, event_id: str | None = None
|
||||
) -> None:
|
||||
for job_set in self.active_jobs:
|
||||
if job_set:
|
||||
for job in job_set:
|
||||
if job.session_hash == session_hash or job._id == event_id:
|
||||
job.alive = False
|
||||
|
||||
async def clean_event(self, event: Event | str) -> None:
|
||||
if isinstance(event, str):
|
||||
for job_set in self.active_jobs:
|
||||
if job_set:
|
||||
for job in job_set:
|
||||
if job._id == event:
|
||||
event = job
|
||||
break
|
||||
if isinstance(event, str):
|
||||
raise ValueError("Event not found", event)
|
||||
event.alive = False
|
||||
if event in self.event_queue:
|
||||
events_to_remove = []
|
||||
for event in self.event_queue:
|
||||
if event.session_hash == session_hash or event._id == event_id:
|
||||
events_to_remove.append(event)
|
||||
|
||||
for event in events_to_remove:
|
||||
async with self.delete_lock:
|
||||
self.event_queue.remove(event)
|
||||
|
||||
@ -391,7 +390,7 @@ class Queue:
|
||||
if None not in self.active_jobs:
|
||||
# Add estimated amount of time for a thread to get empty
|
||||
estimation.rank_eta += self.avg_concurrent_process_time
|
||||
event.send_message("estimation", estimation.model_dump())
|
||||
self.send_message(event, "estimation", estimation.model_dump())
|
||||
return estimation
|
||||
|
||||
def update_estimation(self, duration: float) -> None:
|
||||
@ -485,14 +484,7 @@ class Queue:
|
||||
awake_events: list[Event] = []
|
||||
try:
|
||||
for event in events:
|
||||
if not event.data:
|
||||
self.awaiting_data_events[event._id] = event
|
||||
client_awake = await event.get_data()
|
||||
del self.awaiting_data_events[event._id]
|
||||
if not client_awake:
|
||||
await self.clean_event(event)
|
||||
continue
|
||||
event.send_message("process_starts")
|
||||
self.send_message(event, "process_starts")
|
||||
awake_events.append(event)
|
||||
if not awake_events:
|
||||
return
|
||||
@ -505,7 +497,8 @@ class Queue:
|
||||
response = None
|
||||
err = e
|
||||
for event in awake_events:
|
||||
event.send_message(
|
||||
self.send_message(
|
||||
event,
|
||||
"process_completed",
|
||||
{
|
||||
"output": {
|
||||
@ -515,7 +508,6 @@ class Queue:
|
||||
},
|
||||
"success": False,
|
||||
},
|
||||
final=True,
|
||||
)
|
||||
if response and response.get("is_generating", False):
|
||||
old_response = response
|
||||
@ -524,7 +516,8 @@ class Queue:
|
||||
old_response = response
|
||||
old_err = err
|
||||
for event in awake_events:
|
||||
event.send_message(
|
||||
self.send_message(
|
||||
event,
|
||||
"process_generating",
|
||||
{
|
||||
"output": old_response,
|
||||
@ -545,7 +538,8 @@ class Queue:
|
||||
relevant_response = err
|
||||
else:
|
||||
relevant_response = old_response or old_err
|
||||
event.send_message(
|
||||
self.send_message(
|
||||
event,
|
||||
"process_completed",
|
||||
{
|
||||
"output": {"error": str(relevant_response)}
|
||||
@ -554,20 +548,19 @@ class Queue:
|
||||
"success": relevant_response
|
||||
and not isinstance(relevant_response, Exception),
|
||||
},
|
||||
final=True,
|
||||
)
|
||||
elif response:
|
||||
output = copy.deepcopy(response)
|
||||
for e, event in enumerate(awake_events):
|
||||
if batch and "data" in output:
|
||||
output["data"] = list(zip(*response.get("data")))[e]
|
||||
event.send_message(
|
||||
self.send_message(
|
||||
event,
|
||||
"process_completed",
|
||||
{
|
||||
"output": output,
|
||||
"success": response is not None,
|
||||
},
|
||||
final=True,
|
||||
)
|
||||
end_time = time.time()
|
||||
if response is not None:
|
||||
|
@ -55,7 +55,7 @@ from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
|
||||
from gradio.exceptions import Error
|
||||
from gradio.helpers import CACHED_FOLDER
|
||||
from gradio.oauth import attach_oauth
|
||||
from gradio.queueing import Estimation, Event
|
||||
from gradio.queueing import Estimation
|
||||
from gradio.route_utils import ( # noqa: F401
|
||||
FileUploadProgress,
|
||||
GradioMultiPartParser,
|
||||
@ -65,10 +65,7 @@ from gradio.route_utils import ( # noqa: F401
|
||||
)
|
||||
from gradio.state_holder import StateHolder
|
||||
from gradio.utils import (
|
||||
cancel_tasks,
|
||||
get_package_version,
|
||||
run_coro_in_background,
|
||||
set_task_name,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -532,7 +529,7 @@ class App(FastAPI):
|
||||
async with app.lock:
|
||||
del app.iterators[body.event_id]
|
||||
app.iterators_to_reset.add(body.event_id)
|
||||
await app.get_blocks()._queue.clean_event(body.event_id)
|
||||
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
|
||||
return {"success": True}
|
||||
|
||||
# had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility
|
||||
@ -582,63 +579,38 @@ class App(FastAPI):
|
||||
)
|
||||
return output
|
||||
|
||||
@app.get("/queue/join", dependencies=[Depends(login_check)])
|
||||
async def queue_join(
|
||||
fn_index: int,
|
||||
session_hash: str,
|
||||
@app.get("/queue/data", dependencies=[Depends(login_check)])
|
||||
async def queue_data(
|
||||
request: fastapi.Request,
|
||||
username: str = Depends(get_current_user),
|
||||
data: Optional[str] = None,
|
||||
session_hash: str,
|
||||
):
|
||||
blocks = app.get_blocks()
|
||||
if blocks._queue.server_app is None:
|
||||
blocks._queue.set_server_app(app)
|
||||
|
||||
event = Event(session_hash, fn_index, request, username)
|
||||
if data is not None:
|
||||
input_data = json.loads(data)
|
||||
event.data = PredictBody(
|
||||
session_hash=session_hash,
|
||||
fn_index=fn_index,
|
||||
data=input_data,
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Continuous events are not put in the queue so that they do not
|
||||
# occupy the queue's resource as they are expected to run forever
|
||||
if blocks.dependencies[event.fn_index].get("every", 0):
|
||||
await cancel_tasks({f"{event.session_hash}_{event.fn_index}"})
|
||||
await blocks._queue.reset_iterators(event._id)
|
||||
blocks._queue.continuous_tasks.append(event)
|
||||
task = run_coro_in_background(
|
||||
blocks._queue.process_events, [event], False
|
||||
)
|
||||
set_task_name(task, event.session_hash, event.fn_index, batch=False)
|
||||
app._asyncio_tasks.append(task)
|
||||
else:
|
||||
rank = blocks._queue.push(event)
|
||||
if rank is None:
|
||||
event.send_message("queue_full", final=True)
|
||||
else:
|
||||
estimation = blocks._queue.get_estimation()
|
||||
await blocks._queue.send_estimation(event, estimation, rank)
|
||||
|
||||
async def sse_stream(request: fastapi.Request):
|
||||
try:
|
||||
last_heartbeat = time.perf_counter()
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
await blocks._queue.clean_event(event)
|
||||
if not event.alive:
|
||||
await blocks._queue.clean_events(session_hash=session_hash)
|
||||
return
|
||||
|
||||
if (
|
||||
session_hash
|
||||
not in blocks._queue.pending_messages_per_session
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Session not found.",
|
||||
)
|
||||
|
||||
heartbeat_rate = 15
|
||||
check_rate = 0.05
|
||||
message = None
|
||||
try:
|
||||
message = event.message_queue.get_nowait()
|
||||
if message is None: # end of stream marker
|
||||
return
|
||||
messages = blocks._queue.pending_messages_per_session[
|
||||
session_hash
|
||||
]
|
||||
message = messages.get_nowait()
|
||||
except EmptyQueue:
|
||||
await asyncio.sleep(check_rate)
|
||||
if time.perf_counter() - last_heartbeat > heartbeat_rate:
|
||||
@ -648,10 +620,29 @@ class App(FastAPI):
|
||||
# and then the stream will retry leading to infinite queue 😬
|
||||
last_heartbeat = time.perf_counter()
|
||||
|
||||
if blocks._queue.stopped:
|
||||
message = {"msg": "server_stopped", "success": False}
|
||||
if message:
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
if message["msg"] == "process_completed":
|
||||
blocks._queue.pending_event_ids_session[
|
||||
session_hash
|
||||
].remove(message["event_id"])
|
||||
if message["msg"] == "server_stopped" or (
|
||||
message["msg"] == "process_completed"
|
||||
and (
|
||||
len(
|
||||
blocks._queue.pending_event_ids_session[
|
||||
session_hash
|
||||
]
|
||||
)
|
||||
== 0
|
||||
)
|
||||
):
|
||||
return
|
||||
except asyncio.CancelledError as e:
|
||||
await blocks._queue.clean_event(event)
|
||||
del blocks._queue.pending_messages_per_session[session_hash]
|
||||
await blocks._queue.clean_events(session_hash=session_hash)
|
||||
raise e
|
||||
|
||||
return StreamingResponse(
|
||||
@ -659,14 +650,17 @@ class App(FastAPI):
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@app.post("/queue/data", dependencies=[Depends(login_check)])
|
||||
async def queue_data(
|
||||
@app.post("/queue/join", dependencies=[Depends(login_check)])
|
||||
async def queue_join(
|
||||
body: PredictBody,
|
||||
request: fastapi.Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
blocks = app.get_blocks()
|
||||
blocks._queue.attach_data(body)
|
||||
if blocks._queue.server_app is None:
|
||||
blocks._queue.set_server_app(app)
|
||||
|
||||
event_id = await blocks._queue.push(body, request, username)
|
||||
return {"event_id": event_id}
|
||||
|
||||
@app.post("/component_server", dependencies=[Depends(login_check)])
|
||||
@app.post("/component_server/", dependencies=[Depends(login_check)])
|
||||
|
@ -19,6 +19,7 @@ import typing
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from io import BytesIO
|
||||
from numbers import Number
|
||||
@ -28,6 +29,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Optional,
|
||||
@ -997,3 +999,20 @@ def convert_to_dict_if_dataclass(value):
|
||||
if dataclasses.is_dataclass(value):
|
||||
return dataclasses.asdict(value)
|
||||
return value
|
||||
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class LRUCache(OrderedDict, Generic[K, V]):
|
||||
def __init__(self, max_size: int = 100):
|
||||
super().__init__()
|
||||
self.max_size: int = max_size
|
||||
|
||||
def __setitem__(self, key: K, value: V) -> None:
|
||||
if key in self:
|
||||
self.move_to_end(key)
|
||||
elif len(self) >= self.max_size:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
@ -1,20 +1,10 @@
|
||||
import { test, expect } from "@gradio/tootils";
|
||||
|
||||
test(".success should not run if function fails", async ({ page }) => {
|
||||
let last_iteration;
|
||||
const textbox = page.getByLabel("Result");
|
||||
await expect(textbox).toHaveValue("");
|
||||
|
||||
page.on("websocket", (ws) => {
|
||||
last_iteration = ws.waitForEvent("framereceived", {
|
||||
predicate: (event) => {
|
||||
return JSON.parse(event.payload as string).msg === "process_completed";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
await page.click("text=Trigger Failure");
|
||||
await last_iteration;
|
||||
expect(textbox).toHaveValue("");
|
||||
});
|
||||
|
||||
@ -38,17 +28,7 @@ test("Consecutive .success event is triggered successfully", async ({
|
||||
});
|
||||
|
||||
test("gr.Error makes the toast show up", async ({ page }) => {
|
||||
let complete;
|
||||
page.on("websocket", (ws) => {
|
||||
complete = ws.waitForEvent("framereceived", {
|
||||
predicate: (event) => {
|
||||
return JSON.parse(event.payload as string).msg === "process_completed";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
await page.click("text=Trigger Failure");
|
||||
await complete;
|
||||
|
||||
const toast = page.getByTestId("toast-body");
|
||||
expect(toast).toContainText("error");
|
||||
@ -60,17 +40,7 @@ test("gr.Error makes the toast show up", async ({ page }) => {
|
||||
test("ValueError makes the toast show up when show_error=True", async ({
|
||||
page
|
||||
}) => {
|
||||
let complete;
|
||||
page.on("websocket", (ws) => {
|
||||
complete = ws.waitForEvent("framereceived", {
|
||||
predicate: (event) => {
|
||||
return JSON.parse(event.payload as string).msg === "process_completed";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
await page.click("text=Trigger Failure With ValueError");
|
||||
await complete;
|
||||
|
||||
const toast = page.getByTestId("toast-body");
|
||||
expect(toast).toContainText("error");
|
||||
|
@ -107,7 +107,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
host = f"{demo.local_url.replace('http', 'ws')}queue/join"
|
||||
host = f"{demo.local_url.replace('http', 'ws')}queue/data"
|
||||
data = asyncio.run(main(host, n_results=args.n_jobs))
|
||||
data = dict(zip(data["fn_to_hit"], data["duration"]))
|
||||
|
||||
|
@ -8,7 +8,7 @@ import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import gradio_client as grc
|
||||
import pytest
|
||||
from gradio_client import media_data, utils
|
||||
from pydub import AudioSegment
|
||||
@ -660,50 +660,29 @@ class TestProgressBar:
|
||||
button.click(greet, name, greeting)
|
||||
demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
||||
|
||||
progress_updates = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
f"http://localhost:{demo.server_port}/queue/join",
|
||||
params={"fn_index": 0, "session_hash": "shdce"},
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
if line.startswith("data:"):
|
||||
msg = json.loads(line[5:])
|
||||
if msg["msg"] == "send_data":
|
||||
event_id = msg["event_id"]
|
||||
req = await client.post(
|
||||
f"http://localhost:{demo.server_port}/queue/data",
|
||||
json={
|
||||
"event_id": event_id,
|
||||
"data": [0],
|
||||
"fn_index": 0,
|
||||
},
|
||||
)
|
||||
if not req.is_success:
|
||||
raise ValueError(
|
||||
f"Could not send payload to endpoint: {req.text}"
|
||||
)
|
||||
if msg["msg"] == "progress":
|
||||
progress_updates.append(msg["progress_data"])
|
||||
if msg["msg"] == "process_completed":
|
||||
break
|
||||
client = grc.Client(demo.local_url)
|
||||
job = client.submit("Gradio")
|
||||
|
||||
assert progress_updates == [
|
||||
[
|
||||
{
|
||||
"index": None,
|
||||
"length": None,
|
||||
"unit": "steps",
|
||||
"progress": 0.0,
|
||||
"desc": "start",
|
||||
}
|
||||
],
|
||||
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
status_updates = []
|
||||
while not job.done():
|
||||
status = job.status()
|
||||
update = (
|
||||
status.progress_data[0].index if status.progress_data else None,
|
||||
status.progress_data[0].desc if status.progress_data else None,
|
||||
)
|
||||
if update != (None, None) and (
|
||||
len(status_updates) == 0 or status_updates[-1] != update
|
||||
):
|
||||
status_updates.append(update)
|
||||
time.sleep(0.05)
|
||||
|
||||
assert status_updates == [
|
||||
(None, "start"),
|
||||
(0, None),
|
||||
(1, None),
|
||||
(2, None),
|
||||
(3, None),
|
||||
(4, None),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -726,77 +705,32 @@ class TestProgressBar:
|
||||
button.click(greet, name, greeting)
|
||||
demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
||||
|
||||
progress_updates = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
f"http://localhost:{demo.server_port}/queue/join",
|
||||
params={"fn_index": 0, "session_hash": "shdce"},
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
if line.startswith("data:"):
|
||||
msg = json.loads(line[5:])
|
||||
if msg["msg"] == "send_data":
|
||||
event_id = msg["event_id"]
|
||||
req = await client.post(
|
||||
f"http://localhost:{demo.server_port}/queue/data",
|
||||
json={
|
||||
"event_id": event_id,
|
||||
"data": [0],
|
||||
"fn_index": 0,
|
||||
},
|
||||
)
|
||||
if not req.is_success:
|
||||
raise ValueError(
|
||||
f"Could not send payload to endpoint: {req.text}"
|
||||
)
|
||||
if msg["msg"] == "progress":
|
||||
progress_updates.append(msg["progress_data"])
|
||||
if msg["msg"] == "process_completed":
|
||||
break
|
||||
client = grc.Client(demo.local_url)
|
||||
job = client.submit("Gradio")
|
||||
|
||||
assert progress_updates == [
|
||||
[
|
||||
{
|
||||
"index": None,
|
||||
"length": None,
|
||||
"unit": "steps",
|
||||
"progress": 0.0,
|
||||
"desc": "start",
|
||||
}
|
||||
],
|
||||
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
|
||||
[
|
||||
{
|
||||
"index": 0,
|
||||
"length": 3,
|
||||
"unit": "steps",
|
||||
"progress": None,
|
||||
"desc": "alphabet",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"length": 3,
|
||||
"unit": "steps",
|
||||
"progress": None,
|
||||
"desc": "alphabet",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"index": 2,
|
||||
"length": 3,
|
||||
"unit": "steps",
|
||||
"progress": None,
|
||||
"desc": "alphabet",
|
||||
}
|
||||
],
|
||||
status_updates = []
|
||||
while not job.done():
|
||||
status = job.status()
|
||||
update = (
|
||||
status.progress_data[0].index if status.progress_data else None,
|
||||
status.progress_data[0].desc if status.progress_data else None,
|
||||
)
|
||||
if update != (None, None) and (
|
||||
len(status_updates) == 0 or status_updates[-1] != update
|
||||
):
|
||||
status_updates.append(update)
|
||||
time.sleep(0.05)
|
||||
|
||||
assert status_updates == [
|
||||
(None, "start"),
|
||||
(0, None),
|
||||
(1, None),
|
||||
(2, None),
|
||||
(3, None),
|
||||
(4, None),
|
||||
(0, "alphabet"),
|
||||
(1, "alphabet"),
|
||||
(2, "alphabet"),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -811,63 +745,29 @@ class TestProgressBar:
|
||||
demo = gr.Interface(greet, "text", "text")
|
||||
demo.queue().launch(prevent_thread_lock=True)
|
||||
|
||||
progress_updates = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
f"http://localhost:{demo.server_port}/queue/join",
|
||||
params={"fn_index": 0, "session_hash": "shdce"},
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
if line.startswith("data:"):
|
||||
msg = json.loads(line[5:])
|
||||
if msg["msg"] == "send_data":
|
||||
event_id = msg["event_id"]
|
||||
req = await client.post(
|
||||
f"http://localhost:{demo.server_port}/queue/data",
|
||||
json={
|
||||
"event_id": event_id,
|
||||
"data": ["abc"],
|
||||
"fn_index": 0,
|
||||
},
|
||||
)
|
||||
if not req.is_success:
|
||||
raise ValueError(
|
||||
f"Could not send payload to endpoint: {req.text}"
|
||||
)
|
||||
if msg["msg"] == "progress":
|
||||
progress_updates.append(msg["progress_data"])
|
||||
if msg["msg"] == "process_completed":
|
||||
break
|
||||
client = grc.Client(demo.local_url)
|
||||
job = client.submit("Gradio")
|
||||
|
||||
assert progress_updates == [
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"length": 3,
|
||||
"unit": "steps",
|
||||
"progress": None,
|
||||
"desc": None,
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"index": 2,
|
||||
"length": 3,
|
||||
"unit": "steps",
|
||||
"progress": None,
|
||||
"desc": None,
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"index": 3,
|
||||
"length": 3,
|
||||
"unit": "steps",
|
||||
"progress": None,
|
||||
"desc": None,
|
||||
}
|
||||
],
|
||||
status_updates = []
|
||||
while not job.done():
|
||||
status = job.status()
|
||||
update = (
|
||||
status.progress_data[0].index if status.progress_data else None,
|
||||
status.progress_data[0].unit if status.progress_data else None,
|
||||
)
|
||||
if update != (None, None) and (
|
||||
len(status_updates) == 0 or status_updates[-1] != update
|
||||
):
|
||||
status_updates.append(update)
|
||||
time.sleep(0.05)
|
||||
|
||||
assert status_updates == [
|
||||
(1, "steps"),
|
||||
(2, "steps"),
|
||||
(3, "steps"),
|
||||
(4, "steps"),
|
||||
(5, "steps"),
|
||||
(6, "steps"),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -878,45 +778,30 @@ class TestProgressBar:
|
||||
time.sleep(0.15)
|
||||
if len(s) < 5:
|
||||
gr.Warning("Too short!")
|
||||
time.sleep(0.15)
|
||||
return f"Hello, {s}!"
|
||||
|
||||
demo = gr.Interface(greet, "text", "text")
|
||||
demo.queue().launch(prevent_thread_lock=True)
|
||||
|
||||
log_messages = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
f"http://localhost:{demo.server_port}/queue/join",
|
||||
params={"fn_index": 0, "session_hash": "shdce"},
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
if line.startswith("data:"):
|
||||
msg = json.loads(line[5:])
|
||||
if msg["msg"] == "send_data":
|
||||
event_id = msg["event_id"]
|
||||
req = await client.post(
|
||||
f"http://localhost:{demo.server_port}/queue/data",
|
||||
json={
|
||||
"event_id": event_id,
|
||||
"data": ["abc"],
|
||||
"fn_index": 0,
|
||||
},
|
||||
)
|
||||
if not req.is_success:
|
||||
raise ValueError(
|
||||
f"Could not send payload to endpoint: {req.text}"
|
||||
)
|
||||
if msg["msg"] == "log":
|
||||
log_messages.append([msg["log"], msg["level"]])
|
||||
if msg["msg"] == "process_completed":
|
||||
break
|
||||
client = grc.Client(demo.local_url)
|
||||
job = client.submit("Jon")
|
||||
|
||||
assert log_messages == [
|
||||
["Letter a", "info"],
|
||||
["Letter b", "info"],
|
||||
["Letter c", "info"],
|
||||
["Too short!", "warning"],
|
||||
status_updates = []
|
||||
while not job.done():
|
||||
status = job.status()
|
||||
update = status.log
|
||||
if update is not None and (
|
||||
len(status_updates) == 0 or status_updates[-1] != update
|
||||
):
|
||||
status_updates.append(update)
|
||||
time.sleep(0.05)
|
||||
|
||||
assert status_updates == [
|
||||
("Letter J", "info"),
|
||||
("Letter o", "info"),
|
||||
("Letter n", "info"),
|
||||
("Too short!", "warning"),
|
||||
]
|
||||
|
||||
|
||||
@ -926,11 +811,13 @@ async def test_info_isolation(async_handler: bool):
|
||||
async def greet_async(name):
|
||||
await asyncio.sleep(2)
|
||||
gr.Info(f"Hello {name}")
|
||||
await asyncio.sleep(1)
|
||||
return name
|
||||
|
||||
def greet_sync(name):
|
||||
time.sleep(2)
|
||||
gr.Info(f"Hello {name}")
|
||||
time.sleep(1)
|
||||
return name
|
||||
|
||||
demo = gr.Interface(
|
||||
@ -942,42 +829,24 @@ async def test_info_isolation(async_handler: bool):
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
|
||||
async def session_interaction(name, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
client = grc.Client(demo.local_url)
|
||||
job = client.submit(name)
|
||||
|
||||
log_messages = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
f"http://localhost:{demo.server_port}/queue/join",
|
||||
params={"fn_index": 0, "session_hash": name},
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
if line.startswith("data:"):
|
||||
msg = json.loads(line[5:])
|
||||
if msg["msg"] == "send_data":
|
||||
event_id = msg["event_id"]
|
||||
req = await client.post(
|
||||
f"http://localhost:{demo.server_port}/queue/data",
|
||||
json={
|
||||
"event_id": event_id,
|
||||
"data": [name],
|
||||
"fn_index": 0,
|
||||
},
|
||||
)
|
||||
if not req.is_success:
|
||||
raise ValueError(
|
||||
f"Could not send payload to endpoint: {req.text}"
|
||||
)
|
||||
if msg["msg"] == "log":
|
||||
log_messages.append(msg["log"])
|
||||
if msg["msg"] == "process_completed":
|
||||
break
|
||||
return log_messages
|
||||
status_updates = []
|
||||
while not job.done():
|
||||
status = job.status()
|
||||
update = status.log
|
||||
if update is not None and (
|
||||
len(status_updates) == 0 or status_updates[-1] != update
|
||||
):
|
||||
status_updates.append(update)
|
||||
time.sleep(0.05)
|
||||
return status_updates[-1][0] if status_updates else None
|
||||
|
||||
alice_logs, bob_logs = await asyncio.gather(
|
||||
session_interaction("Alice"),
|
||||
session_interaction("Bob", delay=1),
|
||||
)
|
||||
|
||||
assert alice_logs == ["Hello Alice"]
|
||||
assert bob_logs == ["Hello Bob"]
|
||||
assert alice_logs == "Hello Alice"
|
||||
assert bob_logs == "Hello Bob"
|
||||
|
@ -18,8 +18,6 @@ class TestQueueing:
|
||||
|
||||
name.submit(greet, name, output)
|
||||
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
|
||||
with connect(demo) as client:
|
||||
job = client.submit("x", fn_index=0)
|
||||
assert job.result() == "Hello, x!"
|
||||
@ -92,7 +90,7 @@ class TestQueueing:
|
||||
|
||||
@add_btn.click(inputs=[a, b], outputs=output)
|
||||
def add(x, y):
|
||||
time.sleep(2)
|
||||
time.sleep(4)
|
||||
return x + y
|
||||
|
||||
demo.queue(default_concurrency_limit=default_concurrency_limit)
|
||||
@ -105,7 +103,7 @@ class TestQueueing:
|
||||
add_job_2 = client.submit(1, 1, fn_index=0)
|
||||
add_job_3 = client.submit(1, 1, fn_index=0)
|
||||
|
||||
time.sleep(1)
|
||||
time.sleep(2)
|
||||
|
||||
add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()]
|
||||
assert sorted([s.code.value for s in add_job_statuses]) == statuses
|
||||
@ -161,12 +159,11 @@ class TestQueueing:
|
||||
sub_job_1 = client.submit(1, 1, fn_index=1)
|
||||
sub_job_2 = client.submit(1, 1, fn_index=1)
|
||||
sub_job_3 = client.submit(1, 1, fn_index=1)
|
||||
sub_job_3 = client.submit(1, 1, fn_index=1)
|
||||
mul_job_1 = client.submit(1, 1, fn_index=2)
|
||||
div_job_1 = client.submit(1, 1, fn_index=3)
|
||||
mul_job_2 = client.submit(1, 1, fn_index=2)
|
||||
|
||||
time.sleep(1)
|
||||
time.sleep(2)
|
||||
|
||||
add_job_statuses = [
|
||||
add_job_1.status(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user