mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Refactor queue so that there are separate queues for each concurrency id (#6814)
* change * changes * add changeset * add changeset * changes * changes * changes * changes --------- Co-authored-by: Ali Abid <ubuntu@ip-172-31-25-241.us-west-2.compute.internal> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
73268ee2e3
commit
828fb9e6ce
7
.changeset/eighty-teeth-greet.md
Normal file
7
.changeset/eighty-teeth-greet.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/client": patch
|
||||
"@gradio/statustracker": patch
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
feat:Refactor queue so that there are separate queues for each concurrency id
|
@ -287,6 +287,7 @@ export function api_factory(
|
||||
const session_hash = Math.random().toString(36).substring(2);
|
||||
const last_status: Record<string, Status["stage"]> = {};
|
||||
let stream_open = false;
|
||||
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
|
||||
let event_stream: EventSource | null = null;
|
||||
const event_callbacks: Record<string, () => Promise<void>> = {};
|
||||
let config: Config;
|
||||
@ -908,8 +909,8 @@ export function api_factory(
|
||||
}
|
||||
|
||||
if (
|
||||
status.stage === "complete" ||
|
||||
status.stage === "error"
|
||||
status?.stage === "complete" ||
|
||||
status?.stage === "error"
|
||||
) {
|
||||
if (event_callbacks[event_id]) {
|
||||
delete event_callbacks[event_id];
|
||||
@ -932,6 +933,12 @@ export function api_factory(
|
||||
close_stream();
|
||||
}
|
||||
};
|
||||
if (event_id in pending_stream_messages) {
|
||||
pending_stream_messages[event_id].forEach((msg) =>
|
||||
callback(msg)
|
||||
);
|
||||
delete pending_stream_messages[event_id];
|
||||
}
|
||||
event_callbacks[event_id] = callback;
|
||||
if (!stream_open) {
|
||||
open_stream();
|
||||
@ -1051,15 +1058,21 @@ export function api_factory(
|
||||
event_stream = new EventSource(url);
|
||||
event_stream.onmessage = async function (event) {
|
||||
let _data = JSON.parse(event.data);
|
||||
if (!("event_id" in _data)) {
|
||||
const event_id = _data.event_id;
|
||||
if (!event_id) {
|
||||
await Promise.all(
|
||||
Object.keys(event_callbacks).map((event_id) =>
|
||||
event_callbacks[event_id](_data)
|
||||
)
|
||||
);
|
||||
return;
|
||||
} else if (event_callbacks[event_id]) {
|
||||
await event_callbacks[event_id](_data);
|
||||
} else {
|
||||
if (!pending_stream_messages[event_id]) {
|
||||
pending_stream_messages[event_id] = [];
|
||||
}
|
||||
pending_stream_messages[event_id].push(_data);
|
||||
}
|
||||
await event_callbacks[_data.event_id](_data);
|
||||
};
|
||||
}
|
||||
|
||||
@ -1701,8 +1714,7 @@ function handle_message(
|
||||
message: !data.success ? data.output.error : undefined,
|
||||
stage: data.success ? "complete" : "error",
|
||||
code: data.code,
|
||||
progress_data: data.progress_data,
|
||||
eta: data.output.average_duration
|
||||
progress_data: data.progress_data
|
||||
},
|
||||
data: data.success ? data.output : null
|
||||
};
|
||||
@ -1716,7 +1728,8 @@ function handle_message(
|
||||
code: data.code,
|
||||
size: data.rank,
|
||||
position: 0,
|
||||
success: data.success
|
||||
success: data.success,
|
||||
eta: data.eta
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -376,7 +376,7 @@ class BlockFunction:
|
||||
self.preprocess = preprocess
|
||||
self.postprocess = postprocess
|
||||
self.tracks_progress = tracks_progress
|
||||
self.concurrency_limit = concurrency_limit
|
||||
self.concurrency_limit: int | None | Literal["default"] = concurrency_limit
|
||||
self.concurrency_id = concurrency_id or str(id(fn))
|
||||
self.batch = batch
|
||||
self.max_batch_size = max_batch_size
|
||||
|
@ -101,10 +101,7 @@ class InterfaceTypes(Enum):
|
||||
class Estimation(BaseModel):
|
||||
rank: Optional[int] = None
|
||||
queue_size: int
|
||||
avg_event_process_time: Optional[float] = None
|
||||
avg_event_concurrent_process_time: Optional[float] = None
|
||||
rank_eta: Optional[float] = None
|
||||
queue_eta: float
|
||||
|
||||
|
||||
class ProgressUnit(BaseModel):
|
||||
|
@ -4,9 +4,11 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from queue import Queue as ThreadQueue
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -37,11 +39,13 @@ class Event:
|
||||
fn_index: int,
|
||||
request: fastapi.Request,
|
||||
username: str | None,
|
||||
concurrency_id: str,
|
||||
):
|
||||
self.session_hash = session_hash
|
||||
self.fn_index = fn_index
|
||||
self.request = request
|
||||
self.username = username
|
||||
self.concurrency_id = concurrency_id
|
||||
self._id = uuid.uuid4().hex
|
||||
self.data: PredictBody | None = None
|
||||
self.progress: Progress | None = None
|
||||
@ -49,6 +53,27 @@ class Event:
|
||||
self.alive = True
|
||||
|
||||
|
||||
class EventQueue:
|
||||
def __init__(self, concurrency_id: str, concurrency_limit: int | None):
|
||||
self.queue: list[Event] = []
|
||||
self.concurrency_id = concurrency_id
|
||||
self.concurrency_limit = concurrency_limit
|
||||
self.current_concurrency = 0
|
||||
self.start_times_per_fn_index: defaultdict[int, set[float]] = defaultdict(set)
|
||||
|
||||
|
||||
class ProcessTime:
|
||||
def __init__(self):
|
||||
self.process_time = 0
|
||||
self.count = 0
|
||||
self.avg_time = 0
|
||||
|
||||
def add(self, time: float):
|
||||
self.process_time += time
|
||||
self.count += 1
|
||||
self.avg_time = self.process_time / self.count
|
||||
|
||||
|
||||
class Queue:
|
||||
def __init__(
|
||||
self,
|
||||
@ -62,19 +87,16 @@ class Queue:
|
||||
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000)
|
||||
self.pending_event_ids_session: dict[str, set[str]] = {}
|
||||
self.pending_message_lock = safe_get_lock()
|
||||
self.event_queue: list[Event] = []
|
||||
self.awaiting_data_events: dict[str, Event] = {}
|
||||
self.event_queue_per_concurrency_id: dict[str, EventQueue] = {}
|
||||
self.stopped = False
|
||||
self.max_thread_count = concurrency_count
|
||||
self.update_intervals = update_intervals
|
||||
self.active_jobs: list[None | list[Event]] = []
|
||||
self.delete_lock = safe_get_lock()
|
||||
self.server_app = None
|
||||
self.duration_history_total = 0
|
||||
self.duration_history_count = 0
|
||||
self.avg_process_time = 0
|
||||
self.avg_concurrent_process_time = None
|
||||
self.queue_duration = 1
|
||||
self.process_time_per_fn_index: defaultdict[int, ProcessTime] = defaultdict(
|
||||
ProcessTime
|
||||
)
|
||||
self.live_updates = live_updates
|
||||
self.sleep_when_free = 0.05
|
||||
self.progress_update_sleep_when_free = 0.1
|
||||
@ -85,25 +107,31 @@ class Queue:
|
||||
self.default_concurrency_limit = self._resolve_concurrency_limit(
|
||||
default_concurrency_limit
|
||||
)
|
||||
self.concurrency_limit_per_concurrency_id = {}
|
||||
|
||||
def start(self):
|
||||
self.active_jobs = [None] * self.max_thread_count
|
||||
for block_fn in self.block_fns:
|
||||
concurrency_limit = (
|
||||
self.default_concurrency_limit
|
||||
if block_fn.concurrency_limit == "default"
|
||||
else block_fn.concurrency_limit
|
||||
)
|
||||
if concurrency_limit is not None:
|
||||
self.concurrency_limit_per_concurrency_id[
|
||||
block_fn.concurrency_id
|
||||
] = min(
|
||||
self.concurrency_limit_per_concurrency_id.get(
|
||||
block_fn.concurrency_id, concurrency_limit
|
||||
),
|
||||
concurrency_limit,
|
||||
concurrency_id = block_fn.concurrency_id
|
||||
concurrency_limit: int | None
|
||||
if block_fn.concurrency_limit == "default":
|
||||
concurrency_limit = self.default_concurrency_limit
|
||||
else:
|
||||
concurrency_limit = block_fn.concurrency_limit
|
||||
if concurrency_id not in self.event_queue_per_concurrency_id:
|
||||
self.event_queue_per_concurrency_id[concurrency_id] = EventQueue(
|
||||
concurrency_id, concurrency_limit
|
||||
)
|
||||
elif (
|
||||
concurrency_limit is not None
|
||||
): # Update concurrency limit if it is lower than existing limit
|
||||
existing_event_queue = self.event_queue_per_concurrency_id[
|
||||
concurrency_id
|
||||
]
|
||||
if (
|
||||
existing_event_queue.concurrency_limit is None
|
||||
or concurrency_limit < existing_event_queue.concurrency_limit
|
||||
):
|
||||
existing_event_queue.concurrency_limit = concurrency_limit
|
||||
|
||||
run_coro_in_background(self.start_processing)
|
||||
run_coro_in_background(self.start_progress_updates)
|
||||
@ -119,11 +147,15 @@ class Queue:
|
||||
message_type: str,
|
||||
data: dict | None = None,
|
||||
):
|
||||
if not event.alive:
|
||||
return
|
||||
data = {} if data is None else data
|
||||
messages = self.pending_messages_per_session[event.session_hash]
|
||||
messages.put_nowait({"msg": message_type, "event_id": event._id, **data})
|
||||
|
||||
def _resolve_concurrency_limit(self, default_concurrency_limit):
|
||||
def _resolve_concurrency_limit(
|
||||
self, default_concurrency_limit: int | None | Literal["not_set"]
|
||||
) -> int | None:
|
||||
"""
|
||||
Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination
|
||||
of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT`
|
||||
@ -143,6 +175,12 @@ class Queue:
|
||||
else:
|
||||
return 1
|
||||
|
||||
def __len__(self):
|
||||
total_len = 0
|
||||
for event_queue in self.event_queue_per_concurrency_id.values():
|
||||
total_len += len(event_queue.queue)
|
||||
return total_len
|
||||
|
||||
async def push(
|
||||
self, body: PredictBody, request: fastapi.Request, username: str | None
|
||||
) -> tuple[bool, str]:
|
||||
@ -150,14 +188,19 @@ class Queue:
|
||||
return False, "No session hash provided."
|
||||
if body.fn_index is None:
|
||||
return False, "No function index provided."
|
||||
queue_len = len(self.event_queue)
|
||||
if self.max_size is not None and queue_len >= self.max_size:
|
||||
if self.max_size is not None and len(self) >= self.max_size:
|
||||
return (
|
||||
False,
|
||||
f"Queue is full. Max size is {self.max_size} and size is {queue_len}.",
|
||||
f"Queue is full. Max size is {self.max_size} and size is {len(self)}.",
|
||||
)
|
||||
|
||||
event = Event(body.session_hash, body.fn_index, request, username)
|
||||
event = Event(
|
||||
body.session_hash,
|
||||
body.fn_index,
|
||||
request,
|
||||
username,
|
||||
self.block_fns[body.fn_index].concurrency_id,
|
||||
)
|
||||
event.data = body
|
||||
async with self.pending_message_lock:
|
||||
if body.session_hash not in self.pending_messages_per_session:
|
||||
@ -165,10 +208,10 @@ class Queue:
|
||||
if body.session_hash not in self.pending_event_ids_session:
|
||||
self.pending_event_ids_session[body.session_hash] = set()
|
||||
self.pending_event_ids_session[body.session_hash].add(event._id)
|
||||
self.event_queue.append(event)
|
||||
event_queue = self.event_queue_per_concurrency_id[event.concurrency_id]
|
||||
event_queue.queue.append(event)
|
||||
|
||||
estimation = self.get_estimation()
|
||||
await self.send_estimation(event, estimation, queue_len)
|
||||
self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1)
|
||||
|
||||
return True, event._id
|
||||
|
||||
@ -187,88 +230,73 @@ class Queue:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_events_in_batch(self) -> tuple[list[Event] | None, bool]:
|
||||
if not self.event_queue:
|
||||
return None, False
|
||||
|
||||
worker_count_per_concurrency_id = {}
|
||||
for job in self.active_jobs:
|
||||
if job is not None:
|
||||
for event in job:
|
||||
concurrency_id = self.block_fns[event.fn_index].concurrency_id
|
||||
worker_count_per_concurrency_id[concurrency_id] = (
|
||||
worker_count_per_concurrency_id.get(concurrency_id, 0) + 1
|
||||
)
|
||||
|
||||
events = []
|
||||
batch = False
|
||||
for index, event in enumerate(self.event_queue):
|
||||
block_fn = self.block_fns[event.fn_index]
|
||||
concurrency_id = block_fn.concurrency_id
|
||||
concurrency_limit = self.concurrency_limit_per_concurrency_id.get(
|
||||
concurrency_id, None
|
||||
)
|
||||
existing_worker_count = worker_count_per_concurrency_id.get(
|
||||
concurrency_id, 0
|
||||
)
|
||||
if concurrency_limit is None or existing_worker_count < concurrency_limit:
|
||||
def get_events(self) -> tuple[list[Event], bool, str] | None:
|
||||
concurrency_ids = list(self.event_queue_per_concurrency_id.keys())
|
||||
random.shuffle(concurrency_ids)
|
||||
for concurrency_id in concurrency_ids:
|
||||
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
|
||||
if len(event_queue.queue) and (
|
||||
event_queue.concurrency_limit is None
|
||||
or event_queue.current_concurrency < event_queue.concurrency_limit
|
||||
):
|
||||
first_event = event_queue.queue[0]
|
||||
block_fn = self.block_fns[first_event.fn_index]
|
||||
events = [first_event]
|
||||
batch = block_fn.batch
|
||||
if batch:
|
||||
batch_size = block_fn.max_batch_size
|
||||
if concurrency_limit is None:
|
||||
remaining_worker_count = batch_size - 1
|
||||
else:
|
||||
remaining_worker_count = (
|
||||
concurrency_limit - existing_worker_count
|
||||
)
|
||||
rest_of_batch = [
|
||||
events += [
|
||||
event
|
||||
for event in self.event_queue[index:]
|
||||
if event.fn_index == event.fn_index
|
||||
][: min(batch_size - 1, remaining_worker_count)]
|
||||
events = [event] + rest_of_batch
|
||||
else:
|
||||
events = [event]
|
||||
break
|
||||
for event in event_queue.queue[1:]
|
||||
if event.fn_index == first_event.fn_index
|
||||
][: block_fn.max_batch_size - 1]
|
||||
|
||||
for event in events:
|
||||
self.event_queue.remove(event)
|
||||
for event in events:
|
||||
event_queue.queue.remove(event)
|
||||
|
||||
return events, batch
|
||||
return events, batch, concurrency_id
|
||||
|
||||
async def start_processing(self) -> None:
|
||||
while not self.stopped:
|
||||
if not self.event_queue:
|
||||
await asyncio.sleep(self.sleep_when_free)
|
||||
continue
|
||||
try:
|
||||
while not self.stopped:
|
||||
if len(self) == 0:
|
||||
await asyncio.sleep(self.sleep_when_free)
|
||||
continue
|
||||
|
||||
if None not in self.active_jobs:
|
||||
await asyncio.sleep(self.sleep_when_free)
|
||||
continue
|
||||
# Using mutex to avoid editing a list in use
|
||||
async with self.delete_lock:
|
||||
events, batch = self.get_events_in_batch()
|
||||
if None not in self.active_jobs:
|
||||
await asyncio.sleep(self.sleep_when_free)
|
||||
continue
|
||||
|
||||
if events:
|
||||
self.active_jobs[self.active_jobs.index(None)] = events
|
||||
process_event_task = run_coro_in_background(
|
||||
self.process_events, events, batch
|
||||
)
|
||||
set_task_name(
|
||||
process_event_task,
|
||||
events[0].session_hash,
|
||||
events[0].fn_index,
|
||||
batch,
|
||||
)
|
||||
# Using mutex to avoid editing a list in use
|
||||
async with self.delete_lock:
|
||||
event_batch = self.get_events()
|
||||
|
||||
self._asyncio_tasks.append(process_event_task)
|
||||
if self.live_updates:
|
||||
broadcast_live_estimations_task = run_coro_in_background(
|
||||
self.broadcast_estimations
|
||||
if event_batch:
|
||||
events, batch, concurrency_id = event_batch
|
||||
self.active_jobs[self.active_jobs.index(None)] = events
|
||||
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
|
||||
event_queue.current_concurrency += 1
|
||||
start_time = time.time()
|
||||
event_queue.start_times_per_fn_index[events[0].fn_index].add(
|
||||
start_time
|
||||
)
|
||||
self._asyncio_tasks.append(broadcast_live_estimations_task)
|
||||
else:
|
||||
await asyncio.sleep(self.sleep_when_free)
|
||||
process_event_task = run_coro_in_background(
|
||||
self.process_events, events, batch, start_time
|
||||
)
|
||||
set_task_name(
|
||||
process_event_task,
|
||||
events[0].session_hash,
|
||||
events[0].fn_index,
|
||||
batch,
|
||||
)
|
||||
|
||||
self._asyncio_tasks.append(process_event_task)
|
||||
if self.live_updates:
|
||||
self.broadcast_estimations(concurrency_id)
|
||||
else:
|
||||
await asyncio.sleep(self.sleep_when_free)
|
||||
finally:
|
||||
self.stopped = True
|
||||
self._cancel_asyncio_tasks()
|
||||
|
||||
async def start_progress_updates(self) -> None:
|
||||
"""
|
||||
@ -345,14 +373,17 @@ class Queue:
|
||||
if job.session_hash == session_hash or job._id == event_id:
|
||||
job.alive = False
|
||||
|
||||
events_to_remove = []
|
||||
for event in self.event_queue:
|
||||
if event.session_hash == session_hash or event._id == event_id:
|
||||
events_to_remove.append(event)
|
||||
async with self.delete_lock:
|
||||
events_to_remove: list[Event] = []
|
||||
for event_queue in self.event_queue_per_concurrency_id.values():
|
||||
for event in event_queue.queue:
|
||||
if event.session_hash == session_hash or event._id == event_id:
|
||||
events_to_remove.append(event)
|
||||
|
||||
for event in events_to_remove:
|
||||
async with self.delete_lock:
|
||||
self.event_queue.remove(event)
|
||||
for event in events_to_remove:
|
||||
self.event_queue_per_concurrency_id[event.concurrency_id].queue.remove(
|
||||
event
|
||||
)
|
||||
|
||||
async def notify_clients(self) -> None:
|
||||
"""
|
||||
@ -360,66 +391,65 @@ class Queue:
|
||||
"""
|
||||
while not self.stopped:
|
||||
await asyncio.sleep(self.update_intervals)
|
||||
if self.event_queue:
|
||||
await self.broadcast_estimations()
|
||||
if len(self) > 0:
|
||||
for concurrency_id in self.event_queue_per_concurrency_id:
|
||||
self.broadcast_estimations(concurrency_id)
|
||||
|
||||
async def broadcast_estimations(self) -> None:
|
||||
estimation = self.get_estimation()
|
||||
# Send all messages concurrently
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.send_estimation(event, estimation, rank)
|
||||
for rank, event in enumerate(self.event_queue)
|
||||
]
|
||||
)
|
||||
def broadcast_estimations(
|
||||
self, concurrency_id: str, after: int | None = None
|
||||
) -> None:
|
||||
wait_so_far = 0
|
||||
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
|
||||
time_till_available_worker: int | None = 0
|
||||
|
||||
async def send_estimation(
|
||||
self, event: Event, estimation: Estimation, rank: int
|
||||
) -> Estimation:
|
||||
"""
|
||||
Send estimation about ETA to the client.
|
||||
if event_queue.current_concurrency == event_queue.concurrency_limit:
|
||||
expected_end_times = []
|
||||
for fn_index, start_times in event_queue.start_times_per_fn_index.items():
|
||||
if fn_index not in self.process_time_per_fn_index:
|
||||
time_till_available_worker = None
|
||||
break
|
||||
process_time = self.process_time_per_fn_index[fn_index].avg_time
|
||||
expected_end_times += [
|
||||
start_time + process_time for start_time in start_times
|
||||
]
|
||||
if time_till_available_worker is not None and len(expected_end_times) > 0:
|
||||
time_of_first_completion = min(expected_end_times)
|
||||
time_till_available_worker = max(
|
||||
time_of_first_completion - time.time(), 0
|
||||
)
|
||||
|
||||
Parameters:
|
||||
event:
|
||||
estimation:
|
||||
rank:
|
||||
"""
|
||||
estimation.rank = rank
|
||||
|
||||
if self.avg_concurrent_process_time is not None:
|
||||
estimation.rank_eta = (
|
||||
estimation.rank * self.avg_concurrent_process_time
|
||||
+ self.avg_process_time
|
||||
for rank, event in enumerate(event_queue.queue):
|
||||
process_time_for_fn = (
|
||||
self.process_time_per_fn_index[event.fn_index].avg_time
|
||||
if event.fn_index in self.process_time_per_fn_index
|
||||
else None
|
||||
)
|
||||
rank_eta = (
|
||||
process_time_for_fn + wait_so_far + time_till_available_worker
|
||||
if process_time_for_fn is not None
|
||||
and wait_so_far is not None
|
||||
and time_till_available_worker is not None
|
||||
else None
|
||||
)
|
||||
if None not in self.active_jobs:
|
||||
# Add estimated amount of time for a thread to get empty
|
||||
estimation.rank_eta += self.avg_concurrent_process_time
|
||||
self.send_message(event, ServerMessage.estimation, estimation.model_dump())
|
||||
return estimation
|
||||
|
||||
def update_estimation(self, duration: float) -> None:
|
||||
"""
|
||||
Update estimation by last x element's average duration.
|
||||
if after is None or rank >= after:
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.estimation,
|
||||
Estimation(
|
||||
rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue)
|
||||
).model_dump(),
|
||||
)
|
||||
if event_queue.concurrency_limit is None:
|
||||
wait_so_far = 0
|
||||
elif wait_so_far is not None and process_time_for_fn is not None:
|
||||
wait_so_far += process_time_for_fn / event_queue.concurrency_limit
|
||||
else:
|
||||
wait_so_far = None
|
||||
|
||||
Parameters:
|
||||
duration:
|
||||
"""
|
||||
self.duration_history_total += duration
|
||||
self.duration_history_count += 1
|
||||
self.avg_process_time = (
|
||||
self.duration_history_total / self.duration_history_count
|
||||
)
|
||||
self.avg_concurrent_process_time = self.avg_process_time / min(
|
||||
self.max_thread_count, self.duration_history_count
|
||||
)
|
||||
self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue)
|
||||
|
||||
def get_estimation(self) -> Estimation:
|
||||
def get_status(self) -> Estimation:
|
||||
return Estimation(
|
||||
queue_size=len(self.event_queue),
|
||||
avg_event_process_time=self.avg_process_time,
|
||||
avg_event_concurrent_process_time=self.avg_concurrent_process_time,
|
||||
queue_eta=self.queue_duration,
|
||||
queue_size=len(self),
|
||||
)
|
||||
|
||||
async def call_prediction(self, events: list[Event], batch: bool):
|
||||
@ -484,20 +514,30 @@ class Queue:
|
||||
|
||||
return response_json
|
||||
|
||||
async def process_events(self, events: list[Event], batch: bool) -> None:
|
||||
async def process_events(
|
||||
self, events: list[Event], batch: bool, begin_time: float
|
||||
) -> None:
|
||||
awake_events: list[Event] = []
|
||||
fn_index = events[0].fn_index
|
||||
try:
|
||||
for event in events:
|
||||
self.send_message(event, ServerMessage.process_starts)
|
||||
awake_events.append(event)
|
||||
if event.alive:
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.process_starts,
|
||||
{
|
||||
"eta": self.process_time_per_fn_index[fn_index].avg_time
|
||||
if fn_index in self.process_time_per_fn_index
|
||||
else None
|
||||
},
|
||||
)
|
||||
awake_events.append(event)
|
||||
if not awake_events:
|
||||
return
|
||||
begin_time = time.time()
|
||||
try:
|
||||
response = await self.call_prediction(awake_events, batch)
|
||||
err = None
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
response = None
|
||||
err = e
|
||||
for event in awake_events:
|
||||
@ -568,10 +608,17 @@ class Queue:
|
||||
)
|
||||
end_time = time.time()
|
||||
if response is not None:
|
||||
self.update_estimation(end_time - begin_time)
|
||||
self.process_time_per_fn_index[events[0].fn_index].add(
|
||||
end_time - begin_time
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
event_queue = self.event_queue_per_concurrency_id[events[0].concurrency_id]
|
||||
event_queue.current_concurrency -= 1
|
||||
start_times = event_queue.start_times_per_fn_index[fn_index]
|
||||
if begin_time in start_times:
|
||||
start_times.remove(begin_time)
|
||||
try:
|
||||
self.active_jobs[self.active_jobs.index(events)] = None
|
||||
except ValueError:
|
||||
|
@ -673,6 +673,12 @@ class App(FastAPI):
|
||||
if blocks._queue.server_app is None:
|
||||
blocks._queue.set_server_app(app)
|
||||
|
||||
if blocks._queue.stopped:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Queue is stopped.",
|
||||
)
|
||||
|
||||
success, event_id = await blocks._queue.push(body, request, username)
|
||||
if not success:
|
||||
status_code = (
|
||||
@ -702,7 +708,7 @@ class App(FastAPI):
|
||||
response_model=Estimation,
|
||||
)
|
||||
async def get_queue_status():
|
||||
return app.get_blocks()._queue.get_estimation()
|
||||
return app.get_blocks()._queue.get_status()
|
||||
|
||||
@app.get("/upload_progress")
|
||||
def get_upload_progress(upload_id: str, request: fastapi.Request):
|
||||
|
@ -53,7 +53,6 @@
|
||||
|
||||
export let i18n: I18nFormatter;
|
||||
export let eta: number | null = null;
|
||||
export let queue = false;
|
||||
export let queue_position: number | null;
|
||||
export let queue_size: number | null;
|
||||
export let status: "complete" | "pending" | "error" | "generating";
|
||||
@ -75,6 +74,7 @@
|
||||
let timer_start = 0;
|
||||
let timer_diff = 0;
|
||||
let old_eta: number | null = null;
|
||||
let eta_from_start: number | null = null;
|
||||
let message_visible = false;
|
||||
let eta_level: number | null = 0;
|
||||
let progress_level: (number | undefined)[] | null = null;
|
||||
@ -83,9 +83,9 @@
|
||||
let show_eta_bar = true;
|
||||
|
||||
$: eta_level =
|
||||
eta === null || eta <= 0 || !timer_diff
|
||||
eta_from_start === null || eta_from_start <= 0 || !timer_diff
|
||||
? null
|
||||
: Math.min(timer_diff / eta, 1);
|
||||
: Math.min(timer_diff / eta_from_start, 1);
|
||||
$: if (progress != null) {
|
||||
show_eta_bar = false;
|
||||
}
|
||||
@ -119,6 +119,7 @@
|
||||
}
|
||||
|
||||
const start_timer = (): void => {
|
||||
eta = old_eta = formatted_eta = null;
|
||||
timer_start = performance.now();
|
||||
timer_diff = 0;
|
||||
_timer = true;
|
||||
@ -134,6 +135,7 @@
|
||||
|
||||
function stop_timer(): void {
|
||||
timer_diff = 0;
|
||||
eta = old_eta = formatted_eta = null;
|
||||
|
||||
if (!_timer) return;
|
||||
_timer = false;
|
||||
@ -160,11 +162,10 @@
|
||||
$: {
|
||||
if (eta === null) {
|
||||
eta = old_eta;
|
||||
} else if (queue) {
|
||||
eta = (performance.now() - timer_start) / 1000 + eta;
|
||||
}
|
||||
if (eta != null) {
|
||||
formatted_eta = eta.toFixed(1);
|
||||
if (eta != null && old_eta !== eta) {
|
||||
eta_from_start = (performance.now() - timer_start) / 1000 + eta;
|
||||
formatted_eta = eta_from_start.toFixed(1);
|
||||
old_eta = eta;
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user