Python client properly handles hearbeat and log messages. Also handles responses longer than 65k (#6693)

* first commit

* newlines

* test

* Fix depends

* revert

* add changeset

* add changeset

* Lint

* queue full test

* Add code

* Update + fix

* add changeset

* Revert demo

* Typo in success

* Fix

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-12-13 17:47:06 -05:00 committed by GitHub
parent a3cf90e57b
commit 34f943101b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 197 additions and 70 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/client": patch
"gradio": patch
"gradio_client": patch
---
fix:Python client properly handles hearbeat and log messages. Also handles responses longer than 65k

View File

@ -203,8 +203,16 @@ export function api_factory(
} catch (e) {
return [{ error: BROKEN_CONNECTION_MSG }, 500];
}
const output: PostResponse = await response.json();
return [output, response.status];
let output: PostResponse;
let status: int;
try {
output = await response.json();
status = response.status;
} catch (e) {
output = { error: `Could not parse server response: ${e}` };
status = 500;
}
return [output, status];
}
async function upload_files(
@ -791,7 +799,17 @@ export function api_factory(
},
hf_token
).then(([response, status]) => {
if (status !== 200) {
if (status === 503) {
fire_event({
type: "status",
stage: "error",
message: QUEUE_FULL_MSG,
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
} else if (status !== 200) {
fire_event({
type: "status",
stage: "error",
@ -806,7 +824,6 @@ export function api_factory(
if (!stream_open) {
open_stream();
}
let callback = async function (_data: object): void {
const { type, status, data } = handle_message(
_data,

View File

@ -37,6 +37,8 @@ from gradio_client.utils import (
Communicator,
JobStatus,
Message,
QueueError,
ServerMessage,
Status,
StatusUpdate,
)
@ -169,7 +171,6 @@ class Client:
async def stream_messages(self) -> None:
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
buffer = ""
async with client.stream(
"GET",
self.sse_url,
@ -177,33 +178,31 @@ class Client:
headers=self.headers,
cookies=self.cookies,
) as response:
async for line in response.aiter_text():
buffer += line
while "\n\n" in buffer:
message, buffer = buffer.split("\n\n", 1)
if message.startswith("data:"):
resp = json.loads(message[5:])
if resp["msg"] == "heartbeat":
continue
elif resp["msg"] == "server_stopped":
for (
pending_messages
) in self.pending_messages_per_event.values():
pending_messages.append(resp)
return
event_id = resp["event_id"]
if event_id not in self.pending_messages_per_event:
self.pending_messages_per_event[event_id] = []
self.pending_messages_per_event[event_id].append(resp)
if resp["msg"] == "process_completed":
self.pending_event_ids.remove(event_id)
if len(self.pending_event_ids) == 0:
self.stream_open = False
return
elif message == "":
async for line in response.aiter_lines():
line = line.rstrip("\n")
if not len(line):
continue
if line.startswith("data:"):
resp = json.loads(line[5:])
if resp["msg"] == ServerMessage.heartbeat:
continue
else:
raise ValueError(f"Unexpected SSE line: '{message}'")
elif resp["msg"] == ServerMessage.server_stopped:
for (
pending_messages
) in self.pending_messages_per_event.values():
pending_messages.append(resp)
return
event_id = resp["event_id"]
if event_id not in self.pending_messages_per_event:
self.pending_messages_per_event[event_id] = []
self.pending_messages_per_event[event_id].append(resp)
if resp["msg"] == ServerMessage.process_completed:
self.pending_event_ids.remove(event_id)
if len(self.pending_event_ids) == 0:
self.stream_open = False
return
else:
raise ValueError(f"Unexpected SSE line: '{line}'")
except BaseException as e:
import traceback
@ -218,6 +217,8 @@ class Client:
headers=self.headers,
cookies=self.cookies,
)
if req.status_code == 503:
raise QueueError("Queue is full! Please try again.")
req.raise_for_status()
resp = req.json()
event_id = resp["event_id"]

View File

@ -102,6 +102,20 @@ class SpaceDuplicationError(Exception):
pass
class ServerMessage(str, Enum):
send_hash = "send_hash"
queue_full = "queue_full"
estimation = "estimation"
send_data = "send_data"
process_starts = "process_starts"
process_generating = "process_generating"
process_completed = "process_completed"
log = "log"
progress = "progress"
heartbeat = "heartbeat"
server_stopped = "server_stopped"
class Status(Enum):
"""Status codes presented to client users."""
@ -141,16 +155,17 @@ class Status(Enum):
def msg_to_status(msg: str) -> Status:
"""Map the raw message from the backend to the status code presented to users."""
return {
"send_hash": Status.JOINING_QUEUE,
"queue_full": Status.QUEUE_FULL,
"estimation": Status.IN_QUEUE,
"send_data": Status.SENDING_DATA,
"process_starts": Status.PROCESSING,
"process_generating": Status.ITERATING,
"process_completed": Status.FINISHED,
"progress": Status.PROGRESS,
"log": Status.LOG,
}[msg]
ServerMessage.send_hash: Status.JOINING_QUEUE,
ServerMessage.queue_full: Status.QUEUE_FULL,
ServerMessage.estimation: Status.IN_QUEUE,
ServerMessage.send_data: Status.SENDING_DATA,
ServerMessage.process_starts: Status.PROCESSING,
ServerMessage.process_generating: Status.ITERATING,
ServerMessage.process_completed: Status.FINISHED,
ServerMessage.progress: Status.PROGRESS,
ServerMessage.log: Status.LOG,
ServerMessage.server_stopped: Status.FINISHED,
}[msg] # type: ignore
@dataclass
@ -436,9 +451,14 @@ async def stream_sse_v0(
headers=headers,
cookies=cookies,
) as response:
async for line in response.aiter_text():
async for line in response.aiter_lines():
line = line.rstrip("\n")
if len(line) == 0:
continue
if line.startswith("data:"):
resp = json.loads(line[5:])
if resp["msg"] in [ServerMessage.log, ServerMessage.heartbeat]:
continue
with helper.lock:
has_progress = "progress_data" in resp
status_update = StatusUpdate(
@ -502,7 +522,7 @@ async def stream_sse_v1(
with helper.lock:
log_message = None
if msg["msg"] == "log":
if msg["msg"] == ServerMessage.log:
log = msg.get("log")
level = msg.get("level")
if log and level:
@ -527,13 +547,10 @@ async def stream_sse_v1(
result = [e]
helper.job.outputs.append(result)
helper.job.latest_status = status_update
if msg["msg"] == "queue_full":
raise QueueError("Queue is full! Please try again.")
elif msg["msg"] == "process_completed":
if msg["msg"] == ServerMessage.process_completed:
del pending_messages_per_event[event_id]
return msg["output"]
elif msg["msg"] == "server_stopped":
elif msg["msg"] == ServerMessage.server_stopped:
raise ValueError("Server stopped.")
except asyncio.CancelledError:

View File

@ -381,3 +381,18 @@ def gradio_temp_dir(monkeypatch, tmp_path):
"""
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
return tmp_path
@pytest.fixture
def long_response_with_info():
def long_response(x):
gr.Info("Beginning long response")
time.sleep(17)
gr.Info("Done!")
return "\ta\nb" * 90000
return gr.Interface(
long_response,
None,
gr.Textbox(label="Output"),
)

View File

@ -4,7 +4,7 @@ import pathlib
import tempfile
import time
import uuid
from concurrent.futures import CancelledError, TimeoutError
from concurrent.futures import CancelledError, TimeoutError, wait
from contextlib import contextmanager
from datetime import datetime, timedelta
from pathlib import Path
@ -21,7 +21,13 @@ from huggingface_hub.utils import RepositoryNotFoundError
from gradio_client import Client
from gradio_client.client import DEFAULT_TEMP_DIR
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
from gradio_client.utils import (
Communicator,
ProgressUnit,
QueueError,
Status,
StatusUpdate,
)
HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
@ -488,6 +494,40 @@ class TestClientPredictions:
assert demo.predict(api_name="/close") == 4
assert demo.predict("Ali", api_name="/greeting") == ("Hello Ali", 5)
def test_long_response_time_with_gr_info_and_big_payload(
self, long_response_with_info
):
with connect(long_response_with_info) as demo:
assert demo.predict(api_name="/predict") == "\ta\nb" * 90000
def test_queue_full_raises_error(self):
demo = gr.Interface(lambda s: f"Hello {s}", "textbox", "textbox").queue(
max_size=1
)
with connect(demo) as client:
with pytest.raises(QueueError):
job1 = client.submit("Freddy", api_name="/predict")
job2 = client.submit("Abubakar", api_name="/predict")
job3 = client.submit("Pete", api_name="/predict")
wait([job1, job2, job3])
job1.result()
job2.result()
job3.result()
def test_json_parse_error(self):
data = (
"Bonjour Olivier, tu as l'air bien r\u00e9veill\u00e9 ce matin. Tu veux que je te pr\u00e9pare tes petits-d\u00e9j.\n",
None,
)
def return_bad():
return data
demo = gr.Interface(return_bad, None, ["text", "text"])
with connect(demo) as client:
pred = client.predict(api_name="/predict")
assert pred[0] == data[0]
class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")

View File

@ -11,6 +11,7 @@ from queue import Queue as ThreadQueue
from typing import TYPE_CHECKING
import fastapi
from gradio_client.utils import ServerMessage
from typing_extensions import Literal
from gradio import route_utils, routes
@ -144,15 +145,16 @@ class Queue:
async def push(
self, body: PredictBody, request: fastapi.Request, username: str | None
):
) -> tuple[bool, str]:
if body.session_hash is None:
raise ValueError("No session hash provided.")
return False, "No session hash provided."
if body.fn_index is None:
raise ValueError("No function index provided.")
return False, "No function index provided."
queue_len = len(self.event_queue)
if self.max_size is not None and queue_len >= self.max_size:
raise ValueError(
f"Queue is full. Max size is {self.max_size} and current size is {queue_len}."
return (
False,
f"Queue is full. Max size is {self.max_size} and size is {queue_len}.",
)
event = Event(body.session_hash, body.fn_index, request, username)
@ -168,7 +170,7 @@ class Queue:
estimation = self.get_estimation()
await self.send_estimation(event, estimation, queue_len)
return event._id
return True, event._id
def _cancel_asyncio_tasks(self):
for task in self._asyncio_tasks:
@ -286,7 +288,9 @@ class Queue:
for event in events:
if event.progress_pending and event.progress:
event.progress_pending = False
self.send_message(event, "progress", event.progress.model_dump())
self.send_message(
event, ServerMessage.progress, event.progress.model_dump()
)
await asyncio.sleep(self.progress_update_sleep_when_free)
@ -330,7 +334,7 @@ class Queue:
log=log,
level=level,
)
self.send_message(event, "log", log_message.model_dump())
self.send_message(event, ServerMessage.log, log_message.model_dump())
async def clean_events(
self, *, session_hash: str | None = None, event_id: str | None = None
@ -390,7 +394,7 @@ class Queue:
if None not in self.active_jobs:
# Add estimated amount of time for a thread to get empty
estimation.rank_eta += self.avg_concurrent_process_time
self.send_message(event, "estimation", estimation.model_dump())
self.send_message(event, ServerMessage.estimation, estimation.model_dump())
return estimation
def update_estimation(self, duration: float) -> None:
@ -484,7 +488,7 @@ class Queue:
awake_events: list[Event] = []
try:
for event in events:
self.send_message(event, "process_starts")
self.send_message(event, ServerMessage.process_starts)
awake_events.append(event)
if not awake_events:
return
@ -499,7 +503,7 @@ class Queue:
for event in awake_events:
self.send_message(
event,
"process_completed",
ServerMessage.process_completed,
{
"output": {
"error": None
@ -518,7 +522,7 @@ class Queue:
for event in awake_events:
self.send_message(
event,
"process_generating",
ServerMessage.process_generating,
{
"output": old_response,
"success": old_response is not None,
@ -540,7 +544,7 @@ class Queue:
relevant_response = old_response or old_err
self.send_message(
event,
"process_completed",
ServerMessage.process_completed,
{
"output": {"error": str(relevant_response)}
if isinstance(relevant_response, Exception)
@ -556,7 +560,7 @@ class Queue:
output["data"] = list(zip(*response.get("data")))[e]
self.send_message(
event,
"process_completed",
ServerMessage.process_completed,
{
"output": output,
"success": response is not None,

View File

@ -42,6 +42,7 @@ from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
from gradio_client.utils import ServerMessage
from jinja2.exceptions import TemplateNotFound
from multipart.multipart import parse_options_header
from starlette.background import BackgroundTask
@ -614,22 +615,25 @@ class App(FastAPI):
except EmptyQueue:
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
message = {"msg": "heartbeat"}
message = {"msg": ServerMessage.heartbeat}
# Need to reset last_heartbeat with perf_counter
# otherwise only a single hearbeat msg will be sent
# and then the stream will retry leading to infinite queue 😬
last_heartbeat = time.perf_counter()
if blocks._queue.stopped:
message = {"msg": "server_stopped", "success": False}
message = {
"msg": ServerMessage.server_stopped,
"success": False,
}
if message:
yield f"data: {json.dumps(message)}\n\n"
if message["msg"] == "process_completed":
if message["msg"] == ServerMessage.process_completed:
blocks._queue.pending_event_ids_session[
session_hash
].remove(message["event_id"])
if message["msg"] == "server_stopped" or (
message["msg"] == "process_completed"
if message["msg"] == ServerMessage.server_stopped or (
message["msg"] == ServerMessage.process_completed
and (
len(
blocks._queue.pending_event_ids_session[
@ -659,7 +663,14 @@ class App(FastAPI):
if blocks._queue.server_app is None:
blocks._queue.set_server_app(app)
event_id = await blocks._queue.push(body, request, username)
success, event_id = await blocks._queue.push(body, request, username)
if not success:
status_code = (
status.HTTP_503_SERVICE_UNAVAILABLE
if "Queue is full." in event_id
else status.HTTP_400_BAD_REQUEST
)
raise HTTPException(status_code=status_code, detail=event_id)
return {"event_id": event_id}
@app.post("/component_server", dependencies=[Depends(login_check)])

View File

@ -1,4 +1,5 @@
import time
from concurrent.futures import wait
import gradio_client as grc
import pytest
@ -197,6 +198,20 @@ class TestQueueing:
"PROCESSING",
"PROCESSING",
]
wait(
[
add_job_1,
add_job_2,
add_job_3,
sub_job_1,
sub_job_2,
sub_job_3,
sub_job_3,
mul_job_1,
div_job_1,
mul_job_2,
]
)
def test_every_does_not_block_queue(self):
with gr.Blocks() as demo: