Refactor queue so that there are separate queues for each concurrency id (#6814)

* change

* changes

* add changeset

* add changeset

* changes

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <ubuntu@ip-172-31-25-241.us-west-2.compute.internal>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2023-12-19 11:42:56 -08:00 committed by GitHub
parent 73268ee2e3
commit 828fb9e6ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 257 additions and 186 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/client": patch
"@gradio/statustracker": patch
"gradio": patch
---
feat:Refactor queue so that there are separate queues for each concurrency id

View File

@ -287,6 +287,7 @@ export function api_factory(
const session_hash = Math.random().toString(36).substring(2);
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
let config: Config;
@ -908,8 +909,8 @@ export function api_factory(
}
if (
status.stage === "complete" ||
status.stage === "error"
status?.stage === "complete" ||
status?.stage === "error"
) {
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
@ -932,6 +933,12 @@ export function api_factory(
close_stream();
}
};
if (event_id in pending_stream_messages) {
pending_stream_messages[event_id].forEach((msg) =>
callback(msg)
);
delete pending_stream_messages[event_id];
}
event_callbacks[event_id] = callback;
if (!stream_open) {
open_stream();
@ -1051,15 +1058,21 @@ export function api_factory(
event_stream = new EventSource(url);
event_stream.onmessage = async function (event) {
let _data = JSON.parse(event.data);
if (!("event_id" in _data)) {
const event_id = _data.event_id;
if (!event_id) {
await Promise.all(
Object.keys(event_callbacks).map((event_id) =>
event_callbacks[event_id](_data)
)
);
return;
} else if (event_callbacks[event_id]) {
await event_callbacks[event_id](_data);
} else {
if (!pending_stream_messages[event_id]) {
pending_stream_messages[event_id] = [];
}
pending_stream_messages[event_id].push(_data);
}
await event_callbacks[_data.event_id](_data);
};
}
@ -1701,8 +1714,7 @@ function handle_message(
message: !data.success ? data.output.error : undefined,
stage: data.success ? "complete" : "error",
code: data.code,
progress_data: data.progress_data,
eta: data.output.average_duration
progress_data: data.progress_data
},
data: data.success ? data.output : null
};
@ -1716,7 +1728,8 @@ function handle_message(
code: data.code,
size: data.rank,
position: 0,
success: data.success
success: data.success,
eta: data.eta
}
};
}

View File

@ -376,7 +376,7 @@ class BlockFunction:
self.preprocess = preprocess
self.postprocess = postprocess
self.tracks_progress = tracks_progress
self.concurrency_limit = concurrency_limit
self.concurrency_limit: int | None | Literal["default"] = concurrency_limit
self.concurrency_id = concurrency_id or str(id(fn))
self.batch = batch
self.max_batch_size = max_batch_size

View File

@ -101,10 +101,7 @@ class InterfaceTypes(Enum):
class Estimation(BaseModel):
rank: Optional[int] = None
queue_size: int
avg_event_process_time: Optional[float] = None
avg_event_concurrent_process_time: Optional[float] = None
rank_eta: Optional[float] = None
queue_eta: float
class ProgressUnit(BaseModel):

View File

@ -4,9 +4,11 @@ import asyncio
import copy
import json
import os
import random
import time
import traceback
import uuid
from collections import defaultdict
from queue import Queue as ThreadQueue
from typing import TYPE_CHECKING
@ -37,11 +39,13 @@ class Event:
fn_index: int,
request: fastapi.Request,
username: str | None,
concurrency_id: str,
):
self.session_hash = session_hash
self.fn_index = fn_index
self.request = request
self.username = username
self.concurrency_id = concurrency_id
self._id = uuid.uuid4().hex
self.data: PredictBody | None = None
self.progress: Progress | None = None
@ -49,6 +53,27 @@ class Event:
self.alive = True
class EventQueue:
def __init__(self, concurrency_id: str, concurrency_limit: int | None):
self.queue: list[Event] = []
self.concurrency_id = concurrency_id
self.concurrency_limit = concurrency_limit
self.current_concurrency = 0
self.start_times_per_fn_index: defaultdict[int, set[float]] = defaultdict(set)
class ProcessTime:
def __init__(self):
self.process_time = 0
self.count = 0
self.avg_time = 0
def add(self, time: float):
self.process_time += time
self.count += 1
self.avg_time = self.process_time / self.count
class Queue:
def __init__(
self,
@ -62,19 +87,16 @@ class Queue:
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000)
self.pending_event_ids_session: dict[str, set[str]] = {}
self.pending_message_lock = safe_get_lock()
self.event_queue: list[Event] = []
self.awaiting_data_events: dict[str, Event] = {}
self.event_queue_per_concurrency_id: dict[str, EventQueue] = {}
self.stopped = False
self.max_thread_count = concurrency_count
self.update_intervals = update_intervals
self.active_jobs: list[None | list[Event]] = []
self.delete_lock = safe_get_lock()
self.server_app = None
self.duration_history_total = 0
self.duration_history_count = 0
self.avg_process_time = 0
self.avg_concurrent_process_time = None
self.queue_duration = 1
self.process_time_per_fn_index: defaultdict[int, ProcessTime] = defaultdict(
ProcessTime
)
self.live_updates = live_updates
self.sleep_when_free = 0.05
self.progress_update_sleep_when_free = 0.1
@ -85,25 +107,31 @@ class Queue:
self.default_concurrency_limit = self._resolve_concurrency_limit(
default_concurrency_limit
)
self.concurrency_limit_per_concurrency_id = {}
def start(self):
self.active_jobs = [None] * self.max_thread_count
for block_fn in self.block_fns:
concurrency_limit = (
self.default_concurrency_limit
if block_fn.concurrency_limit == "default"
else block_fn.concurrency_limit
)
if concurrency_limit is not None:
self.concurrency_limit_per_concurrency_id[
block_fn.concurrency_id
] = min(
self.concurrency_limit_per_concurrency_id.get(
block_fn.concurrency_id, concurrency_limit
),
concurrency_limit,
concurrency_id = block_fn.concurrency_id
concurrency_limit: int | None
if block_fn.concurrency_limit == "default":
concurrency_limit = self.default_concurrency_limit
else:
concurrency_limit = block_fn.concurrency_limit
if concurrency_id not in self.event_queue_per_concurrency_id:
self.event_queue_per_concurrency_id[concurrency_id] = EventQueue(
concurrency_id, concurrency_limit
)
elif (
concurrency_limit is not None
): # Update concurrency limit if it is lower than existing limit
existing_event_queue = self.event_queue_per_concurrency_id[
concurrency_id
]
if (
existing_event_queue.concurrency_limit is None
or concurrency_limit < existing_event_queue.concurrency_limit
):
existing_event_queue.concurrency_limit = concurrency_limit
run_coro_in_background(self.start_processing)
run_coro_in_background(self.start_progress_updates)
@ -119,11 +147,15 @@ class Queue:
message_type: str,
data: dict | None = None,
):
if not event.alive:
return
data = {} if data is None else data
messages = self.pending_messages_per_session[event.session_hash]
messages.put_nowait({"msg": message_type, "event_id": event._id, **data})
def _resolve_concurrency_limit(self, default_concurrency_limit):
def _resolve_concurrency_limit(
self, default_concurrency_limit: int | None | Literal["not_set"]
) -> int | None:
"""
Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination
of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT`
@ -143,6 +175,12 @@ class Queue:
else:
return 1
def __len__(self):
total_len = 0
for event_queue in self.event_queue_per_concurrency_id.values():
total_len += len(event_queue.queue)
return total_len
async def push(
self, body: PredictBody, request: fastapi.Request, username: str | None
) -> tuple[bool, str]:
@ -150,14 +188,19 @@ class Queue:
return False, "No session hash provided."
if body.fn_index is None:
return False, "No function index provided."
queue_len = len(self.event_queue)
if self.max_size is not None and queue_len >= self.max_size:
if self.max_size is not None and len(self) >= self.max_size:
return (
False,
f"Queue is full. Max size is {self.max_size} and size is {queue_len}.",
f"Queue is full. Max size is {self.max_size} and size is {len(self)}.",
)
event = Event(body.session_hash, body.fn_index, request, username)
event = Event(
body.session_hash,
body.fn_index,
request,
username,
self.block_fns[body.fn_index].concurrency_id,
)
event.data = body
async with self.pending_message_lock:
if body.session_hash not in self.pending_messages_per_session:
@ -165,10 +208,10 @@ class Queue:
if body.session_hash not in self.pending_event_ids_session:
self.pending_event_ids_session[body.session_hash] = set()
self.pending_event_ids_session[body.session_hash].add(event._id)
self.event_queue.append(event)
event_queue = self.event_queue_per_concurrency_id[event.concurrency_id]
event_queue.queue.append(event)
estimation = self.get_estimation()
await self.send_estimation(event, estimation, queue_len)
self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1)
return True, event._id
@ -187,88 +230,73 @@ class Queue:
count += 1
return count
def get_events_in_batch(self) -> tuple[list[Event] | None, bool]:
if not self.event_queue:
return None, False
worker_count_per_concurrency_id = {}
for job in self.active_jobs:
if job is not None:
for event in job:
concurrency_id = self.block_fns[event.fn_index].concurrency_id
worker_count_per_concurrency_id[concurrency_id] = (
worker_count_per_concurrency_id.get(concurrency_id, 0) + 1
)
events = []
batch = False
for index, event in enumerate(self.event_queue):
block_fn = self.block_fns[event.fn_index]
concurrency_id = block_fn.concurrency_id
concurrency_limit = self.concurrency_limit_per_concurrency_id.get(
concurrency_id, None
)
existing_worker_count = worker_count_per_concurrency_id.get(
concurrency_id, 0
)
if concurrency_limit is None or existing_worker_count < concurrency_limit:
def get_events(self) -> tuple[list[Event], bool, str] | None:
concurrency_ids = list(self.event_queue_per_concurrency_id.keys())
random.shuffle(concurrency_ids)
for concurrency_id in concurrency_ids:
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
if len(event_queue.queue) and (
event_queue.concurrency_limit is None
or event_queue.current_concurrency < event_queue.concurrency_limit
):
first_event = event_queue.queue[0]
block_fn = self.block_fns[first_event.fn_index]
events = [first_event]
batch = block_fn.batch
if batch:
batch_size = block_fn.max_batch_size
if concurrency_limit is None:
remaining_worker_count = batch_size - 1
else:
remaining_worker_count = (
concurrency_limit - existing_worker_count
)
rest_of_batch = [
events += [
event
for event in self.event_queue[index:]
if event.fn_index == event.fn_index
][: min(batch_size - 1, remaining_worker_count)]
events = [event] + rest_of_batch
else:
events = [event]
break
for event in event_queue.queue[1:]
if event.fn_index == first_event.fn_index
][: block_fn.max_batch_size - 1]
for event in events:
self.event_queue.remove(event)
for event in events:
event_queue.queue.remove(event)
return events, batch
return events, batch, concurrency_id
async def start_processing(self) -> None:
while not self.stopped:
if not self.event_queue:
await asyncio.sleep(self.sleep_when_free)
continue
try:
while not self.stopped:
if len(self) == 0:
await asyncio.sleep(self.sleep_when_free)
continue
if None not in self.active_jobs:
await asyncio.sleep(self.sleep_when_free)
continue
# Using mutex to avoid editing a list in use
async with self.delete_lock:
events, batch = self.get_events_in_batch()
if None not in self.active_jobs:
await asyncio.sleep(self.sleep_when_free)
continue
if events:
self.active_jobs[self.active_jobs.index(None)] = events
process_event_task = run_coro_in_background(
self.process_events, events, batch
)
set_task_name(
process_event_task,
events[0].session_hash,
events[0].fn_index,
batch,
)
# Using mutex to avoid editing a list in use
async with self.delete_lock:
event_batch = self.get_events()
self._asyncio_tasks.append(process_event_task)
if self.live_updates:
broadcast_live_estimations_task = run_coro_in_background(
self.broadcast_estimations
if event_batch:
events, batch, concurrency_id = event_batch
self.active_jobs[self.active_jobs.index(None)] = events
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
event_queue.current_concurrency += 1
start_time = time.time()
event_queue.start_times_per_fn_index[events[0].fn_index].add(
start_time
)
self._asyncio_tasks.append(broadcast_live_estimations_task)
else:
await asyncio.sleep(self.sleep_when_free)
process_event_task = run_coro_in_background(
self.process_events, events, batch, start_time
)
set_task_name(
process_event_task,
events[0].session_hash,
events[0].fn_index,
batch,
)
self._asyncio_tasks.append(process_event_task)
if self.live_updates:
self.broadcast_estimations(concurrency_id)
else:
await asyncio.sleep(self.sleep_when_free)
finally:
self.stopped = True
self._cancel_asyncio_tasks()
async def start_progress_updates(self) -> None:
"""
@ -345,14 +373,17 @@ class Queue:
if job.session_hash == session_hash or job._id == event_id:
job.alive = False
events_to_remove = []
for event in self.event_queue:
if event.session_hash == session_hash or event._id == event_id:
events_to_remove.append(event)
async with self.delete_lock:
events_to_remove: list[Event] = []
for event_queue in self.event_queue_per_concurrency_id.values():
for event in event_queue.queue:
if event.session_hash == session_hash or event._id == event_id:
events_to_remove.append(event)
for event in events_to_remove:
async with self.delete_lock:
self.event_queue.remove(event)
for event in events_to_remove:
self.event_queue_per_concurrency_id[event.concurrency_id].queue.remove(
event
)
async def notify_clients(self) -> None:
"""
@ -360,66 +391,65 @@ class Queue:
"""
while not self.stopped:
await asyncio.sleep(self.update_intervals)
if self.event_queue:
await self.broadcast_estimations()
if len(self) > 0:
for concurrency_id in self.event_queue_per_concurrency_id:
self.broadcast_estimations(concurrency_id)
async def broadcast_estimations(self) -> None:
estimation = self.get_estimation()
# Send all messages concurrently
await asyncio.gather(
*[
self.send_estimation(event, estimation, rank)
for rank, event in enumerate(self.event_queue)
]
)
def broadcast_estimations(
self, concurrency_id: str, after: int | None = None
) -> None:
wait_so_far = 0
event_queue = self.event_queue_per_concurrency_id[concurrency_id]
time_till_available_worker: int | None = 0
async def send_estimation(
self, event: Event, estimation: Estimation, rank: int
) -> Estimation:
"""
Send estimation about ETA to the client.
if event_queue.current_concurrency == event_queue.concurrency_limit:
expected_end_times = []
for fn_index, start_times in event_queue.start_times_per_fn_index.items():
if fn_index not in self.process_time_per_fn_index:
time_till_available_worker = None
break
process_time = self.process_time_per_fn_index[fn_index].avg_time
expected_end_times += [
start_time + process_time for start_time in start_times
]
if time_till_available_worker is not None and len(expected_end_times) > 0:
time_of_first_completion = min(expected_end_times)
time_till_available_worker = max(
time_of_first_completion - time.time(), 0
)
Parameters:
event:
estimation:
rank:
"""
estimation.rank = rank
if self.avg_concurrent_process_time is not None:
estimation.rank_eta = (
estimation.rank * self.avg_concurrent_process_time
+ self.avg_process_time
for rank, event in enumerate(event_queue.queue):
process_time_for_fn = (
self.process_time_per_fn_index[event.fn_index].avg_time
if event.fn_index in self.process_time_per_fn_index
else None
)
rank_eta = (
process_time_for_fn + wait_so_far + time_till_available_worker
if process_time_for_fn is not None
and wait_so_far is not None
and time_till_available_worker is not None
else None
)
if None not in self.active_jobs:
# Add estimated amount of time for a thread to get empty
estimation.rank_eta += self.avg_concurrent_process_time
self.send_message(event, ServerMessage.estimation, estimation.model_dump())
return estimation
def update_estimation(self, duration: float) -> None:
"""
Update estimation by last x element's average duration.
if after is None or rank >= after:
self.send_message(
event,
ServerMessage.estimation,
Estimation(
rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue)
).model_dump(),
)
if event_queue.concurrency_limit is None:
wait_so_far = 0
elif wait_so_far is not None and process_time_for_fn is not None:
wait_so_far += process_time_for_fn / event_queue.concurrency_limit
else:
wait_so_far = None
Parameters:
duration:
"""
self.duration_history_total += duration
self.duration_history_count += 1
self.avg_process_time = (
self.duration_history_total / self.duration_history_count
)
self.avg_concurrent_process_time = self.avg_process_time / min(
self.max_thread_count, self.duration_history_count
)
self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue)
def get_estimation(self) -> Estimation:
def get_status(self) -> Estimation:
return Estimation(
queue_size=len(self.event_queue),
avg_event_process_time=self.avg_process_time,
avg_event_concurrent_process_time=self.avg_concurrent_process_time,
queue_eta=self.queue_duration,
queue_size=len(self),
)
async def call_prediction(self, events: list[Event], batch: bool):
@ -484,20 +514,30 @@ class Queue:
return response_json
async def process_events(self, events: list[Event], batch: bool) -> None:
async def process_events(
self, events: list[Event], batch: bool, begin_time: float
) -> None:
awake_events: list[Event] = []
fn_index = events[0].fn_index
try:
for event in events:
self.send_message(event, ServerMessage.process_starts)
awake_events.append(event)
if event.alive:
self.send_message(
event,
ServerMessage.process_starts,
{
"eta": self.process_time_per_fn_index[fn_index].avg_time
if fn_index in self.process_time_per_fn_index
else None
},
)
awake_events.append(event)
if not awake_events:
return
begin_time = time.time()
try:
response = await self.call_prediction(awake_events, batch)
err = None
except Exception as e:
traceback.print_exc()
response = None
err = e
for event in awake_events:
@ -568,10 +608,17 @@ class Queue:
)
end_time = time.time()
if response is not None:
self.update_estimation(end_time - begin_time)
self.process_time_per_fn_index[events[0].fn_index].add(
end_time - begin_time
)
except Exception as e:
traceback.print_exc()
finally:
event_queue = self.event_queue_per_concurrency_id[events[0].concurrency_id]
event_queue.current_concurrency -= 1
start_times = event_queue.start_times_per_fn_index[fn_index]
if begin_time in start_times:
start_times.remove(begin_time)
try:
self.active_jobs[self.active_jobs.index(events)] = None
except ValueError:

