mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
Improve chatbot streaming performance with diffs (#7102)
* changes * add changeset * changes * add changeset * changes * channges * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * canges * changes * changes * changes * Update free-moose-guess.md * changes --------- Co-authored-by: Ali Abid <aliabid@Alis-MacBook-Pro.local> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Ali Abid <aliabid94@gmail.com>
This commit is contained in:
parent
6a7e98bfef
commit
68a54a7a31
9
.changeset/free-moose-guess.md
Normal file
9
.changeset/free-moose-guess.md
Normal file
@ -0,0 +1,9 @@
|
||||
---
|
||||
"@gradio/client": minor
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
feat:Improve chatbot streaming performance with diffs
|
||||
|
||||
Note that this PR changes the API format for generator functions, which would be a breaking change for any clients reading the EventStream directly
|
@ -11,7 +11,8 @@ import {
|
||||
set_space_hardware,
|
||||
set_space_timeout,
|
||||
hardware_types,
|
||||
resolve_root
|
||||
resolve_root,
|
||||
apply_diff
|
||||
} from "./utils.js";
|
||||
|
||||
import type {
|
||||
@ -288,6 +289,7 @@ export function api_factory(
|
||||
const last_status: Record<string, Status["stage"]> = {};
|
||||
let stream_open = false;
|
||||
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
|
||||
let pending_diff_streams: Record<string, any[][]> = {};
|
||||
let event_stream: EventSource | null = null;
|
||||
const event_callbacks: Record<string, () => Promise<void>> = {};
|
||||
const unclosed_events: Set<string> = new Set();
|
||||
@ -774,7 +776,8 @@ export function api_factory(
|
||||
}
|
||||
}
|
||||
};
|
||||
} else if (protocol == "sse_v1") {
|
||||
} else if (protocol == "sse_v1" || protocol == "sse_v2") {
|
||||
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
|
||||
fire_event({
|
||||
type: "status",
|
||||
stage: "pending",
|
||||
@ -867,6 +870,9 @@ export function api_factory(
|
||||
endpoint: _endpoint,
|
||||
fn_index
|
||||
});
|
||||
if (data && protocol === "sse_v2") {
|
||||
apply_diff_stream(event_id!, data);
|
||||
}
|
||||
}
|
||||
if (data) {
|
||||
fire_event({
|
||||
@ -904,6 +910,9 @@ export function api_factory(
|
||||
if (event_callbacks[event_id]) {
|
||||
delete event_callbacks[event_id];
|
||||
}
|
||||
if (event_id in pending_diff_streams) {
|
||||
delete pending_diff_streams[event_id];
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Unexpected client exception", e);
|
||||
@ -936,6 +945,25 @@ export function api_factory(
|
||||
}
|
||||
);
|
||||
|
||||
function apply_diff_stream(event_id: string, data: any): void {
|
||||
let is_first_generation = !pending_diff_streams[event_id];
|
||||
if (is_first_generation) {
|
||||
pending_diff_streams[event_id] = [];
|
||||
data.data.forEach((value: any, i: number) => {
|
||||
pending_diff_streams[event_id][i] = value;
|
||||
});
|
||||
} else {
|
||||
data.data.forEach((value: any, i: number) => {
|
||||
let new_data = apply_diff(
|
||||
pending_diff_streams[event_id][i],
|
||||
value
|
||||
);
|
||||
pending_diff_streams[event_id][i] = new_data;
|
||||
data.data[i] = new_data;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function fire_event<K extends EventType>(event: Event<K>): void {
|
||||
const narrowed_listener_map: ListenerMap<K> = listener_map;
|
||||
const listeners = narrowed_listener_map[event.type] || [];
|
||||
|
@ -20,7 +20,7 @@ export interface Config {
|
||||
show_api: boolean;
|
||||
stylesheets: string[];
|
||||
path: string;
|
||||
protocol?: "sse_v1" | "sse" | "ws";
|
||||
protocol?: "sse_v2" | "sse_v1" | "sse" | "ws";
|
||||
}
|
||||
|
||||
export interface Payload {
|
||||
|
@ -239,3 +239,62 @@ export const hardware_types = [
|
||||
"a10g-large",
|
||||
"a100-large"
|
||||
] as const;
|
||||
|
||||
function apply_edit(
|
||||
target: any,
|
||||
path: (number | string)[],
|
||||
action: string,
|
||||
value: any
|
||||
): any {
|
||||
if (path.length === 0) {
|
||||
if (action === "replace") {
|
||||
return value;
|
||||
} else if (action === "append") {
|
||||
return target + value;
|
||||
}
|
||||
throw new Error(`Unsupported action: ${action}`);
|
||||
}
|
||||
|
||||
let current = target;
|
||||
for (let i = 0; i < path.length - 1; i++) {
|
||||
current = current[path[i]];
|
||||
}
|
||||
|
||||
const last_path = path[path.length - 1];
|
||||
switch (action) {
|
||||
case "replace":
|
||||
current[last_path] = value;
|
||||
break;
|
||||
case "append":
|
||||
current[last_path] += value;
|
||||
break;
|
||||
case "add":
|
||||
if (Array.isArray(current)) {
|
||||
current.splice(Number(last_path), 0, value);
|
||||
} else {
|
||||
current[last_path] = value;
|
||||
}
|
||||
break;
|
||||
case "delete":
|
||||
if (Array.isArray(current)) {
|
||||
current.splice(Number(last_path), 1);
|
||||
} else {
|
||||
delete current[last_path];
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown action: ${action}`);
|
||||
}
|
||||
return target;
|
||||
}
|
||||
|
||||
export function apply_diff(
|
||||
obj: any,
|
||||
diff: [string, (number | string)[], any][]
|
||||
): any {
|
||||
diff.forEach(([action, path, value]) => {
|
||||
obj = apply_edit(obj, path, action, value);
|
||||
});
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
@ -428,7 +428,12 @@ class Client:
|
||||
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
|
||||
|
||||
helper = None
|
||||
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
|
||||
if self.endpoints[inferred_fn_index].protocol in (
|
||||
"ws",
|
||||
"sse",
|
||||
"sse_v1",
|
||||
"sse_v2",
|
||||
):
|
||||
helper = self.new_helper(inferred_fn_index)
|
||||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
|
||||
future = self.executor.submit(end_to_end_fn, *args)
|
||||
@ -998,13 +1003,15 @@ class Endpoint:
|
||||
result = utils.synchronize_async(
|
||||
self._sse_fn_v0, data, hash_data, helper
|
||||
)
|
||||
elif self.protocol == "sse_v1":
|
||||
elif self.protocol == "sse_v1" or self.protocol == "sse_v2":
|
||||
event_id = utils.synchronize_async(
|
||||
self.client.send_data, data, hash_data
|
||||
)
|
||||
self.client.pending_event_ids.add(event_id)
|
||||
self.client.pending_messages_per_event[event_id] = []
|
||||
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
|
||||
result = utils.synchronize_async(
|
||||
self._sse_fn_v1_v2, helper, event_id, self.protocol
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported protocol: {self.protocol}")
|
||||
|
||||
@ -1197,13 +1204,16 @@ class Endpoint:
|
||||
self.client.cookies,
|
||||
)
|
||||
|
||||
async def _sse_fn_v1(self, helper: Communicator, event_id: str):
|
||||
return await utils.get_pred_from_sse_v1(
|
||||
async def _sse_fn_v1_v2(
|
||||
self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"]
|
||||
):
|
||||
return await utils.get_pred_from_sse_v1_v2(
|
||||
helper,
|
||||
self.client.headers,
|
||||
self.client.cookies,
|
||||
self.client.pending_messages_per_event,
|
||||
event_id,
|
||||
protocol,
|
||||
)
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
@ -17,7 +18,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Optional, TypedDict
|
||||
from typing import Any, Callable, Literal, Optional, TypedDict
|
||||
|
||||
import fsspec.asyn
|
||||
import httpx
|
||||
@ -381,22 +382,19 @@ async def get_pred_from_sse_v0(
|
||||
return task.result()
|
||||
|
||||
|
||||
async def get_pred_from_sse_v1(
|
||||
async def get_pred_from_sse_v1_v2(
|
||||
helper: Communicator,
|
||||
headers: dict[str, str],
|
||||
cookies: dict[str, str] | None,
|
||||
pending_messages_per_event: dict[str, list[Message | None]],
|
||||
event_id: str,
|
||||
protocol: Literal["sse_v1", "sse_v2"],
|
||||
) -> dict[str, Any] | None:
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
|
||||
asyncio.create_task(
|
||||
stream_sse_v1(
|
||||
helper,
|
||||
pending_messages_per_event,
|
||||
event_id,
|
||||
)
|
||||
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
|
||||
),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
@ -411,6 +409,9 @@ async def get_pred_from_sse_v1(
|
||||
|
||||
assert len(done) == 1
|
||||
for task in done:
|
||||
exception = task.exception()
|
||||
if exception:
|
||||
raise exception
|
||||
return task.result()
|
||||
|
||||
|
||||
@ -502,13 +503,15 @@ async def stream_sse_v0(
|
||||
raise
|
||||
|
||||
|
||||
async def stream_sse_v1(
|
||||
async def stream_sse_v1_v2(
|
||||
helper: Communicator,
|
||||
pending_messages_per_event: dict[str, list[Message | None]],
|
||||
event_id: str,
|
||||
protocol: Literal["sse_v1", "sse_v2"],
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
pending_messages = pending_messages_per_event[event_id]
|
||||
pending_responses_for_diffs = None
|
||||
|
||||
while True:
|
||||
if len(pending_messages) > 0:
|
||||
@ -540,6 +543,19 @@ async def stream_sse_v1(
|
||||
log=log_message,
|
||||
)
|
||||
output = msg.get("output", {}).get("data", [])
|
||||
if (
|
||||
msg["msg"] == ServerMessage.process_generating
|
||||
and protocol == "sse_v2"
|
||||
):
|
||||
if pending_responses_for_diffs is None:
|
||||
pending_responses_for_diffs = list(output)
|
||||
else:
|
||||
for i, value in enumerate(output):
|
||||
prev_output = pending_responses_for_diffs[i]
|
||||
new_output = apply_diff(prev_output, value)
|
||||
pending_responses_for_diffs[i] = new_output
|
||||
output[i] = new_output
|
||||
|
||||
if output and status_update.code != Status.FINISHED:
|
||||
try:
|
||||
result = helper.prediction_processor(*output)
|
||||
@ -557,6 +573,48 @@ async def stream_sse_v1(
|
||||
raise
|
||||
|
||||
|
||||
def apply_diff(obj, diff):
|
||||
obj = copy.deepcopy(obj)
|
||||
|
||||
def apply_edit(target, path, action, value):
|
||||
if len(path) == 0:
|
||||
if action == "replace":
|
||||
return value
|
||||
elif action == "append":
|
||||
return target + value
|
||||
else:
|
||||
raise ValueError(f"Unsupported action: {action}")
|
||||
|
||||
current = target
|
||||
for i in range(len(path) - 1):
|
||||
current = current[path[i]]
|
||||
|
||||
last_path = path[-1]
|
||||
if action == "replace":
|
||||
current[last_path] = value
|
||||
elif action == "append":
|
||||
current[last_path] += value
|
||||
elif action == "add":
|
||||
if isinstance(current, list):
|
||||
current.insert(int(last_path), value)
|
||||
else:
|
||||
current[last_path] = value
|
||||
elif action == "delete":
|
||||
if isinstance(current, list):
|
||||
del current[int(last_path)]
|
||||
else:
|
||||
del current[last_path]
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
|
||||
return target
|
||||
|
||||
for action, path, value in diff:
|
||||
obj = apply_edit(obj, path, action, value)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
########################
|
||||
# Data processing utils
|
||||
########################
|
||||
|
@ -539,6 +539,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
self.enable_queue = True
|
||||
self.max_threads = 40
|
||||
self.pending_streams = defaultdict(dict)
|
||||
self.pending_diff_streams = defaultdict(dict)
|
||||
self.show_error = True
|
||||
self.head = head
|
||||
if css is not None and os.path.exists(css):
|
||||
@ -1483,6 +1484,38 @@ Received outputs:
|
||||
data[i] = output_data
|
||||
return data
|
||||
|
||||
def handle_streaming_diffs(
|
||||
self,
|
||||
fn_index: int,
|
||||
data: list,
|
||||
session_hash: str | None,
|
||||
run: int | None,
|
||||
final: bool,
|
||||
) -> list:
|
||||
if session_hash is None or run is None:
|
||||
return data
|
||||
first_run = run not in self.pending_diff_streams[session_hash]
|
||||
if first_run:
|
||||
self.pending_diff_streams[session_hash][run] = [None] * len(data)
|
||||
last_diffs = self.pending_diff_streams[session_hash][run]
|
||||
|
||||
for i in range(len(self.dependencies[fn_index]["outputs"])):
|
||||
if final:
|
||||
data[i] = last_diffs[i]
|
||||
continue
|
||||
|
||||
if first_run:
|
||||
last_diffs[i] = data[i]
|
||||
else:
|
||||
prev_chunk = last_diffs[i]
|
||||
last_diffs[i] = data[i]
|
||||
data[i] = utils.diff(prev_chunk, data[i])
|
||||
|
||||
if final:
|
||||
del self.pending_diff_streams[session_hash][run]
|
||||
|
||||
return data
|
||||
|
||||
async def process_api(
|
||||
self,
|
||||
fn_index: int,
|
||||
@ -1565,11 +1598,19 @@ Received outputs:
|
||||
data = self.postprocess_data(fn_index, result["prediction"], state)
|
||||
is_generating, iterator = result["is_generating"], result["iterator"]
|
||||
if is_generating or was_generating:
|
||||
run = id(old_iterator) if was_generating else id(iterator)
|
||||
data = self.handle_streaming_outputs(
|
||||
fn_index,
|
||||
data,
|
||||
session_hash=session_hash,
|
||||
run=id(old_iterator) if was_generating else id(iterator),
|
||||
run=run,
|
||||
)
|
||||
data = self.handle_streaming_diffs(
|
||||
fn_index,
|
||||
data,
|
||||
session_hash=session_hash,
|
||||
run=run,
|
||||
final=not is_generating,
|
||||
)
|
||||
|
||||
block_fn.total_runtime += result["duration"]
|
||||
@ -1611,7 +1652,7 @@ Received outputs:
|
||||
"is_colab": utils.colab_check(),
|
||||
"stylesheets": self.stylesheets,
|
||||
"theme": self.theme.name,
|
||||
"protocol": "sse_v1",
|
||||
"protocol": "sse_v2",
|
||||
"body_css": {
|
||||
"body_background_fill": self.theme._get_computed_value(
|
||||
"body_background_fill"
|
||||
|
@ -6,9 +6,7 @@ from typing import Any, Callable, Literal
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio.components.base import (
|
||||
FormComponent,
|
||||
)
|
||||
from gradio.components.base import FormComponent
|
||||
from gradio.events import Events
|
||||
|
||||
set_documentation_group("component")
|
||||
|
@ -584,10 +584,7 @@ class Queue:
|
||||
response = None
|
||||
err = e
|
||||
for event in awake_events:
|
||||
if response is None:
|
||||
relevant_response = err
|
||||
else:
|
||||
relevant_response = old_response or old_err
|
||||
relevant_response = response or err or old_err
|
||||
self.send_message(
|
||||
event,
|
||||
ServerMessage.process_completed,
|
||||
|
@ -1043,3 +1043,47 @@ class LRUCache(OrderedDict, Generic[K, V]):
|
||||
|
||||
def get_cache_folder() -> Path:
|
||||
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))
|
||||
|
||||
|
||||
def diff(old, new):
|
||||
def compare_objects(obj1, obj2, path=None):
|
||||
if path is None:
|
||||
path = []
|
||||
edits = []
|
||||
|
||||
if obj1 == obj2:
|
||||
return edits
|
||||
|
||||
if type(obj1) != type(obj2):
|
||||
edits.append(("replace", path, obj2))
|
||||
return edits
|
||||
|
||||
if isinstance(obj1, str) and obj2.startswith(obj1):
|
||||
edits.append(("append", path, obj2[len(obj1) :]))
|
||||
return edits
|
||||
|
||||
if isinstance(obj1, list):
|
||||
common_length = min(len(obj1), len(obj2))
|
||||
for i in range(common_length):
|
||||
edits.extend(compare_objects(obj1[i], obj2[i], path + [i]))
|
||||
for i in range(common_length, len(obj1)):
|
||||
edits.append(("delete", path + [i], None))
|
||||
for i in range(common_length, len(obj2)):
|
||||
edits.append(("add", path + [i], obj2[i]))
|
||||
return edits
|
||||
|
||||
if isinstance(obj1, dict):
|
||||
for key in obj1:
|
||||
if key in obj2:
|
||||
edits.extend(compare_objects(obj1[key], obj2[key], path + [key]))
|
||||
else:
|
||||
edits.append(("delete", path + [key], None))
|
||||
for key in obj2:
|
||||
if key not in obj1:
|
||||
edits.append(("add", path + [key], obj2[key]))
|
||||
return edits
|
||||
|
||||
edits.append(("replace", path, obj2))
|
||||
return edits
|
||||
|
||||
return compare_objects(old, new)
|
||||
|
@ -265,6 +265,49 @@ class TestBlocksMethods:
|
||||
assert outputs == ["a", "b", "c"]
|
||||
demo.queue().launch(prevent_thread_lock=True)
|
||||
|
||||
def test_varying_output_forms_with_generators(self, connect):
|
||||
generations = [
|
||||
{"a": 1},
|
||||
{"a": 1, "b": [1, 3]},
|
||||
{"b": [1, 3, 2]},
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
[1, 2, {"x": 4, "y": 6}],
|
||||
{"data": [1, 2, {"x": 4, "y": 6}]},
|
||||
None,
|
||||
1.2,
|
||||
]
|
||||
|
||||
def generator():
|
||||
yield from generations
|
||||
|
||||
def generator_random():
|
||||
indices = list(range(len(generations)))
|
||||
random.shuffle(indices)
|
||||
for i in indices:
|
||||
time.sleep(random.random() / 5)
|
||||
yield generations[i]
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
btn1 = gr.Button()
|
||||
btn2 = gr.Button()
|
||||
output_json = gr.JSON()
|
||||
btn1.click(generator, None, output_json, api_name="generator")
|
||||
btn2.click(generator_random, None, output_json, api_name="generator_random")
|
||||
|
||||
with connect(demo) as client:
|
||||
outputs = []
|
||||
for output in client.submit(api_name="/generator"):
|
||||
outputs.append(output)
|
||||
assert outputs == generations
|
||||
|
||||
outputs = []
|
||||
for output in client.submit(api_name="/generator_random"):
|
||||
outputs.append(output)
|
||||
for generation in generations:
|
||||
assert generation in outputs
|
||||
|
||||
def test_socket_reuse(self):
|
||||
try:
|
||||
io = gr.Interface(lambda x: x, gr.Textbox(), gr.Textbox())
|
||||
|
Loading…
Reference in New Issue
Block a user