mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
* changes * add changeset * changes * add changeset * changes * channges * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * canges * changes * changes * changes * Update free-moose-guess.md * changes --------- Co-authored-by: Ali Abid <aliabid@Alis-MacBook-Pro.local> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Ali Abid <aliabid94@gmail.com>
652 lines
25 KiB
Python
652 lines
25 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import copy
|
|
import json
|
|
import os
|
|
import random
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from collections import defaultdict
|
|
from queue import Queue as ThreadQueue
|
|
from typing import TYPE_CHECKING
|
|
|
|
import fastapi
|
|
from gradio_client.utils import ServerMessage
|
|
from typing_extensions import Literal
|
|
|
|
from gradio import route_utils, routes
|
|
from gradio.data_classes import (
|
|
Estimation,
|
|
LogMessage,
|
|
PredictBody,
|
|
Progress,
|
|
ProgressUnit,
|
|
)
|
|
from gradio.exceptions import Error
|
|
from gradio.helpers import TrackedIterable
|
|
from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name
|
|
|
|
if TYPE_CHECKING:
|
|
from gradio.blocks import BlockFunction
|
|
|
|
|
|
class Event:
|
|
def __init__(
|
|
self,
|
|
session_hash: str,
|
|
fn_index: int,
|
|
request: fastapi.Request,
|
|
username: str | None,
|
|
concurrency_id: str,
|
|
):
|
|
self.session_hash = session_hash
|
|
self.fn_index = fn_index
|
|
self.request = request
|
|
self.username = username
|
|
self.concurrency_id = concurrency_id
|
|
self._id = uuid.uuid4().hex
|
|
self.data: PredictBody | None = None
|
|
self.progress: Progress | None = None
|
|
self.progress_pending: bool = False
|
|
self.alive = True
|
|
|
|
|
|
class EventQueue:
|
|
def __init__(self, concurrency_id: str, concurrency_limit: int | None):
|
|
self.queue: list[Event] = []
|
|
self.concurrency_id = concurrency_id
|
|
self.concurrency_limit = concurrency_limit
|
|
self.current_concurrency = 0
|
|
self.start_times_per_fn_index: defaultdict[int, set[float]] = defaultdict(set)
|
|
|
|
|
|
class ProcessTime:
|
|
def __init__(self):
|
|
self.process_time = 0
|
|
self.count = 0
|
|
self.avg_time = 0
|
|
|
|
def add(self, time: float):
|
|
self.process_time += time
|
|
self.count += 1
|
|
self.avg_time = self.process_time / self.count
|
|
|
|
|
|
class Queue:
|
|
def __init__(
|
|
self,
|
|
live_updates: bool,
|
|
concurrency_count: int,
|
|
update_intervals: float,
|
|
max_size: int | None,
|
|
block_fns: list[BlockFunction],
|
|
default_concurrency_limit: int | None | Literal["not_set"] = "not_set",
|
|
):
|
|
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000)
|
|
self.pending_event_ids_session: dict[str, set[str]] = {}
|
|
self.pending_message_lock = safe_get_lock()
|
|
self.event_queue_per_concurrency_id: dict[str, EventQueue] = {}
|
|
self.stopped = False
|
|
self.max_thread_count = concurrency_count
|
|
self.update_intervals = update_intervals
|
|
self.active_jobs: list[None | list[Event]] = []
|
|
self.delete_lock = safe_get_lock()
|
|
self.server_app = None
|
|
self.process_time_per_fn_index: defaultdict[int, ProcessTime] = defaultdict(
|
|
ProcessTime
|
|
)
|
|
self.live_updates = live_updates
|
|
self.sleep_when_free = 0.05
|
|
self.progress_update_sleep_when_free = 0.1
|
|
self.max_size = max_size
|
|
self.block_fns = block_fns
|
|
self.continuous_tasks: list[Event] = []
|
|
self._asyncio_tasks: list[asyncio.Task] = []
|
|
self.default_concurrency_limit = self._resolve_concurrency_limit(
|
|
default_concurrency_limit
|
|
)
|
|
|
|
def start(self):
|
|
self.active_jobs = [None] * self.max_thread_count
|
|
self.set_event_queue_per_concurrency_id()
|
|
|
|
run_coro_in_background(self.start_processing)
|
|
run_coro_in_background(self.start_progress_updates)
|
|
if not self.live_updates:
|
|
run_coro_in_background(self.notify_clients)
|
|
|
|
def set_event_queue_per_concurrency_id(self):
|
|
for block_fn in self.block_fns:
|
|
concurrency_id = block_fn.concurrency_id
|
|
concurrency_limit: int | None
|
|
if block_fn.concurrency_limit == "default":
|
|
concurrency_limit = self.default_concurrency_limit
|
|
else:
|
|
concurrency_limit = block_fn.concurrency_limit
|
|
if concurrency_id not in self.event_queue_per_concurrency_id:
|
|
self.event_queue_per_concurrency_id[concurrency_id] = EventQueue(
|
|
concurrency_id, concurrency_limit
|
|
)
|
|
elif (
|
|
concurrency_limit is not None
|
|
): # Update concurrency limit if it is lower than existing limit
|
|
existing_event_queue = self.event_queue_per_concurrency_id[
|
|
concurrency_id
|
|
]
|
|
if (
|
|
existing_event_queue.concurrency_limit is None
|
|
or concurrency_limit < existing_event_queue.concurrency_limit
|
|
):
|
|
existing_event_queue.concurrency_limit = concurrency_limit
|
|
|
|
def reload(self):
|
|
self.set_event_queue_per_concurrency_id()
|
|
|
|
def close(self):
|
|
self.stopped = True
|
|
|
|
def send_message(
|
|
self,
|
|
event: Event,
|
|
message_type: str,
|
|
data: dict | None = None,
|
|
):
|
|
if not event.alive:
|
|
return
|
|
data = {} if data is None else data
|
|
messages = self.pending_messages_per_session[event.session_hash]
|
|
messages.put_nowait({"msg": message_type, "event_id": event._id, **data})
|
|
|
|
def _resolve_concurrency_limit(
|
|
self, default_concurrency_limit: int | None | Literal["not_set"]
|
|
) -> int | None:
|
|
"""
|
|
Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination
|
|
of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT`
|
|
environment variable. The parameter in `Blocks.queue()` takes precedence over the environment variable.
|
|
Parameters:
|
|
default_concurrency_limit: The default concurrency limit, as specified by a user in `Blocks.queu()`.
|
|
"""
|
|
if default_concurrency_limit != "not_set":
|
|
return default_concurrency_limit
|
|
if default_concurrency_limit_env := os.environ.get(
|
|
"GRADIO_DEFAULT_CONCURRENCY_LIMIT"
|
|
):
|
|
if default_concurrency_limit_env.lower() == "none":
|
|
return None
|
|
else:
|
|
return int(default_concurrency_limit_env)
|
|
else:
|
|
return 1
|
|
|
|
def __len__(self):
|
|
total_len = 0
|
|
for event_queue in self.event_queue_per_concurrency_id.values():
|
|
total_len += len(event_queue.queue)
|
|
return total_len
|
|
|
|
async def push(
|
|
self, body: PredictBody, request: fastapi.Request, username: str | None
|
|
) -> tuple[bool, str]:
|
|
if body.session_hash is None:
|
|
return False, "No session hash provided."
|
|
if body.fn_index is None:
|
|
return False, "No function index provided."
|
|
if self.max_size is not None and len(self) >= self.max_size:
|
|
return (
|
|
False,
|
|
f"Queue is full. Max size is {self.max_size} and size is {len(self)}.",
|
|
)
|
|
|
|
event = Event(
|
|
body.session_hash,
|
|
body.fn_index,
|
|
request,
|
|
username,
|
|
self.block_fns[body.fn_index].concurrency_id,
|
|
)
|
|
event.data = body
|
|
async with self.pending_message_lock:
|
|
if body.session_hash not in self.pending_messages_per_session:
|
|
self.pending_messages_per_session[body.session_hash] = ThreadQueue()
|
|
if body.session_hash not in self.pending_event_ids_session:
|
|
self.pending_event_ids_session[body.session_hash] = set()
|
|
self.pending_event_ids_session[body.session_hash].add(event._id)
|
|
event_queue = self.event_queue_per_concurrency_id[event.concurrency_id]
|
|
event_queue.queue.append(event)
|
|
|
|
self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1)
|
|
|
|
return True, event._id
|
|
|
|
def _cancel_asyncio_tasks(self):
|
|
for task in self._asyncio_tasks:
|
|
task.cancel()
|
|
self._asyncio_tasks = []
|
|
|
|
def set_server_app(self, app: routes.App):
|
|
self.server_app = app
|
|
|
|
def get_active_worker_count(self) -> int:
|
|
count = 0
|
|
for worker in self.active_jobs:
|
|
if worker is not None:
|
|
count += 1
|
|
return count
|
|
|
|
def get_events(self) -> tuple[list[Event], bool, str] | None:
|
|
concurrency_ids = list(self.event_queue_per_concurrency_id.keys())
|
|
random.shuffle(concurrency_ids)
|
|
for concurrency_id in concurrency_ids:
|
|
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
|
|
if len(event_queue.queue) and (
|
|
event_queue.concurrency_limit is None
|
|
or event_queue.current_concurrency < event_queue.concurrency_limit
|
|
):
|
|
first_event = event_queue.queue[0]
|
|
block_fn = self.block_fns[first_event.fn_index]
|
|
events = [first_event]
|
|
batch = block_fn.batch
|
|
if batch:
|
|
events += [
|
|
event
|
|
for event in event_queue.queue[1:]
|
|
if event.fn_index == first_event.fn_index
|
|
][: block_fn.max_batch_size - 1]
|
|
|
|
for event in events:
|
|
event_queue.queue.remove(event)
|
|
|
|
return events, batch, concurrency_id
|
|
|
|
async def start_processing(self) -> None:
|
|
try:
|
|
while not self.stopped:
|
|
if len(self) == 0:
|
|
await asyncio.sleep(self.sleep_when_free)
|
|
continue
|
|
|
|
if None not in self.active_jobs:
|
|
await asyncio.sleep(self.sleep_when_free)
|
|
continue
|
|
|
|
# Using mutex to avoid editing a list in use
|
|
async with self.delete_lock:
|
|
event_batch = self.get_events()
|
|
|
|
if event_batch:
|
|
events, batch, concurrency_id = event_batch
|
|
self.active_jobs[self.active_jobs.index(None)] = events
|
|
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
|
|
event_queue.current_concurrency += 1
|
|
start_time = time.time()
|
|
event_queue.start_times_per_fn_index[events[0].fn_index].add(
|
|
start_time
|
|
)
|
|
process_event_task = run_coro_in_background(
|
|
self.process_events, events, batch, start_time
|
|
)
|
|
set_task_name(
|
|
process_event_task,
|
|
events[0].session_hash,
|
|
events[0].fn_index,
|
|
batch,
|
|
)
|
|
|
|
self._asyncio_tasks.append(process_event_task)
|
|
if self.live_updates:
|
|
self.broadcast_estimations(concurrency_id)
|
|
else:
|
|
await asyncio.sleep(self.sleep_when_free)
|
|
finally:
|
|
self.stopped = True
|
|
self._cancel_asyncio_tasks()
|
|
|
|
async def start_progress_updates(self) -> None:
|
|
"""
|
|
Because progress updates can be very frequent, we do not necessarily want to send a message per update.
|
|
Rather, we check for progress updates at regular intervals, and send a message if there is a pending update.
|
|
Consecutive progress updates between sends will overwrite each other so only the most recent update will be sent.
|
|
"""
|
|
while not self.stopped:
|
|
events = [
|
|
evt for job in self.active_jobs if job is not None for evt in job
|
|
] + self.continuous_tasks
|
|
|
|
if len(events) == 0:
|
|
await asyncio.sleep(self.progress_update_sleep_when_free)
|
|
continue
|
|
|
|
for event in events:
|
|
if event.progress_pending and event.progress:
|
|
event.progress_pending = False
|
|
self.send_message(
|
|
event, ServerMessage.progress, event.progress.model_dump()
|
|
)
|
|
|
|
await asyncio.sleep(self.progress_update_sleep_when_free)
|
|
|
|
def set_progress(
|
|
self,
|
|
event_id: str,
|
|
iterables: list[TrackedIterable] | None,
|
|
):
|
|
if iterables is None:
|
|
return
|
|
for job in self.active_jobs:
|
|
if job is None:
|
|
continue
|
|
for evt in job:
|
|
if evt._id == event_id:
|
|
progress_data: list[ProgressUnit] = []
|
|
for iterable in iterables:
|
|
progress_unit = ProgressUnit(
|
|
index=iterable.index,
|
|
length=iterable.length,
|
|
unit=iterable.unit,
|
|
progress=iterable.progress,
|
|
desc=iterable.desc,
|
|
)
|
|
progress_data.append(progress_unit)
|
|
evt.progress = Progress(progress_data=progress_data)
|
|
evt.progress_pending = True
|
|
|
|
def log_message(
|
|
self,
|
|
event_id: str,
|
|
log: str,
|
|
level: Literal["info", "warning"],
|
|
):
|
|
events = [
|
|
evt for job in self.active_jobs if job is not None for evt in job
|
|
] + self.continuous_tasks
|
|
for event in events:
|
|
if event._id == event_id:
|
|
log_message = LogMessage(
|
|
log=log,
|
|
level=level,
|
|
)
|
|
self.send_message(event, ServerMessage.log, log_message.model_dump())
|
|
|
|
async def clean_events(
|
|
self, *, session_hash: str | None = None, event_id: str | None = None
|
|
) -> None:
|
|
for job_set in self.active_jobs:
|
|
if job_set:
|
|
for job in job_set:
|
|
if job.session_hash == session_hash or job._id == event_id:
|
|
job.alive = False
|
|
|
|
async with self.delete_lock:
|
|
events_to_remove: list[Event] = []
|
|
for event_queue in self.event_queue_per_concurrency_id.values():
|
|
for event in event_queue.queue:
|
|
if event.session_hash == session_hash or event._id == event_id:
|
|
events_to_remove.append(event)
|
|
|
|
for event in events_to_remove:
|
|
self.event_queue_per_concurrency_id[event.concurrency_id].queue.remove(
|
|
event
|
|
)
|
|
|
|
async def notify_clients(self) -> None:
|
|
"""
|
|
Notify clients about events statuses in the queue periodically.
|
|
"""
|
|
while not self.stopped:
|
|
await asyncio.sleep(self.update_intervals)
|
|
if len(self) > 0:
|
|
for concurrency_id in self.event_queue_per_concurrency_id:
|
|
self.broadcast_estimations(concurrency_id)
|
|
|
|
def broadcast_estimations(
|
|
self, concurrency_id: str, after: int | None = None
|
|
) -> None:
|
|
wait_so_far = 0
|
|
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
|
|
time_till_available_worker: int | None = 0
|
|
|
|
if event_queue.current_concurrency == event_queue.concurrency_limit:
|
|
expected_end_times = []
|
|
for fn_index, start_times in event_queue.start_times_per_fn_index.items():
|
|
if fn_index not in self.process_time_per_fn_index:
|
|
time_till_available_worker = None
|
|
break
|
|
process_time = self.process_time_per_fn_index[fn_index].avg_time
|
|
expected_end_times += [
|
|
start_time + process_time for start_time in start_times
|
|
]
|
|
if time_till_available_worker is not None and len(expected_end_times) > 0:
|
|
time_of_first_completion = min(expected_end_times)
|
|
time_till_available_worker = max(
|
|
time_of_first_completion - time.time(), 0
|
|
)
|
|
|
|
for rank, event in enumerate(event_queue.queue):
|
|
process_time_for_fn = (
|
|
self.process_time_per_fn_index[event.fn_index].avg_time
|
|
if event.fn_index in self.process_time_per_fn_index
|
|
else None
|
|
)
|
|
rank_eta = (
|
|
process_time_for_fn + wait_so_far + time_till_available_worker
|
|
if process_time_for_fn is not None
|
|
and wait_so_far is not None
|
|
and time_till_available_worker is not None
|
|
else None
|
|
)
|
|
|
|
if after is None or rank >= after:
|
|
self.send_message(
|
|
event,
|
|
ServerMessage.estimation,
|
|
Estimation(
|
|
rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue)
|
|
).model_dump(),
|
|
)
|
|
if event_queue.concurrency_limit is None:
|
|
wait_so_far = 0
|
|
elif wait_so_far is not None and process_time_for_fn is not None:
|
|
wait_so_far += process_time_for_fn / event_queue.concurrency_limit
|
|
else:
|
|
wait_so_far = None
|
|
|
|
def get_status(self) -> Estimation:
|
|
return Estimation(
|
|
queue_size=len(self),
|
|
)
|
|
|
|
async def call_prediction(self, events: list[Event], batch: bool):
|
|
body = events[0].data
|
|
if body is None:
|
|
raise ValueError("No event data")
|
|
username = events[0].username
|
|
body.event_id = events[0]._id if not batch else None
|
|
try:
|
|
body.request = events[0].request
|
|
except ValueError:
|
|
pass
|
|
|
|
if batch:
|
|
body.data = list(zip(*[event.data.data for event in events if event.data]))
|
|
body.request = events[0].request
|
|
body.batched = True
|
|
|
|
app = self.server_app
|
|
if app is None:
|
|
raise Exception("Server app has not been set.")
|
|
api_name = "predict"
|
|
|
|
fn_index_inferred = route_utils.infer_fn_index(
|
|
app=app, api_name=api_name, body=body
|
|
)
|
|
|
|
gr_request = route_utils.compile_gr_request(
|
|
app=app,
|
|
body=body,
|
|
fn_index_inferred=fn_index_inferred,
|
|
username=username,
|
|
request=None,
|
|
)
|
|
|
|
try:
|
|
output = await route_utils.call_process_api(
|
|
app=app,
|
|
body=body,
|
|
gr_request=gr_request,
|
|
fn_index_inferred=fn_index_inferred,
|
|
)
|
|
except Exception as error:
|
|
show_error = app.get_blocks().show_error or isinstance(error, Error)
|
|
traceback.print_exc()
|
|
raise Exception(str(error) if show_error else None) from error
|
|
|
|
# To emulate the HTTP response from the predict API,
|
|
# convert the output to a JSON response string.
|
|
# This is done by FastAPI automatically in the HTTP endpoint handlers,
|
|
# but we need to do it manually here.
|
|
response_class = app.router.default_response_class
|
|
if isinstance(response_class, fastapi.datastructures.DefaultPlaceholder):
|
|
actual_response_class = response_class.value
|
|
else:
|
|
actual_response_class = response_class
|
|
http_response = actual_response_class(
|
|
output
|
|
) # Do the same as https://github.com/tiangolo/fastapi/blob/0.87.0/fastapi/routing.py#L264
|
|
# Also, decode the JSON string to a Python object, emulating the HTTP client behavior e.g. the `json()` method of `httpx`.
|
|
response_json = json.loads(http_response.body.decode())
|
|
|
|
return response_json
|
|
|
|
async def process_events(
|
|
self, events: list[Event], batch: bool, begin_time: float
|
|
) -> None:
|
|
awake_events: list[Event] = []
|
|
fn_index = events[0].fn_index
|
|
try:
|
|
for event in events:
|
|
if event.alive:
|
|
self.send_message(
|
|
event,
|
|
ServerMessage.process_starts,
|
|
{
|
|
"eta": self.process_time_per_fn_index[fn_index].avg_time
|
|
if fn_index in self.process_time_per_fn_index
|
|
else None
|
|
},
|
|
)
|
|
awake_events.append(event)
|
|
if not awake_events:
|
|
return
|
|
try:
|
|
response = await self.call_prediction(awake_events, batch)
|
|
err = None
|
|
except Exception as e:
|
|
response = None
|
|
err = e
|
|
for event in awake_events:
|
|
self.send_message(
|
|
event,
|
|
ServerMessage.process_completed,
|
|
{
|
|
"output": {
|
|
"error": None
|
|
if len(e.args) and e.args[0] is None
|
|
else str(e)
|
|
},
|
|
"success": False,
|
|
},
|
|
)
|
|
if response and response.get("is_generating", False):
|
|
old_response = response
|
|
old_err = err
|
|
while response and response.get("is_generating", False):
|
|
old_response = response
|
|
old_err = err
|
|
for event in awake_events:
|
|
self.send_message(
|
|
event,
|
|
ServerMessage.process_generating,
|
|
{
|
|
"output": old_response,
|
|
"success": old_response is not None,
|
|
},
|
|
)
|
|
awake_events = [event for event in awake_events if event.alive]
|
|
if not awake_events:
|
|
return
|
|
try:
|
|
response = await self.call_prediction(awake_events, batch)
|
|
err = None
|
|
except Exception as e:
|
|
response = None
|
|
err = e
|
|
for event in awake_events:
|
|
relevant_response = response or err or old_err
|
|
self.send_message(
|
|
event,
|
|
ServerMessage.process_completed,
|
|
{
|
|
"output": {"error": str(relevant_response)}
|
|
if isinstance(relevant_response, Exception)
|
|
else relevant_response,
|
|
"success": relevant_response
|
|
and not isinstance(relevant_response, Exception),
|
|
},
|
|
)
|
|
elif response:
|
|
output = copy.deepcopy(response)
|
|
for e, event in enumerate(awake_events):
|
|
if batch and "data" in output:
|
|
output["data"] = list(zip(*response.get("data")))[e]
|
|
self.send_message(
|
|
event,
|
|
ServerMessage.process_completed,
|
|
{
|
|
"output": output,
|
|
"success": response is not None,
|
|
},
|
|
)
|
|
end_time = time.time()
|
|
if response is not None:
|
|
self.process_time_per_fn_index[events[0].fn_index].add(
|
|
end_time - begin_time
|
|
)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
finally:
|
|
event_queue = self.event_queue_per_concurrency_id[events[0].concurrency_id]
|
|
event_queue.current_concurrency -= 1
|
|
start_times = event_queue.start_times_per_fn_index[fn_index]
|
|
if begin_time in start_times:
|
|
start_times.remove(begin_time)
|
|
try:
|
|
self.active_jobs[self.active_jobs.index(events)] = None
|
|
except ValueError:
|
|
# `events` can be absent from `self.active_jobs`
|
|
# when this coroutine is called from the `join_queue` endpoint handler in `routes.py`
|
|
# without putting the `events` into `self.active_jobs`.
|
|
# https://github.com/gradio-app/gradio/blob/f09aea34d6bd18c1e2fef80c86ab2476a6d1dd83/gradio/routes.py#L594-L596
|
|
pass
|
|
for event in events:
|
|
# Always reset the state of the iterator
|
|
# If the job finished successfully, this has no effect
|
|
# If the job is cancelled, this will enable future runs
|
|
# to start "from scratch"
|
|
await self.reset_iterators(event._id)
|
|
|
|
async def reset_iterators(self, event_id: str):
|
|
# Do the same thing as the /reset route
|
|
app = self.server_app
|
|
if app is None:
|
|
raise Exception("Server app has not been set.")
|
|
if event_id not in app.iterators:
|
|
# Failure, but don't raise an error
|
|
return
|
|
async with app.lock:
|
|
del app.iterators[event_id]
|
|
app.iterators_to_reset.add(event_id)
|
|
return
|