Fix api event drops (#6556)

* changes

* changes

* add changeset

* changes

* changes

* changes

* changs

* chagnes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes~git push

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* change

* changes

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-25-241.us-west-2.compute.internal>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2023-12-12 15:24:46 -08:00 committed by GitHub
parent 67ddd40b4b
commit d76bcaaaf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 671 additions and 437 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/client": patch
"gradio": patch
"gradio_client": patch
---
fix:Fix api event drops

View File

@ -278,6 +278,9 @@ export function api_factory(
const session_hash = Math.random().toString(36).substring(2);
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
let config: Config;
let api_map: Record<string, number> = {};
@ -437,7 +440,7 @@ export function api_factory(
let websocket: WebSocket;
let eventSource: EventSource;
let protocol = config.protocol ?? "sse";
let protocol = config.protocol ?? "ws";
const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
let payload: Payload;
@ -646,7 +649,7 @@ export function api_factory(
websocket.send(JSON.stringify({ hash: session_hash }))
);
}
} else {
} else if (protocol == "sse") {
fire_event({
type: "status",
stage: "pending",
@ -766,6 +769,121 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1") {
fire_event({
type: "status",
stage: "pending",
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
post_data(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/join?${url_params}`,
{
...payload,
session_hash
},
hf_token
).then(([response, status]) => {
if (status !== 200) {
fire_event({
type: "status",
stage: "error",
message: BROKEN_CONNECTION_MSG,
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
} else {
event_id = response.event_id as string;
if (!stream_open) {
open_stream();
}
let callback = async function (_data: object): void {
const { type, status, data } = handle_message(
_data,
last_status[fn_index]
);
if (type === "update" && status && !complete) {
// call 'status' listeners
fire_event({
type: "status",
endpoint: _endpoint,
fn_index,
time: new Date(),
...status
});
} else if (type === "complete") {
complete = status;
} else if (type === "log") {
fire_event({
type: "log",
log: data.log,
level: data.level,
endpoint: _endpoint,
fn_index
});
} else if (type === "generating") {
fire_event({
type: "status",
time: new Date(),
...status,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
if (data) {
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
endpoint: _endpoint,
fn_index
});
if (complete) {
fire_event({
type: "status",
time: new Date(),
...complete,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
}
if (status.stage === "complete" || status.stage === "error") {
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
if (Object.keys(event_callbacks).length === 0) {
close_stream();
}
}
}
};
event_callbacks[event_id] = callback;
}
});
}
});
@ -864,6 +982,30 @@ export function api_factory(
};
}
function open_stream(): void {
stream_open = true;
let params = new URLSearchParams({
session_hash: session_hash
}).toString();
let url = new URL(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/data?${params}`
);
event_stream = new EventSource(url);
event_stream.onmessage = async function (event) {
let _data = JSON.parse(event.data);
await event_callbacks[_data.event_id](_data);
};
}
function close_stream(): void {
stream_open = false;
event_stream?.close();
}
async function component_server(
component_id: number,
fn_name: string,

View File

@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse" | "ws";
protocol?: "sse_v1" | "sse" | "ws";
}
export interface Payload {

View File

@ -36,6 +36,7 @@ from gradio_client.exceptions import SerializationSetupError
from gradio_client.utils import (
Communicator,
JobStatus,
Message,
Status,
StatusUpdate,
)
@ -124,25 +125,33 @@ class Client:
if self.verbose:
print(f"Loaded as API: {self.src}")
if auth is not None:
self._login(auth)
self.config = self._get_config()
self.protocol: str = self.config.get("protocol", "ws")
self.api_url = urllib.parse.urljoin(self.src, utils.API_URL)
self.sse_url = urllib.parse.urljoin(self.src, utils.SSE_URL)
self.sse_data_url = urllib.parse.urljoin(self.src, utils.SSE_DATA_URL)
self.sse_url = urllib.parse.urljoin(
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
)
self.sse_data_url = urllib.parse.urljoin(
self.src,
utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL,
)
self.ws_url = urllib.parse.urljoin(
self.src.replace("http", "ws", 1), utils.WS_URL
)
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
if auth is not None:
self._login(auth)
self.config = self._get_config()
self.app_version = version.parse(self.config.get("version", "2.0"))
self._info = self._get_api_info()
self.session_hash = str(uuid.uuid4())
protocol = self.config.get("protocol")
endpoint_class = Endpoint if protocol == "sse" else EndpointV3Compatibility
endpoint_class = (
Endpoint if self.protocol.startswith("sse") else EndpointV3Compatibility
)
self.endpoints = [
endpoint_class(self, fn_index, dependency)
endpoint_class(self, fn_index, dependency, self.protocol)
for fn_index, dependency in enumerate(self.config["dependencies"])
]
@ -152,6 +161,84 @@ class Client:
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
threading.Thread(target=self._telemetry_thread).start()
self.stream_open = False
self.streaming_future: Future | None = None
self.pending_messages_per_event: dict[str, list[Message | None]] = {}
self.pending_event_ids: set[str] = set()
async def stream_messages(self) -> None:
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
buffer = ""
async with client.stream(
"GET",
self.sse_url,
params={"session_hash": self.session_hash},
headers=self.headers,
cookies=self.cookies,
) as response:
async for line in response.aiter_text():
buffer += line
while "\n\n" in buffer:
message, buffer = buffer.split("\n\n", 1)
if message.startswith("data:"):
resp = json.loads(message[5:])
if resp["msg"] == "heartbeat":
continue
elif resp["msg"] == "server_stopped":
for (
pending_messages
) in self.pending_messages_per_event.values():
pending_messages.append(resp)
return
event_id = resp["event_id"]
if event_id not in self.pending_messages_per_event:
self.pending_messages_per_event[event_id] = []
self.pending_messages_per_event[event_id].append(resp)
if resp["msg"] == "process_completed":
self.pending_event_ids.remove(event_id)
if len(self.pending_event_ids) == 0:
self.stream_open = False
return
elif message == "":
continue
else:
raise ValueError(f"Unexpected SSE line: '{message}'")
except BaseException as e:
import traceback
traceback.print_exc()
raise e
async def send_data(self, data, hash_data):
async with httpx.AsyncClient() as client:
req = await client.post(
self.sse_data_url,
json={**data, **hash_data},
headers=self.headers,
cookies=self.cookies,
)
req.raise_for_status()
resp = req.json()
event_id = resp["event_id"]
if not self.stream_open:
self.stream_open = True
def open_stream():
return utils.synchronize_async(self.stream_messages)
def close_stream(_):
self.stream_open = False
for _, pending_messages in self.pending_messages_per_event.items():
pending_messages.append(None)
if self.streaming_future is None or self.streaming_future.done():
self.streaming_future = self.executor.submit(open_stream)
self.streaming_future.add_done_callback(close_stream)
return event_id
@classmethod
def duplicate(
cls,
@ -340,7 +427,7 @@ class Client:
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
helper = None
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse"):
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
@ -806,7 +893,9 @@ class ReplaceMe:
class Endpoint:
"""Helper class for storing all the information about a single API endpoint."""
def __init__(self, client: Client, fn_index: int, dependency: dict):
def __init__(
self, client: Client, fn_index: int, dependency: dict, protocol: str = "sse_v1"
):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
@ -814,7 +903,7 @@ class Endpoint:
self.api_name: str | Literal[False] | None = (
"/" + api_name if isinstance(api_name, str) else api_name
)
self.protocol = "sse"
self.protocol = protocol
self.input_component_types = [
self._get_component_type(id_) for id_ in dependency["inputs"]
]
@ -891,7 +980,20 @@ class Endpoint:
"session_hash": self.client.session_hash,
}
result = utils.synchronize_async(self._sse_fn, data, hash_data, helper)
if self.protocol == "sse":
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol == "sse_v1":
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")
if "error" in result:
raise ValueError(result["error"])
@ -1068,24 +1170,33 @@ class Endpoint:
predictions = self.reduce_singleton_output(*predictions)
return predictions
async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator):
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
return await utils.get_pred_from_sse(
return await utils.get_pred_from_sse_v0(
client,
data,
hash_data,
helper,
sse_url=self.client.sse_url,
sse_data_url=self.client.sse_data_url,
headers=self.client.headers,
cookies=self.client.cookies,
self.client.sse_url,
self.client.sse_data_url,
self.client.headers,
self.client.cookies,
)
async def _sse_fn_v1(self, helper: Communicator, event_id: str):
return await utils.get_pred_from_sse_v1(
helper,
self.client.headers,
self.client.cookies,
self.client.pending_messages_per_event,
event_id,
)
class EndpointV3Compatibility:
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""
def __init__(self, client: Client, fn_index: int, dependency: dict):
def __init__(self, client: Client, fn_index: int, dependency: dict, *args):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency

View File

@ -17,7 +17,7 @@ from datetime import datetime
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TypedDict
import fsspec.asyn
import httpx
@ -26,8 +26,10 @@ from huggingface_hub import SpaceStage
from websockets.legacy.protocol import WebSocketCommonProtocol
API_URL = "api/predict/"
SSE_URL = "queue/join"
SSE_DATA_URL = "queue/data"
SSE_URL_V0 = "queue/join"
SSE_DATA_URL_V0 = "queue/data"
SSE_URL = "queue/data"
SSE_DATA_URL = "queue/join"
WS_URL = "queue/join"
UPLOAD_URL = "upload"
LOGIN_URL = "login"
@ -48,6 +50,19 @@ INVALID_RUNTIME = [
]
class Message(TypedDict, total=False):
msg: str
output: dict[str, Any]
event_id: str
rank: int
rank_eta: float
queue_size: int
success: bool
progress_data: list[dict]
log: str
level: str
def get_package_version() -> str:
try:
package_json_data = (
@ -100,6 +115,7 @@ class Status(Enum):
PROGRESS = "PROGRESS"
FINISHED = "FINISHED"
CANCELLED = "CANCELLED"
LOG = "LOG"
@staticmethod
def ordering(status: Status) -> int:
@ -133,6 +149,7 @@ class Status(Enum):
"process_generating": Status.ITERATING,
"process_completed": Status.FINISHED,
"progress": Status.PROGRESS,
"log": Status.LOG,
}[msg]
@ -169,6 +186,7 @@ class StatusUpdate:
success: bool | None
time: datetime | None
progress_data: list[ProgressUnit] | None
log: tuple[str, str] | None = None
def create_initial_status_update():
@ -307,7 +325,7 @@ async def get_pred_from_ws(
return resp["output"]
async def get_pred_from_sse(
async def get_pred_from_sse_v0(
client: httpx.AsyncClient,
data: dict,
hash_data: dict,
@ -315,21 +333,21 @@ async def get_pred_from_sse(
sse_url: str,
sse_data_url: str,
headers: dict[str, str],
cookies: dict[str, str] | None = None,
cookies: dict[str, str] | None,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, cookies)),
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(
stream_sse(
stream_sse_v0(
client,
data,
hash_data,
helper,
sse_url,
sse_data_url,
headers=headers,
cookies=cookies,
headers,
cookies,
)
),
],
@ -348,7 +366,42 @@ async def get_pred_from_sse(
return task.result()
async def check_for_cancel(helper: Communicator, cookies: dict[str, str] | None):
async def get_pred_from_sse_v1(
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(
stream_sse_v1(
helper,
pending_messages_per_event,
event_id,
)
),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
assert len(done) == 1
for task in done:
return task.result()
async def check_for_cancel(
helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None
):
while True:
await asyncio.sleep(0.05)
with helper.lock:
@ -357,12 +410,15 @@ async def check_for_cancel(helper: Communicator, cookies: dict[str, str] | None)
if helper.event_id:
async with httpx.AsyncClient() as http:
await http.post(
helper.reset_url, json={"event_id": helper.event_id}, cookies=cookies
helper.reset_url,
json={"event_id": helper.event_id},
headers=headers,
cookies=cookies,
)
raise CancelledError()
async def stream_sse(
async def stream_sse_v0(
client: httpx.AsyncClient,
data: dict,
hash_data: dict,
@ -370,15 +426,15 @@ async def stream_sse(
sse_url: str,
sse_data_url: str,
headers: dict[str, str],
cookies: dict[str, str] | None = None,
cookies: dict[str, str] | None,
) -> dict[str, Any]:
try:
async with client.stream(
"GET",
sse_url,
params=hash_data,
cookies=cookies,
headers=headers,
cookies=cookies,
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
@ -413,8 +469,8 @@ async def stream_sse(
req = await client.post(
sse_data_url,
json={"event_id": event_id, **data, **hash_data},
cookies=cookies,
headers=headers,
cookies=cookies,
)
req.raise_for_status()
elif resp["msg"] == "process_completed":
@ -426,6 +482,64 @@ async def stream_sse(
raise
async def stream_sse_v1(
helper: Communicator,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
) -> dict[str, Any]:
try:
pending_messages = pending_messages_per_event[event_id]
while True:
if len(pending_messages) > 0:
msg = pending_messages.pop(0)
else:
await asyncio.sleep(0.05)
continue
if msg is None:
raise CancelledError()
with helper.lock:
log_message = None
if msg["msg"] == "log":
log = msg.get("log")
level = msg.get("level")
if log and level:
log_message = (log, level)
status_update = StatusUpdate(
code=Status.msg_to_status(msg["msg"]),
queue_size=msg.get("queue_size"),
rank=msg.get("rank", None),
success=msg.get("success"),
time=datetime.now(),
eta=msg.get("rank_eta"),
progress_data=ProgressUnit.from_msg(msg["progress_data"])
if "progress_data" in msg
else None,
log=log_message,
)
output = msg.get("output", {}).get("data", [])
if output and status_update.code != Status.FINISHED:
try:
result = helper.prediction_processor(*output)
except Exception as e:
result = [e]
helper.job.outputs.append(result)
helper.job.latest_status = status_update
if msg["msg"] == "queue_full":
raise QueueError("Queue is full! Please try again.")
elif msg["msg"] == "process_completed":
del pending_messages_per_event[event_id]
return msg["output"]
elif msg["msg"] == "server_stopped":
raise ValueError("Server stopped.")
except asyncio.CancelledError:
raise
########################
# Data processing utils
########################

View File

@ -95,6 +95,17 @@ class TestClientPredictions:
output = client.predict("abc", api_name="/predict")
assert output == "abc"
@pytest.mark.flaky
def test_private_space_v4_sse_v1(self):
space_id = "gradio-tests/not-actually-private-spacev4-sse-v1"
api = huggingface_hub.HfApi()
assert api.space_info(space_id).private
client = Client(
space_id,
)
output = client.predict("abc", api_name="/predict")
assert output == "abc"
def test_state(self, increment_demo):
with connect(increment_demo) as client:
output = client.predict(api_name="/increment_without_queue")

View File

@ -893,6 +893,13 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
fn = get_continuous_fn(fn, every)
elif every:
raise ValueError("Cannot set a value for `every` without a `fn`.")
if every and concurrency_limit is not None:
if concurrency_limit == "default":
concurrency_limit = None
else:
raise ValueError(
"Cannot set a value for `concurrency_limit` with `every`."
)
if _targets[0][1] == "change" and trigger_mode is None:
trigger_mode = "always_last"
@ -1581,7 +1588,7 @@ Received outputs:
"is_colab": utils.colab_check(),
"stylesheets": self.stylesheets,
"theme": self.theme.name,
"protocol": "sse",
"protocol": "sse_v1",
}
def get_layout(block):
@ -2169,7 +2176,7 @@ Received outputs:
try:
if wasm_utils.IS_WASM:
# NOTE:
# Normally, queue-related async tasks (e.g. continuous events created by `gr.Blocks.load(..., every=interval)`, whose async tasks are started at the `/queue/join` endpoint function)
# Normally, queue-related async tasks (e.g. continuous events created by `gr.Blocks.load(..., every=interval)`, whose async tasks are started at the `/queue/data` endpoint function)
# are running in an event loop in the server thread,
# so they will be cancelled by `self.server.close()` below.
# However, in the Wasm env, we don't have the `server` and

View File

@ -23,7 +23,7 @@ from gradio.data_classes import (
)
from gradio.exceptions import Error
from gradio.helpers import TrackedIterable
from gradio.utils import run_coro_in_background, safe_get_lock, set_task_name
from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name
if TYPE_CHECKING:
from gradio.blocks import BlockFunction
@ -37,7 +37,6 @@ class Event:
request: fastapi.Request,
username: str | None,
):
self.message_queue = ThreadQueue()
self.session_hash = session_hash
self.fn_index = fn_index
self.request = request
@ -48,28 +47,6 @@ class Event:
self.progress_pending: bool = False
self.alive = True
def send_message(
self,
message_type: str,
data: dict | None = None,
final: bool = False,
):
data = {} if data is None else data
self.message_queue.put_nowait({"msg": message_type, **data})
if final:
self.message_queue.put_nowait(None)
async def get_data(self, timeout=5) -> bool:
self.send_message("send_data", {"event_id": self._id})
sleep_interval = 0.05
wait_time = 0
while wait_time < timeout and self.alive:
if self.data is not None:
break
await asyncio.sleep(sleep_interval)
wait_time += sleep_interval
return self.data is not None
class Queue:
def __init__(
@ -81,6 +58,9 @@ class Queue:
block_fns: list[BlockFunction],
default_concurrency_limit: int | None | Literal["not_set"] = "not_set",
):
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000)
self.pending_event_ids_session: dict[str, set[str]] = {}
self.pending_message_lock = safe_get_lock()
self.event_queue: list[Event] = []
self.awaiting_data_events: dict[str, Event] = {}
self.stopped = False
@ -132,6 +112,16 @@ class Queue:
def close(self):
self.stopped = True
def send_message(
self,
event: Event,
message_type: str,
data: dict | None = None,
):
data = {} if data is None else data
messages = self.pending_messages_per_session[event.session_hash]
messages.put_nowait({"msg": message_type, "event_id": event._id, **data})
def _resolve_concurrency_limit(self, default_concurrency_limit):
"""
Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination
@ -152,13 +142,33 @@ class Queue:
else:
return 1
def attach_data(self, body: PredictBody):
event_id = body.event_id
if event_id in self.awaiting_data_events:
event = self.awaiting_data_events[event_id]
event.data = body
else:
raise ValueError("Event not found", event_id)
async def push(
self, body: PredictBody, request: fastapi.Request, username: str | None
):
if body.session_hash is None:
raise ValueError("No session hash provided.")
if body.fn_index is None:
raise ValueError("No function index provided.")
queue_len = len(self.event_queue)
if self.max_size is not None and queue_len >= self.max_size:
raise ValueError(
f"Queue is full. Max size is {self.max_size} and current size is {queue_len}."
)
event = Event(body.session_hash, body.fn_index, request, username)
event.data = body
async with self.pending_message_lock:
if body.session_hash not in self.pending_messages_per_session:
self.pending_messages_per_session[body.session_hash] = ThreadQueue()
if body.session_hash not in self.pending_event_ids_session:
self.pending_event_ids_session[body.session_hash] = set()
self.pending_event_ids_session[body.session_hash].add(event._id)
self.event_queue.append(event)
estimation = self.get_estimation()
await self.send_estimation(event, estimation, queue_len)
return event._id
def _cancel_asyncio_tasks(self):
for task in self._asyncio_tasks:
@ -276,7 +286,7 @@ class Queue:
for event in events:
if event.progress_pending and event.progress:
event.progress_pending = False
event.send_message("progress", event.progress.model_dump())
self.send_message(event, "progress", event.progress.model_dump())
await asyncio.sleep(self.progress_update_sleep_when_free)
@ -320,34 +330,23 @@ class Queue:
log=log,
level=level,
)
event.send_message("log", log_message.model_dump())
self.send_message(event, "log", log_message.model_dump())
def push(self, event: Event) -> int | None:
"""
Add event to queue, or return None if Queue is full
Parameters:
event: Event to add to Queue
Returns:
rank of submitted Event
"""
queue_len = len(self.event_queue)
if self.max_size is not None and queue_len >= self.max_size:
return None
self.event_queue.append(event)
return queue_len
async def clean_events(
self, *, session_hash: str | None = None, event_id: str | None = None
) -> None:
for job_set in self.active_jobs:
if job_set:
for job in job_set:
if job.session_hash == session_hash or job._id == event_id:
job.alive = False
async def clean_event(self, event: Event | str) -> None:
if isinstance(event, str):
for job_set in self.active_jobs:
if job_set:
for job in job_set:
if job._id == event:
event = job
break
if isinstance(event, str):
raise ValueError("Event not found", event)
event.alive = False
if event in self.event_queue:
events_to_remove = []
for event in self.event_queue:
if event.session_hash == session_hash or event._id == event_id:
events_to_remove.append(event)
for event in events_to_remove:
async with self.delete_lock:
self.event_queue.remove(event)
@ -391,7 +390,7 @@ class Queue:
if None not in self.active_jobs:
# Add estimated amount of time for a thread to get empty
estimation.rank_eta += self.avg_concurrent_process_time
event.send_message("estimation", estimation.model_dump())
self.send_message(event, "estimation", estimation.model_dump())
return estimation
def update_estimation(self, duration: float) -> None:
@ -485,14 +484,7 @@ class Queue:
awake_events: list[Event] = []
try:
for event in events:
if not event.data:
self.awaiting_data_events[event._id] = event
client_awake = await event.get_data()
del self.awaiting_data_events[event._id]
if not client_awake:
await self.clean_event(event)
continue
event.send_message("process_starts")
self.send_message(event, "process_starts")
awake_events.append(event)
if not awake_events:
return
@ -505,7 +497,8 @@ class Queue:
response = None
err = e
for event in awake_events:
event.send_message(
self.send_message(
event,
"process_completed",
{
"output": {
@ -515,7 +508,6 @@ class Queue:
},
"success": False,
},
final=True,
)
if response and response.get("is_generating", False):
old_response = response
@ -524,7 +516,8 @@ class Queue:
old_response = response
old_err = err
for event in awake_events:
event.send_message(
self.send_message(
event,
"process_generating",
{
"output": old_response,
@ -545,7 +538,8 @@ class Queue:
relevant_response = err
else:
relevant_response = old_response or old_err
event.send_message(
self.send_message(
event,
"process_completed",
{
"output": {"error": str(relevant_response)}
@ -554,20 +548,19 @@ class Queue:
"success": relevant_response
and not isinstance(relevant_response, Exception),
},
final=True,
)
elif response:
output = copy.deepcopy(response)
for e, event in enumerate(awake_events):
if batch and "data" in output:
output["data"] = list(zip(*response.get("data")))[e]
event.send_message(
self.send_message(
event,
"process_completed",
{
"output": output,
"success": response is not None,
},
final=True,
)
end_time = time.time()
if response is not None:

View File

@ -55,7 +55,7 @@ from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
from gradio.exceptions import Error
from gradio.helpers import CACHED_FOLDER
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation, Event
from gradio.queueing import Estimation
from gradio.route_utils import ( # noqa: F401
FileUploadProgress,
GradioMultiPartParser,
@ -65,10 +65,7 @@ from gradio.route_utils import ( # noqa: F401
)
from gradio.state_holder import StateHolder
from gradio.utils import (
cancel_tasks,
get_package_version,
run_coro_in_background,
set_task_name,
)
if TYPE_CHECKING:
@ -532,7 +529,7 @@ class App(FastAPI):
async with app.lock:
del app.iterators[body.event_id]
app.iterators_to_reset.add(body.event_id)
await app.get_blocks()._queue.clean_event(body.event_id)
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
return {"success": True}
# had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility
@ -582,63 +579,38 @@ class App(FastAPI):
)
return output
@app.get("/queue/join", dependencies=[Depends(login_check)])
async def queue_join(
fn_index: int,
session_hash: str,
@app.get("/queue/data", dependencies=[Depends(login_check)])
async def queue_data(
request: fastapi.Request,
username: str = Depends(get_current_user),
data: Optional[str] = None,
session_hash: str,
):
blocks = app.get_blocks()
if blocks._queue.server_app is None:
blocks._queue.set_server_app(app)
event = Event(session_hash, fn_index, request, username)
if data is not None:
input_data = json.loads(data)
event.data = PredictBody(
session_hash=session_hash,
fn_index=fn_index,
data=input_data,
request=request,
)
# Continuous events are not put in the queue so that they do not
# occupy the queue's resource as they are expected to run forever
if blocks.dependencies[event.fn_index].get("every", 0):
await cancel_tasks({f"{event.session_hash}_{event.fn_index}"})
await blocks._queue.reset_iterators(event._id)
blocks._queue.continuous_tasks.append(event)
task = run_coro_in_background(
blocks._queue.process_events, [event], False
)
set_task_name(task, event.session_hash, event.fn_index, batch=False)
app._asyncio_tasks.append(task)
else:
rank = blocks._queue.push(event)
if rank is None:
event.send_message("queue_full", final=True)
else:
estimation = blocks._queue.get_estimation()
await blocks._queue.send_estimation(event, estimation, rank)
async def sse_stream(request: fastapi.Request):
try:
last_heartbeat = time.perf_counter()
while True:
if await request.is_disconnected():
await blocks._queue.clean_event(event)
if not event.alive:
await blocks._queue.clean_events(session_hash=session_hash)
return
if (
session_hash
not in blocks._queue.pending_messages_per_session
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session not found.",
)
heartbeat_rate = 15
check_rate = 0.05
message = None
try:
message = event.message_queue.get_nowait()
if message is None: # end of stream marker
return
messages = blocks._queue.pending_messages_per_session[
session_hash
]
message = messages.get_nowait()
except EmptyQueue:
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
@ -648,10 +620,29 @@ class App(FastAPI):
# and then the stream will retry leading to infinite queue 😬
last_heartbeat = time.perf_counter()
if blocks._queue.stopped:
message = {"msg": "server_stopped", "success": False}
if message:
yield f"data: {json.dumps(message)}\n\n"
if message["msg"] == "process_completed":
blocks._queue.pending_event_ids_session[
session_hash
].remove(message["event_id"])
if message["msg"] == "server_stopped" or (
message["msg"] == "process_completed"
and (
len(
blocks._queue.pending_event_ids_session[
session_hash
]
)
== 0
)
):
return
except asyncio.CancelledError as e:
await blocks._queue.clean_event(event)
del blocks._queue.pending_messages_per_session[session_hash]
await blocks._queue.clean_events(session_hash=session_hash)
raise e
return StreamingResponse(
@ -659,14 +650,17 @@ class App(FastAPI):
media_type="text/event-stream",
)
@app.post("/queue/data", dependencies=[Depends(login_check)])
async def queue_data(
@app.post("/queue/join", dependencies=[Depends(login_check)])
async def queue_join(
body: PredictBody,
request: fastapi.Request,
username: str = Depends(get_current_user),
):
blocks = app.get_blocks()
blocks._queue.attach_data(body)
if blocks._queue.server_app is None:
blocks._queue.set_server_app(app)
event_id = await blocks._queue.push(body, request, username)
return {"event_id": event_id}
@app.post("/component_server", dependencies=[Depends(login_check)])
@app.post("/component_server/", dependencies=[Depends(login_check)])

View File

@ -19,6 +19,7 @@ import typing
import urllib.parse
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from io import BytesIO
from numbers import Number
@ -28,6 +29,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
Iterator,
Optional,
@ -997,3 +999,20 @@ def convert_to_dict_if_dataclass(value):
if dataclasses.is_dataclass(value):
return dataclasses.asdict(value)
return value
K = TypeVar("K")
V = TypeVar("V")
class LRUCache(OrderedDict, Generic[K, V]):
def __init__(self, max_size: int = 100):
super().__init__()
self.max_size: int = max_size
def __setitem__(self, key: K, value: V) -> None:
if key in self:
self.move_to_end(key)
elif len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)

View File

@ -1,20 +1,10 @@
import { test, expect } from "@gradio/tootils";
test(".success should not run if function fails", async ({ page }) => {
let last_iteration;
const textbox = page.getByLabel("Result");
await expect(textbox).toHaveValue("");
page.on("websocket", (ws) => {
last_iteration = ws.waitForEvent("framereceived", {
predicate: (event) => {
return JSON.parse(event.payload as string).msg === "process_completed";
}
});
});
await page.click("text=Trigger Failure");
await last_iteration;
expect(textbox).toHaveValue("");
});
@ -38,17 +28,7 @@ test("Consecutive .success event is triggered successfully", async ({
});
test("gr.Error makes the toast show up", async ({ page }) => {
let complete;
page.on("websocket", (ws) => {
complete = ws.waitForEvent("framereceived", {
predicate: (event) => {
return JSON.parse(event.payload as string).msg === "process_completed";
}
});
});
await page.click("text=Trigger Failure");
await complete;
const toast = page.getByTestId("toast-body");
expect(toast).toContainText("error");
@ -60,17 +40,7 @@ test("gr.Error makes the toast show up", async ({ page }) => {
test("ValueError makes the toast show up when show_error=True", async ({
page
}) => {
let complete;
page.on("websocket", (ws) => {
complete = ws.waitForEvent("framereceived", {
predicate: (event) => {
return JSON.parse(event.payload as string).msg === "process_completed";
}
});
});
await page.click("text=Trigger Failure With ValueError");
await complete;
const toast = page.getByTestId("toast-body");
expect(toast).toContainText("error");

View File

@ -107,7 +107,7 @@ if __name__ == "__main__":
parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False)
args = parser.parse_args()
host = f"{demo.local_url.replace('http', 'ws')}queue/join"
host = f"{demo.local_url.replace('http', 'ws')}queue/data"
data = asyncio.run(main(host, n_results=args.n_jobs))
data = dict(zip(data["fn_to_hit"], data["duration"]))

View File

@ -8,7 +8,7 @@ import time
from pathlib import Path
from unittest.mock import patch
import httpx
import gradio_client as grc
import pytest
from gradio_client import media_data, utils
from pydub import AudioSegment
@ -660,50 +660,29 @@ class TestProgressBar:
button.click(greet, name, greeting)
demo.queue(max_size=1).launch(prevent_thread_lock=True)
progress_updates = []
async with httpx.AsyncClient() as client:
async with client.stream(
"GET",
f"http://localhost:{demo.server_port}/queue/join",
params={"fn_index": 0, "session_hash": "shdce"},
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
msg = json.loads(line[5:])
if msg["msg"] == "send_data":
event_id = msg["event_id"]
req = await client.post(
f"http://localhost:{demo.server_port}/queue/data",
json={
"event_id": event_id,
"data": [0],
"fn_index": 0,
},
)
if not req.is_success:
raise ValueError(
f"Could not send payload to endpoint: {req.text}"
)
if msg["msg"] == "progress":
progress_updates.append(msg["progress_data"])
if msg["msg"] == "process_completed":
break
client = grc.Client(demo.local_url)
job = client.submit("Gradio")
assert progress_updates == [
[
{
"index": None,
"length": None,
"unit": "steps",
"progress": 0.0,
"desc": "start",
}
],
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
status_updates = []
while not job.done():
status = job.status()
update = (
status.progress_data[0].index if status.progress_data else None,
status.progress_data[0].desc if status.progress_data else None,
)
if update != (None, None) and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
(None, "start"),
(0, None),
(1, None),
(2, None),
(3, None),
(4, None),
]
@pytest.mark.asyncio
@ -726,77 +705,32 @@ class TestProgressBar:
button.click(greet, name, greeting)
demo.queue(max_size=1).launch(prevent_thread_lock=True)
progress_updates = []
async with httpx.AsyncClient() as client:
async with client.stream(
"GET",
f"http://localhost:{demo.server_port}/queue/join",
params={"fn_index": 0, "session_hash": "shdce"},
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
msg = json.loads(line[5:])
if msg["msg"] == "send_data":
event_id = msg["event_id"]
req = await client.post(
f"http://localhost:{demo.server_port}/queue/data",
json={
"event_id": event_id,
"data": [0],
"fn_index": 0,
},
)
if not req.is_success:
raise ValueError(
f"Could not send payload to endpoint: {req.text}"
)
if msg["msg"] == "progress":
progress_updates.append(msg["progress_data"])
if msg["msg"] == "process_completed":
break
client = grc.Client(demo.local_url)
job = client.submit("Gradio")
assert progress_updates == [
[
{
"index": None,
"length": None,
"unit": "steps",
"progress": 0.0,
"desc": "start",
}
],
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[
{
"index": 0,
"length": 3,
"unit": "steps",
"progress": None,
"desc": "alphabet",
}
],
[
{
"index": 1,
"length": 3,
"unit": "steps",
"progress": None,
"desc": "alphabet",
}
],
[
{
"index": 2,
"length": 3,
"unit": "steps",
"progress": None,
"desc": "alphabet",
}
],
status_updates = []
while not job.done():
status = job.status()
update = (
status.progress_data[0].index if status.progress_data else None,
status.progress_data[0].desc if status.progress_data else None,
)
if update != (None, None) and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
(None, "start"),
(0, None),
(1, None),
(2, None),
(3, None),
(4, None),
(0, "alphabet"),
(1, "alphabet"),
(2, "alphabet"),
]
@pytest.mark.asyncio
@ -811,63 +745,29 @@ class TestProgressBar:
demo = gr.Interface(greet, "text", "text")
demo.queue().launch(prevent_thread_lock=True)
progress_updates = []
async with httpx.AsyncClient() as client:
async with client.stream(
"GET",
f"http://localhost:{demo.server_port}/queue/join",
params={"fn_index": 0, "session_hash": "shdce"},
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
msg = json.loads(line[5:])
if msg["msg"] == "send_data":
event_id = msg["event_id"]
req = await client.post(
f"http://localhost:{demo.server_port}/queue/data",
json={
"event_id": event_id,
"data": ["abc"],
"fn_index": 0,
},
)
if not req.is_success:
raise ValueError(
f"Could not send payload to endpoint: {req.text}"
)
if msg["msg"] == "progress":
progress_updates.append(msg["progress_data"])
if msg["msg"] == "process_completed":
break
client = grc.Client(demo.local_url)
job = client.submit("Gradio")
assert progress_updates == [
[
{
"index": 1,
"length": 3,
"unit": "steps",
"progress": None,
"desc": None,
}
],
[
{
"index": 2,
"length": 3,
"unit": "steps",
"progress": None,
"desc": None,
}
],
[
{
"index": 3,
"length": 3,
"unit": "steps",
"progress": None,
"desc": None,
}
],
status_updates = []
while not job.done():
status = job.status()
update = (
status.progress_data[0].index if status.progress_data else None,
status.progress_data[0].unit if status.progress_data else None,
)
if update != (None, None) and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
(1, "steps"),
(2, "steps"),
(3, "steps"),
(4, "steps"),
(5, "steps"),
(6, "steps"),
]
@pytest.mark.asyncio
@ -878,45 +778,30 @@ class TestProgressBar:
time.sleep(0.15)
if len(s) < 5:
gr.Warning("Too short!")
time.sleep(0.15)
return f"Hello, {s}!"
demo = gr.Interface(greet, "text", "text")
demo.queue().launch(prevent_thread_lock=True)
log_messages = []
async with httpx.AsyncClient() as client:
async with client.stream(
"GET",
f"http://localhost:{demo.server_port}/queue/join",
params={"fn_index": 0, "session_hash": "shdce"},
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
msg = json.loads(line[5:])
if msg["msg"] == "send_data":
event_id = msg["event_id"]
req = await client.post(
f"http://localhost:{demo.server_port}/queue/data",
json={
"event_id": event_id,
"data": ["abc"],
"fn_index": 0,
},
)
if not req.is_success:
raise ValueError(
f"Could not send payload to endpoint: {req.text}"
)
if msg["msg"] == "log":
log_messages.append([msg["log"], msg["level"]])
if msg["msg"] == "process_completed":
break
client = grc.Client(demo.local_url)
job = client.submit("Jon")
assert log_messages == [
["Letter a", "info"],
["Letter b", "info"],
["Letter c", "info"],
["Too short!", "warning"],
status_updates = []
while not job.done():
status = job.status()
update = status.log
if update is not None and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
("Letter J", "info"),
("Letter o", "info"),
("Letter n", "info"),
("Too short!", "warning"),
]
@ -926,11 +811,13 @@ async def test_info_isolation(async_handler: bool):
async def greet_async(name):
await asyncio.sleep(2)
gr.Info(f"Hello {name}")
await asyncio.sleep(1)
return name
def greet_sync(name):
time.sleep(2)
gr.Info(f"Hello {name}")
time.sleep(1)
return name
demo = gr.Interface(
@ -942,42 +829,24 @@ async def test_info_isolation(async_handler: bool):
demo.launch(prevent_thread_lock=True)
async def session_interaction(name, delay=0):
await asyncio.sleep(delay)
client = grc.Client(demo.local_url)
job = client.submit(name)
log_messages = []
async with httpx.AsyncClient() as client:
async with client.stream(
"GET",
f"http://localhost:{demo.server_port}/queue/join",
params={"fn_index": 0, "session_hash": name},
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
msg = json.loads(line[5:])
if msg["msg"] == "send_data":
event_id = msg["event_id"]
req = await client.post(
f"http://localhost:{demo.server_port}/queue/data",
json={
"event_id": event_id,
"data": [name],
"fn_index": 0,
},
)
if not req.is_success:
raise ValueError(
f"Could not send payload to endpoint: {req.text}"
)
if msg["msg"] == "log":
log_messages.append(msg["log"])
if msg["msg"] == "process_completed":
break
return log_messages
status_updates = []
while not job.done():
status = job.status()
update = status.log
if update is not None and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
return status_updates[-1][0] if status_updates else None
alice_logs, bob_logs = await asyncio.gather(
session_interaction("Alice"),
session_interaction("Bob", delay=1),
)
assert alice_logs == ["Hello Alice"]
assert bob_logs == ["Hello Bob"]
assert alice_logs == "Hello Alice"
assert bob_logs == "Hello Bob"

View File

@ -18,8 +18,6 @@ class TestQueueing:
name.submit(greet, name, output)
demo.launch(prevent_thread_lock=True)
with connect(demo) as client:
job = client.submit("x", fn_index=0)
assert job.result() == "Hello, x!"
@ -92,7 +90,7 @@ class TestQueueing:
@add_btn.click(inputs=[a, b], outputs=output)
def add(x, y):
time.sleep(2)
time.sleep(4)
return x + y
demo.queue(default_concurrency_limit=default_concurrency_limit)
@ -105,7 +103,7 @@ class TestQueueing:
add_job_2 = client.submit(1, 1, fn_index=0)
add_job_3 = client.submit(1, 1, fn_index=0)
time.sleep(1)
time.sleep(2)
add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()]
assert sorted([s.code.value for s in add_job_statuses]) == statuses
@ -161,12 +159,11 @@ class TestQueueing:
sub_job_1 = client.submit(1, 1, fn_index=1)
sub_job_2 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
sub_job_3 = client.submit(1, 1, fn_index=1)
mul_job_1 = client.submit(1, 1, fn_index=2)
div_job_1 = client.submit(1, 1, fn_index=3)
mul_job_2 = client.submit(1, 1, fn_index=2)
time.sleep(1)
time.sleep(2)
add_job_statuses = [
add_job_1.status(),