View File

@ -673,6 +673,12 @@ class App(FastAPI):
if blocks._queue.server_app is None:
blocks._queue.set_server_app(app)
if blocks._queue.stopped:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Queue is stopped.",
)
success, event_id = await blocks._queue.push(body, request, username)
if not success:
status_code = (
@ -702,7 +708,7 @@ class App(FastAPI):
response_model=Estimation,
)
async def get_queue_status():
return app.get_blocks()._queue.get_estimation()
return app.get_blocks()._queue.get_status()
@app.get("/upload_progress")
def get_upload_progress(upload_id: str, request: fastapi.Request):

View File

@ -53,7 +53,6 @@
export let i18n: I18nFormatter;
export let eta: number | null = null;
export let queue = false;
export let queue_position: number | null;
export let queue_size: number | null;
export let status: "complete" | "pending" | "error" | "generating";
@ -75,6 +74,7 @@
let timer_start = 0;
let timer_diff = 0;
let old_eta: number | null = null;
let eta_from_start: number | null = null;
let message_visible = false;
let eta_level: number | null = 0;
let progress_level: (number | undefined)[] | null = null;
@ -83,9 +83,9 @@
let show_eta_bar = true;
$: eta_level =
eta === null || eta <= 0 || !timer_diff
eta_from_start === null || eta_from_start <= 0 || !timer_diff
? null
: Math.min(timer_diff / eta, 1);
: Math.min(timer_diff / eta_from_start, 1);
$: if (progress != null) {
show_eta_bar = false;
}
@ -119,6 +119,7 @@
}
const start_timer = (): void => {
eta = old_eta = formatted_eta = null;
timer_start = performance.now();
timer_diff = 0;
_timer = true;
@ -134,6 +135,7 @@
function stop_timer(): void {
timer_diff = 0;
eta = old_eta = formatted_eta = null;
if (!_timer) return;
_timer = false;
@ -160,11 +162,10 @@
$: {
if (eta === null) {
eta = old_eta;
} else if (queue) {
eta = (performance.now() - timer_start) / 1000 + eta;
}
if (eta != null) {
formatted_eta = eta.toFixed(1);
if (eta != null && old_eta !== eta) {
eta_from_start = (performance.now() - timer_start) / 1000 + eta;
formatted_eta = eta_from_start.toFixed(1);
old_eta = eta;
}
}