Improve chatbot streaming performance with diffs (#7102)

* changes

* add changeset

* changes

* add changeset

* changes

* channges

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* canges

* changes

* changes

* changes

* Update free-moose-guess.md

* changes

---------

Co-authored-by: Ali Abid <aliabid@Alis-MacBook-Pro.local>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Ali Abid <aliabid94@gmail.com>
This commit is contained in:
aliabid94 2024-01-31 10:39:46 -08:00 committed by GitHub
parent 6a7e98bfef
commit 68a54a7a31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 312 additions and 25 deletions

View File

@ -0,0 +1,9 @@
---
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---
feat:Improve chatbot streaming performance with diffs
Note that this PR changes the API format for generator functions, which would be a breaking change for any clients reading the EventStream directly

View File

@ -11,7 +11,8 @@ import {
set_space_hardware,
set_space_timeout,
hardware_types,
resolve_root
resolve_root,
apply_diff
} from "./utils.js";
import type {
@ -288,6 +289,7 @@ export function api_factory(
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
let pending_diff_streams: Record<string, any[][]> = {};
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
const unclosed_events: Set<string> = new Set();
@ -774,7 +776,8 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1") {
} else if (protocol == "sse_v1" || protocol == "sse_v2") {
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
fire_event({
type: "status",
stage: "pending",
@ -867,6 +870,9 @@ export function api_factory(
endpoint: _endpoint,
fn_index
});
if (data && protocol === "sse_v2") {
apply_diff_stream(event_id!, data);
}
}
if (data) {
fire_event({
@ -904,6 +910,9 @@ export function api_factory(
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
}
if (event_id in pending_diff_streams) {
delete pending_diff_streams[event_id];
}
}
} catch (e) {
console.error("Unexpected client exception", e);
@ -936,6 +945,25 @@ export function api_factory(
}
);
function apply_diff_stream(event_id: string, data: any): void {
let is_first_generation = !pending_diff_streams[event_id];
if (is_first_generation) {
pending_diff_streams[event_id] = [];
data.data.forEach((value: any, i: number) => {
pending_diff_streams[event_id][i] = value;
});
} else {
data.data.forEach((value: any, i: number) => {
let new_data = apply_diff(
pending_diff_streams[event_id][i],
value
);
pending_diff_streams[event_id][i] = new_data;
data.data[i] = new_data;
});
}
}
function fire_event<K extends EventType>(event: Event<K>): void {
const narrowed_listener_map: ListenerMap<K> = listener_map;
const listeners = narrowed_listener_map[event.type] || [];

View File

@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse_v1" | "sse" | "ws";
protocol?: "sse_v2" | "sse_v1" | "sse" | "ws";
}
export interface Payload {

View File

@ -239,3 +239,62 @@ export const hardware_types = [
"a10g-large",
"a100-large"
] as const;
function apply_edit(
target: any,
path: (number | string)[],
action: string,
value: any
): any {
if (path.length === 0) {
if (action === "replace") {
return value;
} else if (action === "append") {
return target + value;
}
throw new Error(`Unsupported action: ${action}`);
}
let current = target;
for (let i = 0; i < path.length - 1; i++) {
current = current[path[i]];
}
const last_path = path[path.length - 1];
switch (action) {
case "replace":
current[last_path] = value;
break;
case "append":
current[last_path] += value;
break;
case "add":
if (Array.isArray(current)) {
current.splice(Number(last_path), 0, value);
} else {
current[last_path] = value;
}
break;
case "delete":
if (Array.isArray(current)) {
current.splice(Number(last_path), 1);
} else {
delete current[last_path];
}
break;
default:
throw new Error(`Unknown action: ${action}`);
}
return target;
}
export function apply_diff(
obj: any,
diff: [string, (number | string)[], any][]
): any {
diff.forEach(([action, path, value]) => {
obj = apply_edit(obj, path, action, value);
});
return obj;
}

View File

@ -428,7 +428,12 @@ class Client:
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
helper = None
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
if self.endpoints[inferred_fn_index].protocol in (
"ws",
"sse",
"sse_v1",
"sse_v2",
):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
@ -998,13 +1003,15 @@ class Endpoint:
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol == "sse_v1":
elif self.protocol == "sse_v1" or self.protocol == "sse_v2":
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
result = utils.synchronize_async(
self._sse_fn_v1_v2, helper, event_id, self.protocol
)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")
@ -1197,13 +1204,16 @@ class Endpoint:
self.client.cookies,
)
async def _sse_fn_v1(self, helper: Communicator, event_id: str):
return await utils.get_pred_from_sse_v1(
async def _sse_fn_v1_v2(
self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"]
):
return await utils.get_pred_from_sse_v1_v2(
helper,
self.client.headers,
self.client.cookies,
self.client.pending_messages_per_event,
event_id,
protocol,
)

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import base64
import copy
import hashlib
import json
import mimetypes
@ -17,7 +18,7 @@ from datetime import datetime
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Optional, TypedDict
from typing import Any, Callable, Literal, Optional, TypedDict
import fsspec.asyn
import httpx
@ -381,22 +382,19 @@ async def get_pred_from_sse_v0(
return task.result()
async def get_pred_from_sse_v1(
async def get_pred_from_sse_v1_v2(
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(
stream_sse_v1(
helper,
pending_messages_per_event,
event_id,
)
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
),
],
return_when=asyncio.FIRST_COMPLETED,
@ -411,6 +409,9 @@ async def get_pred_from_sse_v1(
assert len(done) == 1
for task in done:
exception = task.exception()
if exception:
raise exception
return task.result()
@ -502,13 +503,15 @@ async def stream_sse_v0(
raise
async def stream_sse_v1(
async def stream_sse_v1_v2(
helper: Communicator,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
) -> dict[str, Any]:
try:
pending_messages = pending_messages_per_event[event_id]
pending_responses_for_diffs = None
while True:
if len(pending_messages) > 0:
@ -540,6 +543,19 @@ async def stream_sse_v1(
log=log_message,
)
output = msg.get("output", {}).get("data", [])
if (
msg["msg"] == ServerMessage.process_generating
and protocol == "sse_v2"
):
if pending_responses_for_diffs is None:
pending_responses_for_diffs = list(output)
else:
for i, value in enumerate(output):
prev_output = pending_responses_for_diffs[i]
new_output = apply_diff(prev_output, value)
pending_responses_for_diffs[i] = new_output
output[i] = new_output
if output and status_update.code != Status.FINISHED:
try:
result = helper.prediction_processor(*output)
@ -557,6 +573,48 @@ async def stream_sse_v1(
raise
def apply_diff(obj, diff):
obj = copy.deepcopy(obj)
def apply_edit(target, path, action, value):
if len(path) == 0:
if action == "replace":
return value
elif action == "append":
return target + value
else:
raise ValueError(f"Unsupported action: {action}")
current = target
for i in range(len(path) - 1):
current = current[path[i]]
last_path = path[-1]
if action == "replace":
current[last_path] = value
elif action == "append":
current[last_path] += value
elif action == "add":
if isinstance(current, list):
current.insert(int(last_path), value)
else:
current[last_path] = value
elif action == "delete":
if isinstance(current, list):
del current[int(last_path)]
else:
del current[last_path]
else:
raise ValueError(f"Unknown action: {action}")
return target
for action, path, value in diff:
obj = apply_edit(obj, path, action, value)
return obj
########################
# Data processing utils
########################

View File

@ -539,6 +539,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
self.enable_queue = True
self.max_threads = 40
self.pending_streams = defaultdict(dict)
self.pending_diff_streams = defaultdict(dict)
self.show_error = True
self.head = head
if css is not None and os.path.exists(css):
@ -1483,6 +1484,38 @@ Received outputs:
data[i] = output_data
return data
def handle_streaming_diffs(
self,
fn_index: int,
data: list,
session_hash: str | None,
run: int | None,
final: bool,
) -> list:
if session_hash is None or run is None:
return data
first_run = run not in self.pending_diff_streams[session_hash]
if first_run:
self.pending_diff_streams[session_hash][run] = [None] * len(data)
last_diffs = self.pending_diff_streams[session_hash][run]
for i in range(len(self.dependencies[fn_index]["outputs"])):
if final:
data[i] = last_diffs[i]
continue
if first_run:
last_diffs[i] = data[i]
else:
prev_chunk = last_diffs[i]
last_diffs[i] = data[i]
data[i] = utils.diff(prev_chunk, data[i])
if final:
del self.pending_diff_streams[session_hash][run]
return data
async def process_api(
self,
fn_index: int,
@ -1565,11 +1598,19 @@ Received outputs:
data = self.postprocess_data(fn_index, result["prediction"], state)
is_generating, iterator = result["is_generating"], result["iterator"]
if is_generating or was_generating:
run = id(old_iterator) if was_generating else id(iterator)
data = self.handle_streaming_outputs(
fn_index,
data,
session_hash=session_hash,
run=id(old_iterator) if was_generating else id(iterator),
run=run,
)
data = self.handle_streaming_diffs(
fn_index,
data,
session_hash=session_hash,
run=run,
final=not is_generating,
)
block_fn.total_runtime += result["duration"]
@ -1611,7 +1652,7 @@ Received outputs:
"is_colab": utils.colab_check(),
"stylesheets": self.stylesheets,
"theme": self.theme.name,
"protocol": "sse_v1",
"protocol": "sse_v2",
"body_css": {
"body_background_fill": self.theme._get_computed_value(
"body_background_fill"

View File

@ -6,9 +6,7 @@ from typing import Any, Callable, Literal
from gradio_client.documentation import document, set_documentation_group
from gradio.components.base import (
FormComponent,
)
from gradio.components.base import FormComponent
from gradio.events import Events
set_documentation_group("component")

View File

@ -584,10 +584,7 @@ class Queue:
response = None
err = e
for event in awake_events:
if response is None:
relevant_response = err
else:
relevant_response = old_response or old_err
relevant_response = response or err or old_err
self.send_message(
event,
ServerMessage.process_completed,

View File

@ -1043,3 +1043,47 @@ class LRUCache(OrderedDict, Generic[K, V]):
def get_cache_folder() -> Path:
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))
def diff(old, new):
def compare_objects(obj1, obj2, path=None):
if path is None:
path = []
edits = []
if obj1 == obj2:
return edits
if type(obj1) != type(obj2):
edits.append(("replace", path, obj2))
return edits
if isinstance(obj1, str) and obj2.startswith(obj1):
edits.append(("append", path, obj2[len(obj1) :]))
return edits
if isinstance(obj1, list):
common_length = min(len(obj1), len(obj2))
for i in range(common_length):
edits.extend(compare_objects(obj1[i], obj2[i], path + [i]))
for i in range(common_length, len(obj1)):
edits.append(("delete", path + [i], None))
for i in range(common_length, len(obj2)):
edits.append(("add", path + [i], obj2[i]))
return edits
if isinstance(obj1, dict):
for key in obj1:
if key in obj2:
edits.extend(compare_objects(obj1[key], obj2[key], path + [key]))
else:
edits.append(("delete", path + [key], None))
for key in obj2:
if key not in obj1:
edits.append(("add", path + [key], obj2[key]))
return edits
edits.append(("replace", path, obj2))
return edits
return compare_objects(old, new)

View File

@ -265,6 +265,49 @@ class TestBlocksMethods:
assert outputs == ["a", "b", "c"]
demo.queue().launch(prevent_thread_lock=True)
def test_varying_output_forms_with_generators(self, connect):
generations = [
{"a": 1},
{"a": 1, "b": [1, 3]},
{"b": [1, 3, 2]},
1,
2,
3,
[1, 2, {"x": 4, "y": 6}],
{"data": [1, 2, {"x": 4, "y": 6}]},
None,
1.2,
]
def generator():
yield from generations
def generator_random():
indices = list(range(len(generations)))
random.shuffle(indices)
for i in indices:
time.sleep(random.random() / 5)
yield generations[i]
with gr.Blocks() as demo:
btn1 = gr.Button()
btn2 = gr.Button()
output_json = gr.JSON()
btn1.click(generator, None, output_json, api_name="generator")
btn2.click(generator_random, None, output_json, api_name="generator_random")
with connect(demo) as client:
outputs = []
for output in client.submit(api_name="/generator"):
outputs.append(output)
assert outputs == generations
outputs = []
for output in client.submit(api_name="/generator_random"):
outputs.append(output)
for generation in generations:
assert generation in outputs
def test_socket_reuse(self):
try:
io = gr.Interface(lambda x: x, gr.Textbox(), gr.Textbox())