Support simplified event api for event only updates (#7407)

* changes

* add changeset

* changes

* changes

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* chanegs

* changes

* changes

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
aliabid94 2024-03-05 14:06:53 +05:00 committed by GitHub
parent a57e34ef87
commit 375bfd28d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 410 additions and 130 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---
feat:Fix server_messages.py to use the patched BaseModel class for Wasm env

View File

@ -115,6 +115,7 @@ class ServerMessage(str, Enum):
progress = "progress"
heartbeat = "heartbeat"
server_stopped = "server_stopped"
unexpected_error = "unexpected_error"
class Status(Enum):

View File

@ -1529,6 +1529,7 @@ Received outputs:
session_hash: str | None,
run: int | None,
final: bool,
simple_format: bool = False,
) -> list:
if session_hash is None or run is None:
return data
@ -1547,7 +1548,8 @@ Received outputs:
else:
prev_chunk = last_diffs[i]
last_diffs[i] = data[i]
data[i] = utils.diff(prev_chunk, data[i])
if not simple_format:
data[i] = utils.diff(prev_chunk, data[i])
if final:
del self.pending_diff_streams[session_hash][run]
@ -1574,6 +1576,7 @@ Received outputs:
event_id: str | None = None,
event_data: EventData | None = None,
in_event_listener: bool = True,
simple_format: bool = False,
explicit_call: bool = False,
) -> dict[str, Any]:
"""
@ -1684,6 +1687,7 @@ Received outputs:
session_hash=session_hash,
run=run,
final=not is_generating,
simple_format=simple_format,
)
block_fn.total_runtime += result["duration"]

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import tempfile
from pathlib import Path
from typing import Callable, Literal
from gradio_client.documentation import document

View File

@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
from fastapi import Request
from gradio_client.utils import traverse
from typing_extensions import Literal
from . import wasm_utils
@ -75,6 +74,11 @@ else:
RootModel.model_json_schema = RootModel.schema # type: ignore
class SimplePredictBody(BaseModel):
data: List[Any]
session_hash: Optional[str] = None
class PredictBody(BaseModel):
model_config = {"arbitrary_types_allowed": True}
@ -84,6 +88,7 @@ class PredictBody(BaseModel):
event_data: Optional[Any] = None
fn_index: Optional[int] = None
trigger_id: Optional[int] = None
simple_format: bool = False
batched: Optional[
bool
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
@ -110,29 +115,6 @@ class InterfaceTypes(Enum):
UNIFIED = auto()
class Estimation(BaseModel):
rank: Optional[int] = None
queue_size: int
rank_eta: Optional[float] = None
class ProgressUnit(BaseModel):
index: Optional[int] = None
length: Optional[int] = None
unit: Optional[str] = None
progress: Optional[float] = None
desc: Optional[str] = None
class Progress(BaseModel):
progress_data: List[ProgressUnit] = []
class LogMessage(BaseModel):
log: str
level: Literal["info", "warning"]
class GradioBaseModel(ABC):
def copy_to_dir(self, dir: str | pathlib.Path) -> GradioDataModel:
if not isinstance(self, (BaseModel, RootModel)):

View File

@ -13,19 +13,24 @@ from queue import Queue as ThreadQueue
from typing import TYPE_CHECKING
import fastapi
from gradio_client.utils import ServerMessage
from typing_extensions import Literal
from gradio import route_utils, routes
from gradio.data_classes import (
Estimation,
LogMessage,
PredictBody,
Progress,
ProgressUnit,
)
from gradio.exceptions import Error
from gradio.helpers import TrackedIterable
from gradio.server_messages import (
EstimationMessage,
EventMessage,
LogMessage,
ProcessCompletedMessage,
ProcessGeneratingMessage,
ProcessStartsMessage,
ProgressMessage,
ProgressUnit,
)
from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name
if TYPE_CHECKING:
@ -35,20 +40,20 @@ if TYPE_CHECKING:
class Event:
def __init__(
self,
session_hash: str,
session_hash: str | None,
fn_index: int,
request: fastapi.Request,
username: str | None,
concurrency_id: str,
):
self.session_hash = session_hash
self._id = uuid.uuid4().hex
self.session_hash: str = session_hash or self._id
self.fn_index = fn_index
self.request = request
self.username = username
self.concurrency_id = concurrency_id
self._id = uuid.uuid4().hex
self.data: PredictBody | None = None
self.progress: Progress | None = None
self.progress: ProgressMessage | None = None
self.progress_pending: bool = False
self.alive = True
@ -84,7 +89,9 @@ class Queue:
block_fns: list[BlockFunction],
default_concurrency_limit: int | None | Literal["not_set"] = "not_set",
):
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000)
self.pending_messages_per_session: LRUCache[
str, ThreadQueue[EventMessage]
] = LRUCache(2000)
self.pending_event_ids_session: dict[str, set[str]] = {}
self.pending_message_lock = safe_get_lock()
self.event_queue_per_concurrency_id: dict[str, EventQueue] = {}
@ -150,14 +157,13 @@ class Queue:
def send_message(
self,
event: Event,
message_type: str,
data: dict | None = None,
event_message: EventMessage,
):
if not event.alive:
return
data = {} if data is None else data
event_message.event_id = event._id
messages = self.pending_messages_per_session[event.session_hash]
messages.put_nowait({"msg": message_type, "event_id": event._id, **data})
messages.put_nowait(event_message)
def _resolve_concurrency_limit(
self, default_concurrency_limit: int | None | Literal["not_set"]
@ -190,8 +196,6 @@ class Queue:
async def push(
self, body: PredictBody, request: fastapi.Request, username: str | None
) -> tuple[bool, str]:
if body.session_hash is None:
return False, "No session hash provided."
if body.fn_index is None:
return False, "No function index provided."
if self.max_size is not None and len(self) >= self.max_size:
@ -208,6 +212,8 @@ class Queue:
self.block_fns[body.fn_index].concurrency_id,
)
event.data = body
if body.session_hash is None:
body.session_hash = event.session_hash
async with self.pending_message_lock:
if body.session_hash not in self.pending_messages_per_session:
self.pending_messages_per_session[body.session_hash] = ThreadQueue()
@ -322,9 +328,7 @@ class Queue:
for event in events:
if event.progress_pending and event.progress:
event.progress_pending = False
self.send_message(
event, ServerMessage.progress, event.progress.model_dump()
)
self.send_message(event, event.progress)
await asyncio.sleep(self.progress_update_sleep_when_free)
@ -350,7 +354,7 @@ class Queue:
desc=iterable.desc,
)
progress_data.append(progress_unit)
evt.progress = Progress(progress_data=progress_data)
evt.progress = ProgressMessage(progress_data=progress_data)
evt.progress_pending = True
def log_message(
@ -368,7 +372,7 @@ class Queue:
log=log,
level=level,
)
self.send_message(event, ServerMessage.log, log_message.model_dump())
self.send_message(event, log_message)
async def clean_events(
self, *, session_hash: str | None = None, event_id: str | None = None
@ -441,10 +445,9 @@ class Queue:
if after is None or rank >= after:
self.send_message(
event,
ServerMessage.estimation,
Estimation(
EstimationMessage(
rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue)
).model_dump(),
),
)
if event_queue.concurrency_limit is None:
wait_so_far = 0
@ -453,12 +456,12 @@ class Queue:
else:
wait_so_far = None
def get_status(self) -> Estimation:
return Estimation(
def get_status(self) -> EstimationMessage:
return EstimationMessage(
queue_size=len(self),
)
async def call_prediction(self, events: list[Event], batch: bool):
async def call_prediction(self, events: list[Event], batch: bool) -> dict:
body = events[0].data
if body is None:
raise ValueError("No event data")
@ -517,6 +520,8 @@ class Queue:
) # Do the same as https://github.com/tiangolo/fastapi/blob/0.87.0/fastapi/routing.py#L264
# Also, decode the JSON string to a Python object, emulating the HTTP client behavior e.g. the `json()` method of `httpx`.
response_json = json.loads(http_response.body.decode())
if not isinstance(response_json, dict):
raise ValueError("Unexpected object.")
return response_json
@ -530,12 +535,11 @@ class Queue:
if event.alive:
self.send_message(
event,
ServerMessage.process_starts,
{
"eta": self.process_time_per_fn_index[fn_index].avg_time
ProcessStartsMessage(
eta=self.process_time_per_fn_index[fn_index].avg_time
if fn_index in self.process_time_per_fn_index
else None
},
),
)
awake_events.append(event)
if not awake_events:
@ -549,15 +553,14 @@ class Queue:
for event in awake_events:
self.send_message(
event,
ServerMessage.process_completed,
{
"output": {
ProcessCompletedMessage(
output={
"error": None
if len(e.args) and e.args[0] is None
else str(e)
},
"success": False,
},
success=False,
),
)
if response and response.get("is_generating", False):
old_response = response
@ -568,11 +571,10 @@ class Queue:
for event in awake_events:
self.send_message(
event,
ServerMessage.process_generating,
{
"output": old_response,
"success": old_response is not None,
},
ProcessGeneratingMessage(
output=old_response,
success=old_response is not None,
),
)
awake_events = [event for event in awake_events if event.alive]
if not awake_events:
@ -587,14 +589,15 @@ class Queue:
relevant_response = response or err or old_err
self.send_message(
event,
ServerMessage.process_completed,
{
"output": {"error": str(relevant_response)}
ProcessCompletedMessage(
output={"error": str(relevant_response)}
if isinstance(relevant_response, Exception)
else relevant_response,
"success": relevant_response
and not isinstance(relevant_response, Exception),
},
else relevant_response or {},
success=(
relevant_response is not None
and not isinstance(relevant_response, Exception)
),
),
)
elif response:
output = copy.deepcopy(response)
@ -603,11 +606,10 @@ class Queue:
output["data"] = list(zip(*response.get("data")))[e]
self.send_message(
event,
ServerMessage.process_completed,
{
"output": output,
"success": response is not None,
},
ProcessCompletedMessage(
output=output,
success=response is not None,
),
)
end_time = time.time()
if response is not None:

View File

@ -258,6 +258,7 @@ async def call_process_api(
event_id=event_id,
event_data=event_data,
in_event_listener=True,
simple_format=body.simple_format,
)
iterator = output.pop("iterator", None)
if event_id is not None:

View File

@ -22,7 +22,16 @@ import time
import traceback
from pathlib import Path
from queue import Empty as EmptyQueue
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Type
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Type,
)
import fastapi
import httpx
@ -48,11 +57,15 @@ from starlette.responses import RedirectResponse, StreamingResponse
import gradio
from gradio import ranged_response, route_utils, utils, wasm_utils
from gradio.context import Context
from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
from gradio.data_classes import (
ComponentServerBody,
PredictBody,
ResetBody,
SimplePredictBody,
)
from gradio.exceptions import Error
from gradio.oauth import attach_oauth
from gradio.processing_utils import add_root_url
from gradio.queueing import Estimation
from gradio.route_utils import ( # noqa: F401
CustomCORSMiddleware,
FileUploadProgress,
@ -66,6 +79,14 @@ from gradio.route_utils import ( # noqa: F401
create_lifespan_handler,
move_uploaded_files_to_cache,
)
from gradio.server_messages import (
EstimationMessage,
EventMessage,
HeartbeatMessage,
ProcessCompletedMessage,
ProcessGeneratingMessage,
UnexpectedErrorMessage,
)
from gradio.state_holder import StateHolder
from gradio.utils import get_package_version, get_upload_folder
@ -596,10 +617,98 @@ class App(FastAPI):
output = add_root_url(output, root_path, None)
return output
@app.post("/call/{api_name}", dependencies=[Depends(login_check)])
@app.post("/call/{api_name}/", dependencies=[Depends(login_check)])
async def simple_predict_post(
api_name: str,
body: SimplePredictBody,
request: fastapi.Request,
username: str = Depends(get_current_user),
):
full_body = PredictBody(**body.model_dump(), simple_format=True)
inferred_fn_index = route_utils.infer_fn_index(
app=app, api_name=api_name, body=full_body
)
full_body.fn_index = inferred_fn_index
return await queue_join_helper(full_body, request, username)
@app.post("/queue/join", dependencies=[Depends(login_check)])
async def queue_join(
body: PredictBody,
request: fastapi.Request,
username: str = Depends(get_current_user),
):
if body.session_hash is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Session hash not found.",
)
return await queue_join_helper(body, request, username)
async def queue_join_helper(
body: PredictBody,
request: fastapi.Request,
username: str,
):
blocks = app.get_blocks()
if blocks._queue.server_app is None:
blocks._queue.set_server_app(app)
if blocks._queue.stopped:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Queue is stopped.",
)
success, event_id = await blocks._queue.push(body, request, username)
if not success:
status_code = (
status.HTTP_503_SERVICE_UNAVAILABLE
if "Queue is full." in event_id
else status.HTTP_400_BAD_REQUEST
)
raise HTTPException(status_code=status_code, detail=event_id)
return {"event_id": event_id}
@app.get("/call/{api_name}/{event_id}", dependencies=[Depends(login_check)])
async def simple_predict_get(
request: fastapi.Request,
event_id: str,
):
def process_msg(message: EventMessage) -> str | None:
if isinstance(message, ProcessCompletedMessage):
event = "complete" if message.success else "error"
data = message.output.get("data")
elif isinstance(message, ProcessGeneratingMessage):
event = "generating" if message.success else "error"
data = message.output.get("data")
elif isinstance(message, HeartbeatMessage):
event = "heartbeat"
data = None
elif isinstance(message, UnexpectedErrorMessage):
event = "error"
data = message.message
else:
return None
return f"event: {event}\ndata: {json.dumps(data)}\n\n"
return await queue_data_helper(request, event_id, process_msg)
@app.get("/queue/data", dependencies=[Depends(login_check)])
async def queue_data(
request: fastapi.Request,
session_hash: str,
):
def process_msg(message: EventMessage) -> str:
return f"data: {json.dumps(message.model_dump())}\n\n"
return await queue_data_helper(request, session_hash, process_msg)
async def queue_data_helper(
request: fastapi.Request,
session_hash: str,
process_msg: Callable[[EventMessage], str | None],
):
blocks = app.get_blocks()
root_path = route_utils.get_root_url(
@ -635,29 +744,35 @@ class App(FastAPI):
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
# Fix this
message = {
"msg": ServerMessage.heartbeat,
}
message = HeartbeatMessage()
# Need to reset last_heartbeat with perf_counter
# otherwise only a single hearbeat msg will be sent
# and then the stream will retry leading to infinite queue 😬
last_heartbeat = time.perf_counter()
if blocks._queue.stopped:
message = {
"msg": "unexpected_error",
"message": "Server stopped unexpectedly.",
"success": False,
}
message = UnexpectedErrorMessage(
message="Server stopped unexpectedly.",
success=False,
)
if message:
add_root_url(message, root_path, None)
yield f"data: {json.dumps(message)}\n\n"
if message["msg"] == ServerMessage.process_completed:
if isinstance(
message,
(ProcessGeneratingMessage, ProcessCompletedMessage),
):
add_root_url(message.output, root_path, None)
response = process_msg(message)
if response is not None:
yield response
if (
isinstance(message, ProcessCompletedMessage)
and message.event_id
):
blocks._queue.pending_event_ids_session[
session_hash
].remove(message["event_id"])
if message["msg"] == ServerMessage.server_stopped or (
message["msg"] == ServerMessage.process_completed
].remove(message.event_id)
if message.msg == ServerMessage.server_stopped or (
message.msg == ServerMessage.process_completed
and (
len(
blocks._queue.pending_event_ids_session[
@ -669,12 +784,12 @@ class App(FastAPI):
):
return
except BaseException as e:
message = {
"msg": "unexpected_error",
"success": False,
"message": str(e),
}
yield f"data: {json.dumps(message)}\n\n"
message = UnexpectedErrorMessage(
message=str(e),
)
response = process_msg(message)
if response is not None:
yield response
if isinstance(e, asyncio.CancelledError):
del blocks._queue.pending_messages_per_session[session_hash]
await blocks._queue.clean_events(session_hash=session_hash)
@ -685,33 +800,6 @@ class App(FastAPI):
media_type="text/event-stream",
)
@app.post("/queue/join", dependencies=[Depends(login_check)])
async def queue_join(
body: PredictBody,
request: fastapi.Request,
username: str = Depends(get_current_user),
):
blocks = app.get_blocks()
if blocks._queue.server_app is None:
blocks._queue.set_server_app(app)
if blocks._queue.stopped:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Queue is stopped.",
)
success, event_id = await blocks._queue.push(body, request, username)
if not success:
status_code = (
status.HTTP_503_SERVICE_UNAVAILABLE
if "Queue is full." in event_id
else status.HTTP_400_BAD_REQUEST
)
raise HTTPException(status_code=status_code, detail=event_id)
return {"event_id": event_id}
@app.post("/component_server", dependencies=[Depends(login_check)])
@app.post("/component_server/", dependencies=[Depends(login_check)])
def component_server(body: ComponentServerBody):
@ -733,7 +821,7 @@ class App(FastAPI):
@app.get(
"/queue/status",
dependencies=[Depends(login_check)],
response_model=Estimation,
response_model=EstimationMessage,
)
async def get_queue_status():
return app.get_blocks()._queue.get_status()

79
gradio/server_messages.py Normal file
View File

@ -0,0 +1,79 @@
from typing import List, Literal, Optional, Union
from gradio_client.utils import ServerMessage
# In the data_classes.py file, the BaseModel class is patched
# so that it can be used in both Pyodide and non-Pyodide environments.
# So we import it from the data_classes.py file
# instead of the original Pydantic BaseModel class.
from gradio.data_classes import BaseModel
class BaseMessage(BaseModel):
msg: ServerMessage
event_id: Optional[str] = None
class ProgressUnit(BaseModel):
index: Optional[int] = None
length: Optional[int] = None
unit: Optional[str] = None
progress: Optional[float] = None
desc: Optional[str] = None
class ProgressMessage(BaseMessage):
msg: Literal[ServerMessage.progress] = ServerMessage.progress
progress_data: List[ProgressUnit] = []
class LogMessage(BaseMessage):
msg: Literal[ServerMessage.log] = ServerMessage.log
log: str
level: Literal["info", "warning"]
class EstimationMessage(BaseMessage):
msg: Literal[ServerMessage.estimation] = ServerMessage.estimation
rank: Optional[int] = None
queue_size: int
rank_eta: Optional[float] = None
class ProcessStartsMessage(BaseMessage):
msg: Literal[ServerMessage.process_starts] = ServerMessage.process_starts
eta: Optional[float] = None
class ProcessCompletedMessage(BaseMessage):
msg: Literal[ServerMessage.process_completed] = ServerMessage.process_completed
output: dict
success: bool
class ProcessGeneratingMessage(BaseMessage):
msg: Literal[ServerMessage.process_generating] = ServerMessage.process_generating
output: dict
success: bool
class HeartbeatMessage(BaseModel):
msg: Literal[ServerMessage.heartbeat] = ServerMessage.heartbeat
class UnexpectedErrorMessage(BaseModel):
msg: Literal[ServerMessage.unexpected_error] = ServerMessage.unexpected_error
message: str
success: Literal[False] = False
EventMessage = Union[
ProgressMessage,
LogMessage,
EstimationMessage,
ProcessStartsMessage,
ProcessCompletedMessage,
ProcessGeneratingMessage,
HeartbeatMessage,
UnexpectedErrorMessage,
]

View File

@ -1036,6 +1036,9 @@ class LRUCache(OrderedDict, Generic[K, V]):
self.popitem(last=False)
super().__setitem__(key, value)
def __getitem__(self, key: K) -> V:
return super().__getitem__(key)
def get_cache_folder() -> Path:
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))

View File

@ -2,6 +2,7 @@
import functools
import os
import tempfile
import time
from contextlib import asynccontextmanager, closing
from unittest.mock import patch
@ -9,6 +10,7 @@ import gradio_client as grc
import numpy as np
import pandas as pd
import pytest
import requests
import starlette.routing
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
@ -1008,6 +1010,119 @@ def test_get_root_url(request_url, route_path, root_path, expected_root_url):
assert get_root_url(request, route_path, root_path) == expected_root_url
class TestSimpleAPIRoutes:
def get_demo(self):
with Blocks() as demo:
input = Textbox()
output = Textbox()
output2 = Textbox()
def fn_1(x):
return f"Hello, {x}!"
def fn_2(x):
for i in range(len(x)):
time.sleep(0.5)
yield f"Hello, {x[:i+1]}!"
if len(x) < 3:
raise ValueError("Small input")
def fn_3():
return "a", "b"
btn1, btn2, btn3 = Button(), Button(), Button()
btn1.click(fn_1, input, output, api_name="fn1")
btn2.click(fn_2, input, output2, api_name="fn2")
btn3.click(fn_3, None, [output, output2], api_name="fn3")
return demo
def test_successful_simple_route(self):
demo = self.get_demo()
demo.launch(prevent_thread_lock=True)
response = requests.post(f"{demo.local_url}call/fn1", json={"data": ["world"]})
assert response.status_code == 200, "Failed to call fn1"
response = response.json()
event_id = response["event_id"]
output = []
response = requests.get(f"{demo.local_url}call/fn1/{event_id}", stream=True)
for line in response.iter_lines():
if line:
output.append(line.decode("utf-8"))
assert output == ["event: complete", 'data: ["Hello, world!"]']
response = requests.post(f"{demo.local_url}call/fn3", json={"data": []})
assert response.status_code == 200, "Failed to call fn3"
response = response.json()
event_id = response["event_id"]
output = []
response = requests.get(f"{demo.local_url}call/fn3/{event_id}", stream=True)
for line in response.iter_lines():
if line:
output.append(line.decode("utf-8"))
assert output == ["event: complete", 'data: ["a", "b"]']
def test_generative_simple_route(self):
demo = self.get_demo()
demo.launch(prevent_thread_lock=True)
response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["world"]})
assert response.status_code == 200, "Failed to call fn2"
response = response.json()
event_id = response["event_id"]
output = []
response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
for line in response.iter_lines():
if line:
output.append(line.decode("utf-8"))
assert output == [
"event: generating",
'data: ["Hello, w!"]',
"event: generating",
'data: ["Hello, wo!"]',
"event: generating",
'data: ["Hello, wor!"]',
"event: generating",
'data: ["Hello, worl!"]',
"event: generating",
'data: ["Hello, world!"]',
"event: complete",
'data: ["Hello, world!"]',
]
response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["w"]})
assert response.status_code == 200, "Failed to call fn2"
response = response.json()
event_id = response["event_id"]
output = []
response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
for line in response.iter_lines():
if line:
output.append(line.decode("utf-8"))
assert output == [
"event: generating",
'data: ["Hello, w!"]',
"event: error",
"data: null",
]
def test_compare_passwords_securely():
password1 = "password"
password2 = "pässword"