mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-07 11:46:51 +08:00
Support simplified event api for event only updates (#7407)
* changes * add changeset * changes * changes * changes * add changeset * changes * changes * changes * changes * changes * changes * changes * changes * chanegs * changes * changes * changes * changes * changes --------- Co-authored-by: Ali Abid <aliabid94@gmail.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
a57e34ef87
commit
375bfd28d2
6
.changeset/silver-hairs-roll.md
Normal file
6
.changeset/silver-hairs-roll.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
feat:Fix server_messages.py to use the patched BaseModel class for Wasm env
|
@ -115,6 +115,7 @@ class ServerMessage(str, Enum):
|
||||
progress = "progress"
|
||||
heartbeat = "heartbeat"
|
||||
server_stopped = "server_stopped"
|
||||
unexpected_error = "unexpected_error"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
|
@ -1529,6 +1529,7 @@ Received outputs:
|
||||
session_hash: str | None,
|
||||
run: int | None,
|
||||
final: bool,
|
||||
simple_format: bool = False,
|
||||
) -> list:
|
||||
if session_hash is None or run is None:
|
||||
return data
|
||||
@ -1547,7 +1548,8 @@ Received outputs:
|
||||
else:
|
||||
prev_chunk = last_diffs[i]
|
||||
last_diffs[i] = data[i]
|
||||
data[i] = utils.diff(prev_chunk, data[i])
|
||||
if not simple_format:
|
||||
data[i] = utils.diff(prev_chunk, data[i])
|
||||
|
||||
if final:
|
||||
del self.pending_diff_streams[session_hash][run]
|
||||
@ -1574,6 +1576,7 @@ Received outputs:
|
||||
event_id: str | None = None,
|
||||
event_data: EventData | None = None,
|
||||
in_event_listener: bool = True,
|
||||
simple_format: bool = False,
|
||||
explicit_call: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@ -1684,6 +1687,7 @@ Received outputs:
|
||||
session_hash=session_hash,
|
||||
run=run,
|
||||
final=not is_generating,
|
||||
simple_format=simple_format,
|
||||
)
|
||||
|
||||
block_fn.total_runtime += result["duration"]
|
||||
|
@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
|
||||
from gradio_client.documentation import document
|
||||
|
@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from gradio_client.utils import traverse
|
||||
from typing_extensions import Literal
|
||||
|
||||
from . import wasm_utils
|
||||
|
||||
@ -75,6 +74,11 @@ else:
|
||||
RootModel.model_json_schema = RootModel.schema # type: ignore
|
||||
|
||||
|
||||
class SimplePredictBody(BaseModel):
|
||||
data: List[Any]
|
||||
session_hash: Optional[str] = None
|
||||
|
||||
|
||||
class PredictBody(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@ -84,6 +88,7 @@ class PredictBody(BaseModel):
|
||||
event_data: Optional[Any] = None
|
||||
fn_index: Optional[int] = None
|
||||
trigger_id: Optional[int] = None
|
||||
simple_format: bool = False
|
||||
batched: Optional[
|
||||
bool
|
||||
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
|
||||
@ -110,29 +115,6 @@ class InterfaceTypes(Enum):
|
||||
UNIFIED = auto()
|
||||
|
||||
|
||||
class Estimation(BaseModel):
|
||||
rank: Optional[int] = None
|
||||
queue_size: int
|
||||
rank_eta: Optional[float] = None
|
||||
|
||||
|
||||
class ProgressUnit(BaseModel):
|
||||
index: Optional[int] = None
|
||||
length: Optional[int] = None
|
||||
unit: Optional[str] = None
|
||||
progress: Optional[float] = None
|
||||
desc: Optional[str] = None
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
progress_data: List[ProgressUnit] = []
|
||||
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
log: str
|
||||
level: Literal["info", "warning"]
|
||||
|
||||
|
||||
class GradioBaseModel(ABC):
|
||||
def copy_to_dir(self, dir: str | pathlib.Path) -> GradioDataModel:
|
||||
if not isinstance(self, (BaseModel, RootModel)):
|
||||
|
@ -13,19 +13,24 @@ from queue import Queue as ThreadQueue
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fastapi
|
||||
from gradio_client.utils import ServerMessage
|
||||
from typing_extensions import Literal
|
||||
|
||||
from gradio import route_utils, routes
|
||||
from gradio.data_classes import (
|
||||
Estimation,
|
||||
LogMessage,
|
||||
PredictBody,
|
||||
Progress,
|
||||
ProgressUnit,
|
||||
)
|
||||
from gradio.exceptions import Error
|
||||
from gradio.helpers import TrackedIterable
|
||||
from gradio.server_messages import (
|
||||
EstimationMessage,
|
||||
EventMessage,
|
||||
LogMessage,
|
||||
ProcessCompletedMessage,
|
||||
ProcessGeneratingMessage,
|
||||
ProcessStartsMessage,
|
||||
ProgressMessage,
|
||||
ProgressUnit,
|
||||
)
|
||||
from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -35,20 +40,20 @@ if TYPE_CHECKING:
|
||||
class Event:
|
||||
def __init__(
|
||||
self,
|
||||
session_hash: str,
|
||||
session_hash: str | None,
|
||||
fn_index: int,
|
||||
request: fastapi.Request,
|
||||
username: str | None,
|
||||
concurrency_id: str,
|
||||
):
|
||||
self.session_hash = session_hash
|
||||
self._id = uuid.uuid4().hex
|
||||
self.session_hash: str = session_hash or self._id
|
||||
self.fn_index = fn_index
|
||||
self.request = request
|
||||
self.username = username
|
||||
self.concurrency_id = concurrency_id
|
||||
self._id = uuid.uuid4().hex
|
||||
self.data: PredictBody | None = None
|
||||
self.progress: Progress | None = None
|
||||
self.progress: ProgressMessage | None = None
|
||||
self.progress_pending: bool = False
|
||||
self.alive = True
|
||||
|
||||
@ -84,7 +89,9 @@ class Queue:
|
||||
block_fns: list[BlockFunction],
|
||||
default_concurrency_limit: int | None | Literal["not_set"] = "not_set",
|
||||
):
|
||||
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000)
|
||||
self.pending_messages_per_session: LRUCache[
|
||||
str, ThreadQueue[EventMessage]
|
||||
] = LRUCache(2000)
|
||||
self.pending_event_ids_session: dict[str, set[str]] = {}
|
||||
self.pending_message_lock = safe_get_lock()
|
||||
self.event_queue_per_concurrency_id: dict[str, EventQueue] = {}
|
||||
@ -150,14 +157,13 @@ class Queue:
|
||||
def send_message(
|
||||
self,
|
||||
event: Event,
|
||||
message_type: str,
|
||||
data: dict | None = None,
|
||||
event_message: EventMessage,
|
||||
):
|
||||
if not event.alive:
|
||||
return
|
||||
data = {} if data is None else data
|
||||
event_message.event_id = event._id
|
||||
messages = self.pending_messages_per_session[event.session_hash]
|
||||
messages.put_nowait({"msg": message_type, "event_id": event._id, **data})
|
||||
messages.put_nowait(event_message)
|
||||
|
||||
def _resolve_concurrency_limit(
|
||||
self, default_concurrency_limit: int | None | Literal["not_set"]
|
||||
@ -190,8 +196,6 @@ class Queue:
|
||||
async def push(
|
||||
self, body: PredictBody, request: fastapi.Request, username: str | None
|
||||
) -> tuple[bool, str]:
|
||||
if body.session_hash is None:
|
||||
return False, "No session hash provided."
|
||||
if body.fn_index is None:
|
||||
return False, "No function index provided."
|
||||
if self.max_size is not None and len(self) >= self.max_size:
|
||||
@ -208,6 +212,8 @@ class Queue:
|
||||
self.block_fns[body.fn_index].concurrency_id,
|
||||
)
|
||||
event.data = body
|
||||
if body.session_hash is None:
|
||||
body.session_hash = event.session_hash
|
||||
async with self.pending_message_lock:
|
||||
if body.session_hash not in self.pending_messages_per_session:
|
||||
self.pending_messages_per_session[body.session_hash] = ThreadQueue()
|
||||
@ -322,9 +328,7 @@ class Queue:
|
||||
for event in events:
|
||||
if event.progress_pending and event.progress:
|
||||
event.progress_pending = False
|
||||
self.send_message(
|
||||
event, ServerMessage.progress, event.progress.model_dump()
|
||||
)
|
||||
self.send_message(event, event.progress)
|
||||
|
||||
await asyncio.sleep(self.progress_update_sleep_when_free)
|
||||
|
||||
@ -350,7 +354,7 @@ class Queue:
|
||||
desc=iterable.desc,
|
||||
)
|
||||
progress_data.append(progress_unit)
|
||||
evt.progress = Progress(progress_data=progress_data)
|
||||
evt.progress = ProgressMessage(progress_data=progress_data)
|
||||
evt.progress_pending = True
|
||||
|
||||
def log_message(
|
||||
@ -368,7 +372,7 @@ class Queue:
|
||||
log=log,
|
||||
level=level,
|
||||
)
|
||||
self.send_message(event, ServerMessage.log, log_message.model_dump())
|
||||
self.send_message(event, log_message)
|
||||
|
||||
async def clean_events(
|
||||
self, *, session_hash: str | None = None, event_id: str | None = None
|
||||
@ -441,10 +445,9 @@ class Queue:
|
||||
if after is None or rank >= after:
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.estimation,
|
||||
Estimation(
|
||||
EstimationMessage(
|
||||
rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue)
|
||||
).model_dump(),
|
||||
),
|
||||
)
|
||||
if event_queue.concurrency_limit is None:
|
||||
wait_so_far = 0
|
||||
@ -453,12 +456,12 @@ class Queue:
|
||||
else:
|
||||
wait_so_far = None
|
||||
|
||||
def get_status(self) -> Estimation:
|
||||
return Estimation(
|
||||
def get_status(self) -> EstimationMessage:
|
||||
return EstimationMessage(
|
||||
queue_size=len(self),
|
||||
)
|
||||
|
||||
async def call_prediction(self, events: list[Event], batch: bool):
|
||||
async def call_prediction(self, events: list[Event], batch: bool) -> dict:
|
||||
body = events[0].data
|
||||
if body is None:
|
||||
raise ValueError("No event data")
|
||||
@ -517,6 +520,8 @@ class Queue:
|
||||
) # Do the same as https://github.com/tiangolo/fastapi/blob/0.87.0/fastapi/routing.py#L264
|
||||
# Also, decode the JSON string to a Python object, emulating the HTTP client behavior e.g. the `json()` method of `httpx`.
|
||||
response_json = json.loads(http_response.body.decode())
|
||||
if not isinstance(response_json, dict):
|
||||
raise ValueError("Unexpected object.")
|
||||
|
||||
return response_json
|
||||
|
||||
@ -530,12 +535,11 @@ class Queue:
|
||||
if event.alive:
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.process_starts,
|
||||
{
|
||||
"eta": self.process_time_per_fn_index[fn_index].avg_time
|
||||
ProcessStartsMessage(
|
||||
eta=self.process_time_per_fn_index[fn_index].avg_time
|
||||
if fn_index in self.process_time_per_fn_index
|
||||
else None
|
||||
},
|
||||
),
|
||||
)
|
||||
awake_events.append(event)
|
||||
if not awake_events:
|
||||
@ -549,15 +553,14 @@ class Queue:
|
||||
for event in awake_events:
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.process_completed,
|
||||
{
|
||||
"output": {
|
||||
ProcessCompletedMessage(
|
||||
output={
|
||||
"error": None
|
||||
if len(e.args) and e.args[0] is None
|
||||
else str(e)
|
||||
},
|
||||
"success": False,
|
||||
},
|
||||
success=False,
|
||||
),
|
||||
)
|
||||
if response and response.get("is_generating", False):
|
||||
old_response = response
|
||||
@ -568,11 +571,10 @@ class Queue:
|
||||
for event in awake_events:
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.process_generating,
|
||||
{
|
||||
"output": old_response,
|
||||
"success": old_response is not None,
|
||||
},
|
||||
ProcessGeneratingMessage(
|
||||
output=old_response,
|
||||
success=old_response is not None,
|
||||
),
|
||||
)
|
||||
awake_events = [event for event in awake_events if event.alive]
|
||||
if not awake_events:
|
||||
@ -587,14 +589,15 @@ class Queue:
|
||||
relevant_response = response or err or old_err
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.process_completed,
|
||||
{
|
||||
"output": {"error": str(relevant_response)}
|
||||
ProcessCompletedMessage(
|
||||
output={"error": str(relevant_response)}
|
||||
if isinstance(relevant_response, Exception)
|
||||
else relevant_response,
|
||||
"success": relevant_response
|
||||
and not isinstance(relevant_response, Exception),
|
||||
},
|
||||
else relevant_response or {},
|
||||
success=(
|
||||
relevant_response is not None
|
||||
and not isinstance(relevant_response, Exception)
|
||||
),
|
||||
),
|
||||
)
|
||||
elif response:
|
||||
output = copy.deepcopy(response)
|
||||
@ -603,11 +606,10 @@ class Queue:
|
||||
output["data"] = list(zip(*response.get("data")))[e]
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.process_completed,
|
||||
{
|
||||
"output": output,
|
||||
"success": response is not None,
|
||||
},
|
||||
ProcessCompletedMessage(
|
||||
output=output,
|
||||
success=response is not None,
|
||||
),
|
||||
)
|
||||
end_time = time.time()
|
||||
if response is not None:
|
||||
|
@ -258,6 +258,7 @@ async def call_process_api(
|
||||
event_id=event_id,
|
||||
event_data=event_data,
|
||||
in_event_listener=True,
|
||||
simple_format=body.simple_format,
|
||||
)
|
||||
iterator = output.pop("iterator", None)
|
||||
if event_id is not None:
|
||||
|
190
gradio/routes.py
190
gradio/routes.py
@ -22,7 +22,16 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty as EmptyQueue
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
|
||||
import fastapi
|
||||
import httpx
|
||||
@ -48,11 +57,15 @@ from starlette.responses import RedirectResponse, StreamingResponse
|
||||
import gradio
|
||||
from gradio import ranged_response, route_utils, utils, wasm_utils
|
||||
from gradio.context import Context
|
||||
from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
|
||||
from gradio.data_classes import (
|
||||
ComponentServerBody,
|
||||
PredictBody,
|
||||
ResetBody,
|
||||
SimplePredictBody,
|
||||
)
|
||||
from gradio.exceptions import Error
|
||||
from gradio.oauth import attach_oauth
|
||||
from gradio.processing_utils import add_root_url
|
||||
from gradio.queueing import Estimation
|
||||
from gradio.route_utils import ( # noqa: F401
|
||||
CustomCORSMiddleware,
|
||||
FileUploadProgress,
|
||||
@ -66,6 +79,14 @@ from gradio.route_utils import ( # noqa: F401
|
||||
create_lifespan_handler,
|
||||
move_uploaded_files_to_cache,
|
||||
)
|
||||
from gradio.server_messages import (
|
||||
EstimationMessage,
|
||||
EventMessage,
|
||||
HeartbeatMessage,
|
||||
ProcessCompletedMessage,
|
||||
ProcessGeneratingMessage,
|
||||
UnexpectedErrorMessage,
|
||||
)
|
||||
from gradio.state_holder import StateHolder
|
||||
from gradio.utils import get_package_version, get_upload_folder
|
||||
|
||||
@ -596,10 +617,98 @@ class App(FastAPI):
|
||||
output = add_root_url(output, root_path, None)
|
||||
return output
|
||||
|
||||
@app.post("/call/{api_name}", dependencies=[Depends(login_check)])
|
||||
@app.post("/call/{api_name}/", dependencies=[Depends(login_check)])
|
||||
async def simple_predict_post(
|
||||
api_name: str,
|
||||
body: SimplePredictBody,
|
||||
request: fastapi.Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
full_body = PredictBody(**body.model_dump(), simple_format=True)
|
||||
inferred_fn_index = route_utils.infer_fn_index(
|
||||
app=app, api_name=api_name, body=full_body
|
||||
)
|
||||
full_body.fn_index = inferred_fn_index
|
||||
return await queue_join_helper(full_body, request, username)
|
||||
|
||||
@app.post("/queue/join", dependencies=[Depends(login_check)])
|
||||
async def queue_join(
|
||||
body: PredictBody,
|
||||
request: fastapi.Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
if body.session_hash is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Session hash not found.",
|
||||
)
|
||||
return await queue_join_helper(body, request, username)
|
||||
|
||||
async def queue_join_helper(
|
||||
body: PredictBody,
|
||||
request: fastapi.Request,
|
||||
username: str,
|
||||
):
|
||||
blocks = app.get_blocks()
|
||||
|
||||
if blocks._queue.server_app is None:
|
||||
blocks._queue.set_server_app(app)
|
||||
|
||||
if blocks._queue.stopped:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Queue is stopped.",
|
||||
)
|
||||
|
||||
success, event_id = await blocks._queue.push(body, request, username)
|
||||
if not success:
|
||||
status_code = (
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
if "Queue is full." in event_id
|
||||
else status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=event_id)
|
||||
return {"event_id": event_id}
|
||||
|
||||
@app.get("/call/{api_name}/{event_id}", dependencies=[Depends(login_check)])
|
||||
async def simple_predict_get(
|
||||
request: fastapi.Request,
|
||||
event_id: str,
|
||||
):
|
||||
def process_msg(message: EventMessage) -> str | None:
|
||||
if isinstance(message, ProcessCompletedMessage):
|
||||
event = "complete" if message.success else "error"
|
||||
data = message.output.get("data")
|
||||
elif isinstance(message, ProcessGeneratingMessage):
|
||||
event = "generating" if message.success else "error"
|
||||
data = message.output.get("data")
|
||||
elif isinstance(message, HeartbeatMessage):
|
||||
event = "heartbeat"
|
||||
data = None
|
||||
elif isinstance(message, UnexpectedErrorMessage):
|
||||
event = "error"
|
||||
data = message.message
|
||||
else:
|
||||
return None
|
||||
return f"event: {event}\ndata: {json.dumps(data)}\n\n"
|
||||
|
||||
return await queue_data_helper(request, event_id, process_msg)
|
||||
|
||||
@app.get("/queue/data", dependencies=[Depends(login_check)])
|
||||
async def queue_data(
|
||||
request: fastapi.Request,
|
||||
session_hash: str,
|
||||
):
|
||||
def process_msg(message: EventMessage) -> str:
|
||||
return f"data: {json.dumps(message.model_dump())}\n\n"
|
||||
|
||||
return await queue_data_helper(request, session_hash, process_msg)
|
||||
|
||||
async def queue_data_helper(
|
||||
request: fastapi.Request,
|
||||
session_hash: str,
|
||||
process_msg: Callable[[EventMessage], str | None],
|
||||
):
|
||||
blocks = app.get_blocks()
|
||||
root_path = route_utils.get_root_url(
|
||||
@ -635,29 +744,35 @@ class App(FastAPI):
|
||||
await asyncio.sleep(check_rate)
|
||||
if time.perf_counter() - last_heartbeat > heartbeat_rate:
|
||||
# Fix this
|
||||
message = {
|
||||
"msg": ServerMessage.heartbeat,
|
||||
}
|
||||
message = HeartbeatMessage()
|
||||
# Need to reset last_heartbeat with perf_counter
|
||||
# otherwise only a single hearbeat msg will be sent
|
||||
# and then the stream will retry leading to infinite queue 😬
|
||||
last_heartbeat = time.perf_counter()
|
||||
|
||||
if blocks._queue.stopped:
|
||||
message = {
|
||||
"msg": "unexpected_error",
|
||||
"message": "Server stopped unexpectedly.",
|
||||
"success": False,
|
||||
}
|
||||
message = UnexpectedErrorMessage(
|
||||
message="Server stopped unexpectedly.",
|
||||
success=False,
|
||||
)
|
||||
if message:
|
||||
add_root_url(message, root_path, None)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
if message["msg"] == ServerMessage.process_completed:
|
||||
if isinstance(
|
||||
message,
|
||||
(ProcessGeneratingMessage, ProcessCompletedMessage),
|
||||
):
|
||||
add_root_url(message.output, root_path, None)
|
||||
response = process_msg(message)
|
||||
if response is not None:
|
||||
yield response
|
||||
if (
|
||||
isinstance(message, ProcessCompletedMessage)
|
||||
and message.event_id
|
||||
):
|
||||
blocks._queue.pending_event_ids_session[
|
||||
session_hash
|
||||
].remove(message["event_id"])
|
||||
if message["msg"] == ServerMessage.server_stopped or (
|
||||
message["msg"] == ServerMessage.process_completed
|
||||
].remove(message.event_id)
|
||||
if message.msg == ServerMessage.server_stopped or (
|
||||
message.msg == ServerMessage.process_completed
|
||||
and (
|
||||
len(
|
||||
blocks._queue.pending_event_ids_session[
|
||||
@ -669,12 +784,12 @@ class App(FastAPI):
|
||||
):
|
||||
return
|
||||
except BaseException as e:
|
||||
message = {
|
||||
"msg": "unexpected_error",
|
||||
"success": False,
|
||||
"message": str(e),
|
||||
}
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
message = UnexpectedErrorMessage(
|
||||
message=str(e),
|
||||
)
|
||||
response = process_msg(message)
|
||||
if response is not None:
|
||||
yield response
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
del blocks._queue.pending_messages_per_session[session_hash]
|
||||
await blocks._queue.clean_events(session_hash=session_hash)
|
||||
@ -685,33 +800,6 @@ class App(FastAPI):
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@app.post("/queue/join", dependencies=[Depends(login_check)])
|
||||
async def queue_join(
|
||||
body: PredictBody,
|
||||
request: fastapi.Request,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
blocks = app.get_blocks()
|
||||
|
||||
if blocks._queue.server_app is None:
|
||||
blocks._queue.set_server_app(app)
|
||||
|
||||
if blocks._queue.stopped:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Queue is stopped.",
|
||||
)
|
||||
|
||||
success, event_id = await blocks._queue.push(body, request, username)
|
||||
if not success:
|
||||
status_code = (
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
if "Queue is full." in event_id
|
||||
else status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=event_id)
|
||||
return {"event_id": event_id}
|
||||
|
||||
@app.post("/component_server", dependencies=[Depends(login_check)])
|
||||
@app.post("/component_server/", dependencies=[Depends(login_check)])
|
||||
def component_server(body: ComponentServerBody):
|
||||
@ -733,7 +821,7 @@ class App(FastAPI):
|
||||
@app.get(
|
||||
"/queue/status",
|
||||
dependencies=[Depends(login_check)],
|
||||
response_model=Estimation,
|
||||
response_model=EstimationMessage,
|
||||
)
|
||||
async def get_queue_status():
|
||||
return app.get_blocks()._queue.get_status()
|
||||
|
79
gradio/server_messages.py
Normal file
79
gradio/server_messages.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from gradio_client.utils import ServerMessage
|
||||
|
||||
# In the data_classes.py file, the BaseModel class is patched
|
||||
# so that it can be used in both Pyodide and non-Pyodide environments.
|
||||
# So we import it from the data_classes.py file
|
||||
# instead of the original Pydantic BaseModel class.
|
||||
from gradio.data_classes import BaseModel
|
||||
|
||||
|
||||
class BaseMessage(BaseModel):
|
||||
msg: ServerMessage
|
||||
event_id: Optional[str] = None
|
||||
|
||||
|
||||
class ProgressUnit(BaseModel):
|
||||
index: Optional[int] = None
|
||||
length: Optional[int] = None
|
||||
unit: Optional[str] = None
|
||||
progress: Optional[float] = None
|
||||
desc: Optional[str] = None
|
||||
|
||||
|
||||
class ProgressMessage(BaseMessage):
|
||||
msg: Literal[ServerMessage.progress] = ServerMessage.progress
|
||||
progress_data: List[ProgressUnit] = []
|
||||
|
||||
|
||||
class LogMessage(BaseMessage):
|
||||
msg: Literal[ServerMessage.log] = ServerMessage.log
|
||||
log: str
|
||||
level: Literal["info", "warning"]
|
||||
|
||||
|
||||
class EstimationMessage(BaseMessage):
|
||||
msg: Literal[ServerMessage.estimation] = ServerMessage.estimation
|
||||
rank: Optional[int] = None
|
||||
queue_size: int
|
||||
rank_eta: Optional[float] = None
|
||||
|
||||
|
||||
class ProcessStartsMessage(BaseMessage):
|
||||
msg: Literal[ServerMessage.process_starts] = ServerMessage.process_starts
|
||||
eta: Optional[float] = None
|
||||
|
||||
|
||||
class ProcessCompletedMessage(BaseMessage):
|
||||
msg: Literal[ServerMessage.process_completed] = ServerMessage.process_completed
|
||||
output: dict
|
||||
success: bool
|
||||
|
||||
|
||||
class ProcessGeneratingMessage(BaseMessage):
|
||||
msg: Literal[ServerMessage.process_generating] = ServerMessage.process_generating
|
||||
output: dict
|
||||
success: bool
|
||||
|
||||
|
||||
class HeartbeatMessage(BaseModel):
|
||||
msg: Literal[ServerMessage.heartbeat] = ServerMessage.heartbeat
|
||||
|
||||
|
||||
class UnexpectedErrorMessage(BaseModel):
|
||||
msg: Literal[ServerMessage.unexpected_error] = ServerMessage.unexpected_error
|
||||
message: str
|
||||
success: Literal[False] = False
|
||||
|
||||
|
||||
EventMessage = Union[
|
||||
ProgressMessage,
|
||||
LogMessage,
|
||||
EstimationMessage,
|
||||
ProcessStartsMessage,
|
||||
ProcessCompletedMessage,
|
||||
ProcessGeneratingMessage,
|
||||
HeartbeatMessage,
|
||||
UnexpectedErrorMessage,
|
||||
]
|
@ -1036,6 +1036,9 @@ class LRUCache(OrderedDict, Generic[K, V]):
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
return super().__getitem__(key)
|
||||
|
||||
|
||||
def get_cache_folder() -> Path:
|
||||
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))
|
||||
|
@ -2,6 +2,7 @@
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import asynccontextmanager, closing
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -9,6 +10,7 @@ import gradio_client as grc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
import starlette.routing
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
@ -1008,6 +1010,119 @@ def test_get_root_url(request_url, route_path, root_path, expected_root_url):
|
||||
assert get_root_url(request, route_path, root_path) == expected_root_url
|
||||
|
||||
|
||||
class TestSimpleAPIRoutes:
|
||||
def get_demo(self):
|
||||
with Blocks() as demo:
|
||||
input = Textbox()
|
||||
output = Textbox()
|
||||
output2 = Textbox()
|
||||
|
||||
def fn_1(x):
|
||||
return f"Hello, {x}!"
|
||||
|
||||
def fn_2(x):
|
||||
for i in range(len(x)):
|
||||
time.sleep(0.5)
|
||||
yield f"Hello, {x[:i+1]}!"
|
||||
if len(x) < 3:
|
||||
raise ValueError("Small input")
|
||||
|
||||
def fn_3():
|
||||
return "a", "b"
|
||||
|
||||
btn1, btn2, btn3 = Button(), Button(), Button()
|
||||
btn1.click(fn_1, input, output, api_name="fn1")
|
||||
btn2.click(fn_2, input, output2, api_name="fn2")
|
||||
btn3.click(fn_3, None, [output, output2], api_name="fn3")
|
||||
return demo
|
||||
|
||||
def test_successful_simple_route(self):
|
||||
demo = self.get_demo()
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
|
||||
response = requests.post(f"{demo.local_url}call/fn1", json={"data": ["world"]})
|
||||
|
||||
assert response.status_code == 200, "Failed to call fn1"
|
||||
response = response.json()
|
||||
event_id = response["event_id"]
|
||||
|
||||
output = []
|
||||
response = requests.get(f"{demo.local_url}call/fn1/{event_id}", stream=True)
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
output.append(line.decode("utf-8"))
|
||||
|
||||
assert output == ["event: complete", 'data: ["Hello, world!"]']
|
||||
|
||||
response = requests.post(f"{demo.local_url}call/fn3", json={"data": []})
|
||||
|
||||
assert response.status_code == 200, "Failed to call fn3"
|
||||
response = response.json()
|
||||
event_id = response["event_id"]
|
||||
|
||||
output = []
|
||||
response = requests.get(f"{demo.local_url}call/fn3/{event_id}", stream=True)
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
output.append(line.decode("utf-8"))
|
||||
|
||||
assert output == ["event: complete", 'data: ["a", "b"]']
|
||||
|
||||
def test_generative_simple_route(self):
|
||||
demo = self.get_demo()
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
|
||||
response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["world"]})
|
||||
|
||||
assert response.status_code == 200, "Failed to call fn2"
|
||||
response = response.json()
|
||||
event_id = response["event_id"]
|
||||
|
||||
output = []
|
||||
response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
output.append(line.decode("utf-8"))
|
||||
|
||||
assert output == [
|
||||
"event: generating",
|
||||
'data: ["Hello, w!"]',
|
||||
"event: generating",
|
||||
'data: ["Hello, wo!"]',
|
||||
"event: generating",
|
||||
'data: ["Hello, wor!"]',
|
||||
"event: generating",
|
||||
'data: ["Hello, worl!"]',
|
||||
"event: generating",
|
||||
'data: ["Hello, world!"]',
|
||||
"event: complete",
|
||||
'data: ["Hello, world!"]',
|
||||
]
|
||||
|
||||
response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["w"]})
|
||||
|
||||
assert response.status_code == 200, "Failed to call fn2"
|
||||
response = response.json()
|
||||
event_id = response["event_id"]
|
||||
|
||||
output = []
|
||||
response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
output.append(line.decode("utf-8"))
|
||||
|
||||
assert output == [
|
||||
"event: generating",
|
||||
'data: ["Hello, w!"]',
|
||||
"event: error",
|
||||
"data: null",
|
||||
]
|
||||
|
||||
|
||||
def test_compare_passwords_securely():
|
||||
password1 = "password"
|
||||
password2 = "pässword"
|
||||
|
Loading…
Reference in New Issue
Block a user