From 68a54a7a310d8d7072fdae930bf1cfdf12c45a7f Mon Sep 17 00:00:00 2001 From: aliabid94 Date: Wed, 31 Jan 2024 10:39:46 -0800 Subject: [PATCH] Improve chatbot streaming performance with diffs (#7102) * changes * add changeset * changes * add changeset * changes * channges * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * canges * changes * changes * changes * Update free-moose-guess.md * changes --------- Co-authored-by: Ali Abid Co-authored-by: gradio-pr-bot Co-authored-by: Ali Abid --- .changeset/free-moose-guess.md | 9 ++++ client/js/src/client.ts | 32 +++++++++++- client/js/src/types.ts | 2 +- client/js/src/utils.ts | 59 +++++++++++++++++++++ client/python/gradio_client/client.py | 20 ++++++-- client/python/gradio_client/utils.py | 74 ++++++++++++++++++++++++--- gradio/blocks.py | 45 +++++++++++++++- gradio/components/textbox.py | 4 +- gradio/queueing.py | 5 +- gradio/utils.py | 44 ++++++++++++++++ test/test_blocks.py | 43 ++++++++++++++++ 11 files changed, 312 insertions(+), 25 deletions(-) create mode 100644 .changeset/free-moose-guess.md diff --git a/.changeset/free-moose-guess.md b/.changeset/free-moose-guess.md new file mode 100644 index 0000000000..78c6673402 --- /dev/null +++ b/.changeset/free-moose-guess.md @@ -0,0 +1,9 @@ +--- +"@gradio/client": minor +"gradio": minor +"gradio_client": minor +--- + +feat:Improve chatbot streaming performance with diffs + +Note that this PR changes the API format for generator functions, which would be a breaking change for any clients reading the EventStream directly diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 6355d5a04c..00ae24048e 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -11,7 +11,8 @@ import { set_space_hardware, set_space_timeout, hardware_types, - resolve_root + resolve_root, + apply_diff } from "./utils.js"; import type { @@ -288,6 +289,7 @@ export function api_factory( const last_status: Record = {}; let stream_open = false; let pending_stream_messages: Record = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete. + let pending_diff_streams: Record = {}; let event_stream: EventSource | null = null; const event_callbacks: Record Promise> = {}; const unclosed_events: Set = new Set(); @@ -774,7 +776,8 @@ export function api_factory( } } }; - } else if (protocol == "sse_v1") { + } else if (protocol == "sse_v1" || protocol == "sse_v2") { + // latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter. fire_event({ type: "status", stage: "pending", @@ -867,6 +870,9 @@ export function api_factory( endpoint: _endpoint, fn_index }); + if (data && protocol === "sse_v2") { + apply_diff_stream(event_id!, data); + } } if (data) { fire_event({ @@ -904,6 +910,9 @@ export function api_factory( if (event_callbacks[event_id]) { delete event_callbacks[event_id]; } + if (event_id in pending_diff_streams) { + delete pending_diff_streams[event_id]; + } } } catch (e) { console.error("Unexpected client exception", e); @@ -936,6 +945,25 @@ export function api_factory( } ); + function apply_diff_stream(event_id: string, data: any): void { + let is_first_generation = !pending_diff_streams[event_id]; + if (is_first_generation) { + pending_diff_streams[event_id] = []; + data.data.forEach((value: any, i: number) => { + pending_diff_streams[event_id][i] = value; + }); + } else { + data.data.forEach((value: any, i: number) => { + let new_data = apply_diff( + pending_diff_streams[event_id][i], + value + ); + pending_diff_streams[event_id][i] = new_data; + data.data[i] = new_data; + }); + } + } + function fire_event(event: Event): void { const narrowed_listener_map: ListenerMap = listener_map; const listeners = narrowed_listener_map[event.type] || []; diff --git a/client/js/src/types.ts b/client/js/src/types.ts index 2b1869855e..4e93a762b5 100644 --- a/client/js/src/types.ts +++ b/client/js/src/types.ts @@ -20,7 +20,7 @@ export interface Config { show_api: boolean; stylesheets: string[]; path: string; - protocol?: "sse_v1" | "sse" | "ws"; + protocol?: "sse_v2" | "sse_v1" | "sse" | "ws"; } export interface Payload { diff --git a/client/js/src/utils.ts b/client/js/src/utils.ts index 5883cfe0b1..3683911356 100644 --- a/client/js/src/utils.ts +++ b/client/js/src/utils.ts @@ -239,3 +239,62 @@ export const hardware_types = [ "a10g-large", "a100-large" ] as const; + +function apply_edit( + target: any, + path: (number | string)[], + action: string, + value: any +): any { + if (path.length === 0) { + if (action === "replace") { + return value; + } else if (action === "append") { + return target + value; + } + throw new Error(`Unsupported action: ${action}`); + } + + let current = target; + for (let i = 0; i < path.length - 1; i++) { + current = current[path[i]]; + } + + const last_path = path[path.length - 1]; + switch (action) { + case "replace": + current[last_path] = value; + break; + case "append": + current[last_path] += value; + break; + case "add": + if (Array.isArray(current)) { + current.splice(Number(last_path), 0, value); + } else { + current[last_path] = value; + } + break; + case "delete": + if (Array.isArray(current)) { + current.splice(Number(last_path), 1); + } else { + delete current[last_path]; + } + break; + default: + throw new Error(`Unknown action: ${action}`); + } + return target; +} + +export function apply_diff( + obj: any, + diff: [string, (number | string)[], any][] +): any { + diff.forEach(([action, path, value]) => { + obj = apply_edit(obj, path, action, value); + }); + + return obj; +} diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 19a9cdb741..0748e7cc45 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -428,7 +428,12 @@ class Client: inferred_fn_index = self._infer_fn_index(api_name, fn_index) helper = None - if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"): + if self.endpoints[inferred_fn_index].protocol in ( + "ws", + "sse", + "sse_v1", + "sse_v2", + ): helper = self.new_helper(inferred_fn_index) end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper) future = self.executor.submit(end_to_end_fn, *args) @@ -998,13 +1003,15 @@ class Endpoint: result = utils.synchronize_async( self._sse_fn_v0, data, hash_data, helper ) - elif self.protocol == "sse_v1": + elif self.protocol == "sse_v1" or self.protocol == "sse_v2": event_id = utils.synchronize_async( self.client.send_data, data, hash_data ) self.client.pending_event_ids.add(event_id) self.client.pending_messages_per_event[event_id] = [] - result = utils.synchronize_async(self._sse_fn_v1, helper, event_id) + result = utils.synchronize_async( + self._sse_fn_v1_v2, helper, event_id, self.protocol + ) else: raise ValueError(f"Unsupported protocol: {self.protocol}") @@ -1197,13 +1204,16 @@ class Endpoint: self.client.cookies, ) - async def _sse_fn_v1(self, helper: Communicator, event_id: str): - return await utils.get_pred_from_sse_v1( + async def _sse_fn_v1_v2( + self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"] + ): + return await utils.get_pred_from_sse_v1_v2( helper, self.client.headers, self.client.cookies, self.client.pending_messages_per_event, event_id, + protocol, ) diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 630f16c2aa..d520f556e7 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import base64 +import copy import hashlib import json import mimetypes @@ -17,7 +18,7 @@ from datetime import datetime from enum import Enum from pathlib import Path from threading import Lock -from typing import Any, Callable, Optional, TypedDict +from typing import Any, Callable, Literal, Optional, TypedDict import fsspec.asyn import httpx @@ -381,22 +382,19 @@ async def get_pred_from_sse_v0( return task.result() -async def get_pred_from_sse_v1( +async def get_pred_from_sse_v1_v2( helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None, pending_messages_per_event: dict[str, list[Message | None]], event_id: str, + protocol: Literal["sse_v1", "sse_v2"], ) -> dict[str, Any] | None: done, pending = await asyncio.wait( [ asyncio.create_task(check_for_cancel(helper, headers, cookies)), asyncio.create_task( - stream_sse_v1( - helper, - pending_messages_per_event, - event_id, - ) + stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol) ), ], return_when=asyncio.FIRST_COMPLETED, @@ -411,6 +409,9 @@ async def get_pred_from_sse_v1( assert len(done) == 1 for task in done: + exception = task.exception() + if exception: + raise exception return task.result() @@ -502,13 +503,15 @@ async def stream_sse_v0( raise -async def stream_sse_v1( +async def stream_sse_v1_v2( helper: Communicator, pending_messages_per_event: dict[str, list[Message | None]], event_id: str, + protocol: Literal["sse_v1", "sse_v2"], ) -> dict[str, Any]: try: pending_messages = pending_messages_per_event[event_id] + pending_responses_for_diffs = None while True: if len(pending_messages) > 0: @@ -540,6 +543,19 @@ async def stream_sse_v1( log=log_message, ) output = msg.get("output", {}).get("data", []) + if ( + msg["msg"] == ServerMessage.process_generating + and protocol == "sse_v2" + ): + if pending_responses_for_diffs is None: + pending_responses_for_diffs = list(output) + else: + for i, value in enumerate(output): + prev_output = pending_responses_for_diffs[i] + new_output = apply_diff(prev_output, value) + pending_responses_for_diffs[i] = new_output + output[i] = new_output + if output and status_update.code != Status.FINISHED: try: result = helper.prediction_processor(*output) @@ -557,6 +573,48 @@ async def stream_sse_v1( raise +def apply_diff(obj, diff): + obj = copy.deepcopy(obj) + + def apply_edit(target, path, action, value): + if len(path) == 0: + if action == "replace": + return value + elif action == "append": + return target + value + else: + raise ValueError(f"Unsupported action: {action}") + + current = target + for i in range(len(path) - 1): + current = current[path[i]] + + last_path = path[-1] + if action == "replace": + current[last_path] = value + elif action == "append": + current[last_path] += value + elif action == "add": + if isinstance(current, list): + current.insert(int(last_path), value) + else: + current[last_path] = value + elif action == "delete": + if isinstance(current, list): + del current[int(last_path)] + else: + del current[last_path] + else: + raise ValueError(f"Unknown action: {action}") + + return target + + for action, path, value in diff: + obj = apply_edit(obj, path, action, value) + + return obj + + ######################## # Data processing utils ######################## diff --git a/gradio/blocks.py b/gradio/blocks.py index b431e2306c..f9661086d5 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -539,6 +539,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): self.enable_queue = True self.max_threads = 40 self.pending_streams = defaultdict(dict) + self.pending_diff_streams = defaultdict(dict) self.show_error = True self.head = head if css is not None and os.path.exists(css): @@ -1483,6 +1484,38 @@ Received outputs: data[i] = output_data return data + def handle_streaming_diffs( + self, + fn_index: int, + data: list, + session_hash: str | None, + run: int | None, + final: bool, + ) -> list: + if session_hash is None or run is None: + return data + first_run = run not in self.pending_diff_streams[session_hash] + if first_run: + self.pending_diff_streams[session_hash][run] = [None] * len(data) + last_diffs = self.pending_diff_streams[session_hash][run] + + for i in range(len(self.dependencies[fn_index]["outputs"])): + if final: + data[i] = last_diffs[i] + continue + + if first_run: + last_diffs[i] = data[i] + else: + prev_chunk = last_diffs[i] + last_diffs[i] = data[i] + data[i] = utils.diff(prev_chunk, data[i]) + + if final: + del self.pending_diff_streams[session_hash][run] + + return data + async def process_api( self, fn_index: int, @@ -1565,11 +1598,19 @@ Received outputs: data = self.postprocess_data(fn_index, result["prediction"], state) is_generating, iterator = result["is_generating"], result["iterator"] if is_generating or was_generating: + run = id(old_iterator) if was_generating else id(iterator) data = self.handle_streaming_outputs( fn_index, data, session_hash=session_hash, - run=id(old_iterator) if was_generating else id(iterator), + run=run, + ) + data = self.handle_streaming_diffs( + fn_index, + data, + session_hash=session_hash, + run=run, + final=not is_generating, ) block_fn.total_runtime += result["duration"] @@ -1611,7 +1652,7 @@ Received outputs: "is_colab": utils.colab_check(), "stylesheets": self.stylesheets, "theme": self.theme.name, - "protocol": "sse_v1", + "protocol": "sse_v2", "body_css": { "body_background_fill": self.theme._get_computed_value( "body_background_fill" diff --git a/gradio/components/textbox.py b/gradio/components/textbox.py index c1be9029c0..6dd9b8c731 100644 --- a/gradio/components/textbox.py +++ b/gradio/components/textbox.py @@ -6,9 +6,7 @@ from typing import Any, Callable, Literal from gradio_client.documentation import document, set_documentation_group -from gradio.components.base import ( - FormComponent, -) +from gradio.components.base import FormComponent from gradio.events import Events set_documentation_group("component") diff --git a/gradio/queueing.py b/gradio/queueing.py index c871d7989e..8ed2f5cf76 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -584,10 +584,7 @@ class Queue: response = None err = e for event in awake_events: - if response is None: - relevant_response = err - else: - relevant_response = old_response or old_err + relevant_response = response or err or old_err self.send_message( event, ServerMessage.process_completed, diff --git a/gradio/utils.py b/gradio/utils.py index a819953856..ce6021375e 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1043,3 +1043,47 @@ class LRUCache(OrderedDict, Generic[K, V]): def get_cache_folder() -> Path: return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples")) + + +def diff(old, new): + def compare_objects(obj1, obj2, path=None): + if path is None: + path = [] + edits = [] + + if obj1 == obj2: + return edits + + if type(obj1) != type(obj2): + edits.append(("replace", path, obj2)) + return edits + + if isinstance(obj1, str) and obj2.startswith(obj1): + edits.append(("append", path, obj2[len(obj1) :])) + return edits + + if isinstance(obj1, list): + common_length = min(len(obj1), len(obj2)) + for i in range(common_length): + edits.extend(compare_objects(obj1[i], obj2[i], path + [i])) + for i in range(common_length, len(obj1)): + edits.append(("delete", path + [i], None)) + for i in range(common_length, len(obj2)): + edits.append(("add", path + [i], obj2[i])) + return edits + + if isinstance(obj1, dict): + for key in obj1: + if key in obj2: + edits.extend(compare_objects(obj1[key], obj2[key], path + [key])) + else: + edits.append(("delete", path + [key], None)) + for key in obj2: + if key not in obj1: + edits.append(("add", path + [key], obj2[key])) + return edits + + edits.append(("replace", path, obj2)) + return edits + + return compare_objects(old, new) diff --git a/test/test_blocks.py b/test/test_blocks.py index cc4627c4ba..e01c4dd730 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -265,6 +265,49 @@ class TestBlocksMethods: assert outputs == ["a", "b", "c"] demo.queue().launch(prevent_thread_lock=True) + def test_varying_output_forms_with_generators(self, connect): + generations = [ + {"a": 1}, + {"a": 1, "b": [1, 3]}, + {"b": [1, 3, 2]}, + 1, + 2, + 3, + [1, 2, {"x": 4, "y": 6}], + {"data": [1, 2, {"x": 4, "y": 6}]}, + None, + 1.2, + ] + + def generator(): + yield from generations + + def generator_random(): + indices = list(range(len(generations))) + random.shuffle(indices) + for i in indices: + time.sleep(random.random() / 5) + yield generations[i] + + with gr.Blocks() as demo: + btn1 = gr.Button() + btn2 = gr.Button() + output_json = gr.JSON() + btn1.click(generator, None, output_json, api_name="generator") + btn2.click(generator_random, None, output_json, api_name="generator_random") + + with connect(demo) as client: + outputs = [] + for output in client.submit(api_name="/generator"): + outputs.append(output) + assert outputs == generations + + outputs = [] + for output in client.submit(api_name="/generator_random"): + outputs.append(output) + for generation in generations: + assert generation in outputs + def test_socket_reuse(self): try: io = gr.Interface(lambda x: x, gr.Textbox(), gr.Textbox())