mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Python client properly handles hearbeat and log messages. Also handles responses longer than 65k (#6693)
* first commit * newlines * test * Fix depends * revert * add changeset * add changeset * Lint * queue full test * Add code * Update + fix * add changeset * Revert demo * Typo in success * Fix --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
a3cf90e57b
commit
34f943101b
7
.changeset/yummy-roses-decide.md
Normal file
7
.changeset/yummy-roses-decide.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/client": patch
|
||||
"gradio": patch
|
||||
"gradio_client": patch
|
||||
---
|
||||
|
||||
fix:Python client properly handles hearbeat and log messages. Also handles responses longer than 65k
|
@ -203,8 +203,16 @@ export function api_factory(
|
||||
} catch (e) {
|
||||
return [{ error: BROKEN_CONNECTION_MSG }, 500];
|
||||
}
|
||||
const output: PostResponse = await response.json();
|
||||
return [output, response.status];
|
||||
let output: PostResponse;
|
||||
let status: int;
|
||||
try {
|
||||
output = await response.json();
|
||||
status = response.status;
|
||||
} catch (e) {
|
||||
output = { error: `Could not parse server response: ${e}` };
|
||||
status = 500;
|
||||
}
|
||||
return [output, status];
|
||||
}
|
||||
|
||||
async function upload_files(
|
||||
@ -791,7 +799,17 @@ export function api_factory(
|
||||
},
|
||||
hf_token
|
||||
).then(([response, status]) => {
|
||||
if (status !== 200) {
|
||||
if (status === 503) {
|
||||
fire_event({
|
||||
type: "status",
|
||||
stage: "error",
|
||||
message: QUEUE_FULL_MSG,
|
||||
queue: true,
|
||||
endpoint: _endpoint,
|
||||
fn_index,
|
||||
time: new Date()
|
||||
});
|
||||
} else if (status !== 200) {
|
||||
fire_event({
|
||||
type: "status",
|
||||
stage: "error",
|
||||
@ -806,7 +824,6 @@ export function api_factory(
|
||||
if (!stream_open) {
|
||||
open_stream();
|
||||
}
|
||||
|
||||
let callback = async function (_data: object): void {
|
||||
const { type, status, data } = handle_message(
|
||||
_data,
|
||||
|
@ -37,6 +37,8 @@ from gradio_client.utils import (
|
||||
Communicator,
|
||||
JobStatus,
|
||||
Message,
|
||||
QueueError,
|
||||
ServerMessage,
|
||||
Status,
|
||||
StatusUpdate,
|
||||
)
|
||||
@ -169,7 +171,6 @@ class Client:
|
||||
async def stream_messages(self) -> None:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
|
||||
buffer = ""
|
||||
async with client.stream(
|
||||
"GET",
|
||||
self.sse_url,
|
||||
@ -177,33 +178,31 @@ class Client:
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
buffer += line
|
||||
while "\n\n" in buffer:
|
||||
message, buffer = buffer.split("\n\n", 1)
|
||||
if message.startswith("data:"):
|
||||
resp = json.loads(message[5:])
|
||||
if resp["msg"] == "heartbeat":
|
||||
continue
|
||||
elif resp["msg"] == "server_stopped":
|
||||
for (
|
||||
pending_messages
|
||||
) in self.pending_messages_per_event.values():
|
||||
pending_messages.append(resp)
|
||||
return
|
||||
event_id = resp["event_id"]
|
||||
if event_id not in self.pending_messages_per_event:
|
||||
self.pending_messages_per_event[event_id] = []
|
||||
self.pending_messages_per_event[event_id].append(resp)
|
||||
if resp["msg"] == "process_completed":
|
||||
self.pending_event_ids.remove(event_id)
|
||||
if len(self.pending_event_ids) == 0:
|
||||
self.stream_open = False
|
||||
return
|
||||
elif message == "":
|
||||
async for line in response.aiter_lines():
|
||||
line = line.rstrip("\n")
|
||||
if not len(line):
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
resp = json.loads(line[5:])
|
||||
if resp["msg"] == ServerMessage.heartbeat:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unexpected SSE line: '{message}'")
|
||||
elif resp["msg"] == ServerMessage.server_stopped:
|
||||
for (
|
||||
pending_messages
|
||||
) in self.pending_messages_per_event.values():
|
||||
pending_messages.append(resp)
|
||||
return
|
||||
event_id = resp["event_id"]
|
||||
if event_id not in self.pending_messages_per_event:
|
||||
self.pending_messages_per_event[event_id] = []
|
||||
self.pending_messages_per_event[event_id].append(resp)
|
||||
if resp["msg"] == ServerMessage.process_completed:
|
||||
self.pending_event_ids.remove(event_id)
|
||||
if len(self.pending_event_ids) == 0:
|
||||
self.stream_open = False
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Unexpected SSE line: '{line}'")
|
||||
except BaseException as e:
|
||||
import traceback
|
||||
|
||||
@ -218,6 +217,8 @@ class Client:
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
)
|
||||
if req.status_code == 503:
|
||||
raise QueueError("Queue is full! Please try again.")
|
||||
req.raise_for_status()
|
||||
resp = req.json()
|
||||
event_id = resp["event_id"]
|
||||
|
@ -102,6 +102,20 @@ class SpaceDuplicationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ServerMessage(str, Enum):
|
||||
send_hash = "send_hash"
|
||||
queue_full = "queue_full"
|
||||
estimation = "estimation"
|
||||
send_data = "send_data"
|
||||
process_starts = "process_starts"
|
||||
process_generating = "process_generating"
|
||||
process_completed = "process_completed"
|
||||
log = "log"
|
||||
progress = "progress"
|
||||
heartbeat = "heartbeat"
|
||||
server_stopped = "server_stopped"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
"""Status codes presented to client users."""
|
||||
|
||||
@ -141,16 +155,17 @@ class Status(Enum):
|
||||
def msg_to_status(msg: str) -> Status:
|
||||
"""Map the raw message from the backend to the status code presented to users."""
|
||||
return {
|
||||
"send_hash": Status.JOINING_QUEUE,
|
||||
"queue_full": Status.QUEUE_FULL,
|
||||
"estimation": Status.IN_QUEUE,
|
||||
"send_data": Status.SENDING_DATA,
|
||||
"process_starts": Status.PROCESSING,
|
||||
"process_generating": Status.ITERATING,
|
||||
"process_completed": Status.FINISHED,
|
||||
"progress": Status.PROGRESS,
|
||||
"log": Status.LOG,
|
||||
}[msg]
|
||||
ServerMessage.send_hash: Status.JOINING_QUEUE,
|
||||
ServerMessage.queue_full: Status.QUEUE_FULL,
|
||||
ServerMessage.estimation: Status.IN_QUEUE,
|
||||
ServerMessage.send_data: Status.SENDING_DATA,
|
||||
ServerMessage.process_starts: Status.PROCESSING,
|
||||
ServerMessage.process_generating: Status.ITERATING,
|
||||
ServerMessage.process_completed: Status.FINISHED,
|
||||
ServerMessage.progress: Status.PROGRESS,
|
||||
ServerMessage.log: Status.LOG,
|
||||
ServerMessage.server_stopped: Status.FINISHED,
|
||||
}[msg] # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -436,9 +451,14 @@ async def stream_sse_v0(
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
) as response:
|
||||
async for line in response.aiter_text():
|
||||
async for line in response.aiter_lines():
|
||||
line = line.rstrip("\n")
|
||||
if len(line) == 0:
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
resp = json.loads(line[5:])
|
||||
if resp["msg"] in [ServerMessage.log, ServerMessage.heartbeat]:
|
||||
continue
|
||||
with helper.lock:
|
||||
has_progress = "progress_data" in resp
|
||||
status_update = StatusUpdate(
|
||||
@ -502,7 +522,7 @@ async def stream_sse_v1(
|
||||
|
||||
with helper.lock:
|
||||
log_message = None
|
||||
if msg["msg"] == "log":
|
||||
if msg["msg"] == ServerMessage.log:
|
||||
log = msg.get("log")
|
||||
level = msg.get("level")
|
||||
if log and level:
|
||||
@ -527,13 +547,10 @@ async def stream_sse_v1(
|
||||
result = [e]
|
||||
helper.job.outputs.append(result)
|
||||
helper.job.latest_status = status_update
|
||||
|
||||
if msg["msg"] == "queue_full":
|
||||
raise QueueError("Queue is full! Please try again.")
|
||||
elif msg["msg"] == "process_completed":
|
||||
if msg["msg"] == ServerMessage.process_completed:
|
||||
del pending_messages_per_event[event_id]
|
||||
return msg["output"]
|
||||
elif msg["msg"] == "server_stopped":
|
||||
elif msg["msg"] == ServerMessage.server_stopped:
|
||||
raise ValueError("Server stopped.")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
|
@ -381,3 +381,18 @@ def gradio_temp_dir(monkeypatch, tmp_path):
|
||||
"""
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_response_with_info():
|
||||
def long_response(x):
|
||||
gr.Info("Beginning long response")
|
||||
time.sleep(17)
|
||||
gr.Info("Done!")
|
||||
return "\ta\nb" * 90000
|
||||
|
||||
return gr.Interface(
|
||||
long_response,
|
||||
None,
|
||||
gr.Textbox(label="Output"),
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import CancelledError, TimeoutError
|
||||
from concurrent.futures import CancelledError, TimeoutError, wait
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
@ -21,7 +21,13 @@ from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from gradio_client import Client
|
||||
from gradio_client.client import DEFAULT_TEMP_DIR
|
||||
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
|
||||
from gradio_client.utils import (
|
||||
Communicator,
|
||||
ProgressUnit,
|
||||
QueueError,
|
||||
Status,
|
||||
StatusUpdate,
|
||||
)
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
|
||||
|
||||
@ -488,6 +494,40 @@ class TestClientPredictions:
|
||||
assert demo.predict(api_name="/close") == 4
|
||||
assert demo.predict("Ali", api_name="/greeting") == ("Hello Ali", 5)
|
||||
|
||||
def test_long_response_time_with_gr_info_and_big_payload(
|
||||
self, long_response_with_info
|
||||
):
|
||||
with connect(long_response_with_info) as demo:
|
||||
assert demo.predict(api_name="/predict") == "\ta\nb" * 90000
|
||||
|
||||
def test_queue_full_raises_error(self):
|
||||
demo = gr.Interface(lambda s: f"Hello {s}", "textbox", "textbox").queue(
|
||||
max_size=1
|
||||
)
|
||||
with connect(demo) as client:
|
||||
with pytest.raises(QueueError):
|
||||
job1 = client.submit("Freddy", api_name="/predict")
|
||||
job2 = client.submit("Abubakar", api_name="/predict")
|
||||
job3 = client.submit("Pete", api_name="/predict")
|
||||
wait([job1, job2, job3])
|
||||
job1.result()
|
||||
job2.result()
|
||||
job3.result()
|
||||
|
||||
def test_json_parse_error(self):
|
||||
data = (
|
||||
"Bonjour Olivier, tu as l'air bien r\u00e9veill\u00e9 ce matin. Tu veux que je te pr\u00e9pare tes petits-d\u00e9j.\n",
|
||||
None,
|
||||
)
|
||||
|
||||
def return_bad():
|
||||
return data
|
||||
|
||||
demo = gr.Interface(return_bad, None, ["text", "text"])
|
||||
with connect(demo) as client:
|
||||
pred = client.predict(api_name="/predict")
|
||||
assert pred[0] == data[0]
|
||||
|
||||
|
||||
class TestStatusUpdates:
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
|
@ -11,6 +11,7 @@ from queue import Queue as ThreadQueue
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fastapi
|
||||
from gradio_client.utils import ServerMessage
|
||||
from typing_extensions import Literal
|
||||
|
||||
from gradio import route_utils, routes
|
||||
@ -144,15 +145,16 @@ class Queue:
|
||||
|
||||
async def push(
|
||||
self, body: PredictBody, request: fastapi.Request, username: str | None
|
||||
):
|
||||
) -> tuple[bool, str]:
|
||||
if body.session_hash is None:
|
||||
raise ValueError("No session hash provided.")
|
||||
return False, "No session hash provided."
|
||||
if body.fn_index is None:
|
||||
raise ValueError("No function index provided.")
|
||||
return False, "No function index provided."
|
||||
queue_len = len(self.event_queue)
|
||||
if self.max_size is not None and queue_len >= self.max_size:
|
||||
raise ValueError(
|
||||
f"Queue is full. Max size is {self.max_size} and current size is {queue_len}."
|
||||
return (
|
||||
False,
|
||||
f"Queue is full. Max size is {self.max_size} and size is {queue_len}.",
|
||||
)
|
||||
|
||||
event = Event(body.session_hash, body.fn_index, request, username)
|
||||
@ -168,7 +170,7 @@ class Queue:
|
||||
estimation = self.get_estimation()
|
||||
await self.send_estimation(event, estimation, queue_len)
|
||||
|
||||
return event._id
|
||||
return True, event._id
|
||||
|
||||
def _cancel_asyncio_tasks(self):
|
||||
for task in self._asyncio_tasks:
|
||||
@ -286,7 +288,9 @@ class Queue:
|
||||
for event in events:
|
||||
if event.progress_pending and event.progress:
|
||||
event.progress_pending = False
|
||||
self.send_message(event, "progress", event.progress.model_dump())
|
||||
self.send_message(
|
||||
event, ServerMessage.progress, event.progress.model_dump()
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.progress_update_sleep_when_free)
|
||||
|
||||
@ -330,7 +334,7 @@ class Queue:
|
||||
log=log,
|
||||
level=level,
|
||||
)
|
||||
self.send_message(event, "log", log_message.model_dump())
|
||||
self.send_message(event, ServerMessage.log, log_message.model_dump())
|
||||
|
||||
async def clean_events(
|
||||
self, *, session_hash: str | None = None, event_id: str | None = None
|
||||
@ -390,7 +394,7 @@ class Queue:
|
||||
if None not in self.active_jobs:
|
||||
# Add estimated amount of time for a thread to get empty
|
||||
estimation.rank_eta += self.avg_concurrent_process_time
|
||||
self.send_message(event, "estimation", estimation.model_dump())
|
||||
self.send_message(event, ServerMessage.estimation, estimation.model_dump())
|
||||
return estimation
|
||||
|
||||
def update_estimation(self, duration: float) -> None:
|
||||
@ -484,7 +488,7 @@ class Queue:
|
||||
awake_events: list[Event] = []
|
||||
try:
|
||||
for event in events:
|
||||
self.send_message(event, "process_starts")
|
||||
self.send_message(event, ServerMessage.process_starts)
|
||||
awake_events.append(event)
|
||||
if not awake_events:
|
||||
return
|
||||
@ -499,7 +503,7 @@ class Queue:
|
||||
for event in awake_events:
|
||||
self.send_message(
|
||||
event,
|
||||
"process_completed",
|
||||
ServerMessage.process_completed,
|
||||
{
|
||||
"output": {
|
||||
"error": None
|
||||
@ -518,7 +522,7 @@ class Queue:
|
||||
for event in awake_events:
|
||||
self.send_message(
|
||||
event,
|
||||
"process_generating",
|
||||
ServerMessage.process_generating,
|
||||
{
|
||||
"output": old_response,
|
||||
"success": old_response is not None,
|
||||
@ -540,7 +544,7 @@ class Queue:
|
||||
relevant_response = old_response or old_err
|
||||
self.send_message(
|
||||
event,
|
||||
"process_completed",
|
||||
ServerMessage.process_completed,
|
||||
{
|
||||
"output": {"error": str(relevant_response)}
|
||||
if isinstance(relevant_response, Exception)
|
||||
@ -556,7 +560,7 @@ class Queue:
|
||||
output["data"] = list(zip(*response.get("data")))[e]
|
||||
self.send_message(
|
||||
event,
|
||||
"process_completed",
|
||||
ServerMessage.process_completed,
|
||||
{
|
||||
"output": output,
|
||||
"success": response is not None,
|
||||
|
@ -42,6 +42,7 @@ from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
from gradio_client.utils import ServerMessage
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
from multipart.multipart import parse_options_header
|
||||
from starlette.background import BackgroundTask
|
||||
@ -614,22 +615,25 @@ class App(FastAPI):
|
||||
except EmptyQueue:
|
||||
await asyncio.sleep(check_rate)
|
||||
if time.perf_counter() - last_heartbeat > heartbeat_rate:
|
||||
message = {"msg": "heartbeat"}
|
||||
message = {"msg": ServerMessage.heartbeat}
|
||||
# Need to reset last_heartbeat with perf_counter
|
||||
# otherwise only a single hearbeat msg will be sent
|
||||
# and then the stream will retry leading to infinite queue 😬
|
||||
last_heartbeat = time.perf_counter()
|
||||
|
||||
if blocks._queue.stopped:
|
||||
message = {"msg": "server_stopped", "success": False}
|
||||
message = {
|
||||
"msg": ServerMessage.server_stopped,
|
||||
"success": False,
|
||||
}
|
||||
if message:
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
if message["msg"] == "process_completed":
|
||||
if message["msg"] == ServerMessage.process_completed:
|
||||
blocks._queue.pending_event_ids_session[
|
||||
session_hash
|
||||
].remove(message["event_id"])
|
||||
if message["msg"] == "server_stopped" or (
|
||||
message["msg"] == "process_completed"
|
||||
if message["msg"] == ServerMessage.server_stopped or (
|
||||
message["msg"] == ServerMessage.process_completed
|
||||
and (
|
||||
len(
|
||||
blocks._queue.pending_event_ids_session[
|
||||
@ -659,7 +663,14 @@ class App(FastAPI):
|
||||
if blocks._queue.server_app is None:
|
||||
blocks._queue.set_server_app(app)
|
||||
|
||||
event_id = await blocks._queue.push(body, request, username)
|
||||
success, event_id = await blocks._queue.push(body, request, username)
|
||||
if not success:
|
||||
status_code = (
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
if "Queue is full." in event_id
|
||||
else status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=event_id)
|
||||
return {"event_id": event_id}
|
||||
|
||||
@app.post("/component_server", dependencies=[Depends(login_check)])
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from concurrent.futures import wait
|
||||
|
||||
import gradio_client as grc
|
||||
import pytest
|
||||
@ -197,6 +198,20 @@ class TestQueueing:
|
||||
"PROCESSING",
|
||||
"PROCESSING",
|
||||
]
|
||||
wait(
|
||||
[
|
||||
add_job_1,
|
||||
add_job_2,
|
||||
add_job_3,
|
||||
sub_job_1,
|
||||
sub_job_2,
|
||||
sub_job_3,
|
||||
sub_job_3,
|
||||
mul_job_1,
|
||||
div_job_1,
|
||||
mul_job_2,
|
||||
]
|
||||
)
|
||||
|
||||
def test_every_does_not_block_queue(self):
|
||||
with gr.Blocks() as demo:
|
||||
|
Loading…
x
Reference in New Issue
Block a user