Make <UploadProgress /> Wasm-compatible (#6965)

* Make <UploadProgress /> Wasm-compatible

* add changeset

* Fix <DownloadLink /> not to prefetch the data in the Wasm mode

* add changeset

* Fix <DownloadLink /> to check the `window` object existence for SSR

* Lite: Fix and improve the file upload progress SSE (#6978)

* Update the Wasm ASGI connection to be able handle ReadableStream, which is used for example in <Upload />"

* Fix wasm_proxied_fetch() not to pass a leading '?' in the query_string to WorkerProxy.httpRequest() because it's required by the ASGI spec

* Fix FileUploadProgress.update() to merge a new item to the existing one in the queue

* Fix the SSE stream async task in the /upload_progress endpoint removing an unreached code block. `await asyncio.sleep()` has been moved from the unreached block to a live location, so the stream cadence has been reduced

* Fix `FileUploadProgress` to manage the `is_done` flag independent from the queue because it has a different semantics and checking it is a priority over reading other progress events to abort the SSE stream when uploading is done

* Refactoring

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2024-01-11 02:27:58 +09:00 committed by GitHub
parent 0f0498bf97
commit 5d00dd37ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 197 additions and 63 deletions

View File

@ -0,0 +1,8 @@
---
"@gradio/app": minor
"@gradio/upload": minor
"@gradio/wasm": minor
"gradio": minor
---
feat:Make <UploadProgress /> Wasm-compatible

View File

@ -298,38 +298,70 @@ class GradioUploadFile(UploadFile):
class FileUploadProgressUnit:
filename: str
chunk_size: int
@python_dataclass
class FileUploadProgressTracker:
deque: deque[FileUploadProgressUnit]
is_done: bool
class FileUploadProgressNotTrackedError(Exception):
pass
class FileUploadProgressNotQueuedError(Exception):
pass
class FileUploadProgress:
def __init__(self) -> None:
self._statuses: dict[str, deque[FileUploadProgressUnit]] = {}
self._statuses: dict[str, FileUploadProgressTracker] = {}
def track(self, upload_id: str):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()
self._statuses[upload_id] = FileUploadProgressTracker(deque(), False)
def update(self, upload_id: str, filename: str, message_bytes: bytes):
def append(self, upload_id: str, filename: str, message_bytes: bytes):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()
self._statuses[upload_id].append(
FileUploadProgressUnit(filename, len(message_bytes), is_done=False)
)
self.track(upload_id)
queue = self._statuses[upload_id].deque
if len(queue) == 0:
queue.append(FileUploadProgressUnit(filename, len(message_bytes)))
else:
last_unit = queue.popleft()
if last_unit.filename != filename:
queue.append(FileUploadProgressUnit(filename, len(message_bytes)))
else:
queue.append(
FileUploadProgressUnit(
filename,
last_unit.chunk_size + len(message_bytes),
)
)
def set_done(self, upload_id: str):
self._statuses[upload_id].append(FileUploadProgressUnit("", 0, is_done=True))
if upload_id not in self._statuses:
self.track(upload_id)
self._statuses[upload_id].is_done = True
def is_done(self, upload_id: str):
if upload_id not in self._statuses:
raise FileUploadProgressNotTrackedError()
return self._statuses[upload_id].is_done
def stop_tracking(self, upload_id: str):
if upload_id in self._statuses:
del self._statuses[upload_id]
def status(self, upload_id: str) -> deque[FileUploadProgressUnit]:
def pop(self, upload_id: str) -> FileUploadProgressUnit:
if upload_id not in self._statuses:
return deque()
return self._statuses[upload_id]
def is_tracked(self, upload_id: str):
return upload_id in self._statuses
raise FileUploadProgressNotTrackedError()
try:
return self._statuses[upload_id].deque.pop()
except IndexError as e:
raise FileUploadProgressNotQueuedError() from e
class GradioMultiPartParser:
@ -382,7 +414,7 @@ class GradioMultiPartParser:
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self.upload_progress is not None:
self.upload_progress.update(
self.upload_progress.append(
self.upload_id, # type: ignore
self._current_part.file.filename, # type: ignore
message_bytes,

View File

@ -58,6 +58,8 @@ from gradio.oauth import attach_oauth
from gradio.queueing import Estimation
from gradio.route_utils import ( # noqa: F401
FileUploadProgress,
FileUploadProgressNotQueuedError,
FileUploadProgressNotTrackedError,
GradioMultiPartParser,
GradioUploadFile,
MultiPartException,
@ -732,30 +734,26 @@ class App(FastAPI):
heartbeat_rate = 15
check_rate = 0.05
message = None
try:
if update := file_upload_statuses.status(upload_id).popleft():
if update.is_done:
message = {"msg": "done"}
is_done = True
else:
message = {
"msg": "update",
"orig_name": update.filename,
"chunk_size": update.chunk_size,
}
if file_upload_statuses.is_done(upload_id):
message = {"msg": "done"}
is_done = True
else:
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
message = {"msg": "heartbeat"}
last_heartbeat = time.perf_counter()
if message:
update = file_upload_statuses.pop(upload_id)
message = {
"msg": "update",
"orig_name": update.filename,
"chunk_size": update.chunk_size,
}
yield f"data: {json.dumps(message)}\n\n"
except FileUploadProgressNotTrackedError:
return
except FileUploadProgressNotQueuedError:
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
message = {"msg": "heartbeat"}
yield f"data: {json.dumps(message)}\n\n"
except IndexError:
if not file_upload_statuses.is_tracked(upload_id):
return
# pop from empty queue
continue
last_heartbeat = time.perf_counter()
return StreamingResponse(
sse_stream(request),

View File

@ -99,6 +99,9 @@
}
export let fetch_implementation: typeof fetch = fetch;
setContext("fetch_implementation", fetch_implementation);
export let EventSource_factory: (url: URL) => EventSource = (url) =>
new EventSource(url);
setContext("EventSource_factory", EventSource_factory);
export let space: string | null;
export let host: string | null;

View File

@ -44,16 +44,12 @@ export async function wasm_proxied_fetch(
headers[key] = value;
});
const bodyArrayBuffer = await new Response(request.body).arrayBuffer();
const body: Parameters<WorkerProxy["httpRequest"]>[0]["body"] =
new Uint8Array(bodyArrayBuffer);
const response = await workerProxy.httpRequest({
path: url.pathname,
query_string: url.search,
query_string: url.searchParams.toString(), // The `query_string` field in the ASGI spec must not contain the leading `?`.
method,
headers,
body
body: request.body
});
return new Response(response.body, {
status: response.status,

View File

@ -161,7 +161,8 @@ export function create(options: Options): GradioAppController {
client,
upload_files,
mount_css: overridden_mount_css,
fetch_implementation: overridden_fetch
fetch_implementation: overridden_fetch,
EventSource_factory
}
});
}

View File

@ -1,6 +1,6 @@
<script lang="ts">
import { FileData } from "@gradio/client";
import { onMount, createEventDispatcher } from "svelte";
import { onMount, createEventDispatcher, getContext } from "svelte";
type FileDataWithProgress = FileData & { progress: number };
@ -36,9 +36,12 @@
return (file.progress * 100) / (file.size || 0) || 0;
}
const EventSource_factory = getContext<(url: URL) => EventSource>(
"EventSource_factory"
);
onMount(() => {
event_source = new EventSource(
`${root}/upload_progress?upload_id=${upload_id}`
event_source = EventSource_factory(
new URL(`${root}/upload_progress?upload_id=${upload_id}`)
);
// Event listener for progress updates
event_source.onmessage = async function (event) {

View File

@ -1,9 +1,9 @@
export interface HttpRequest {
method: "GET" | "POST" | "PUT" | "DELETE";
path: string;
query_string: string;
query_string: string; // This field must not contain the leading `?`, as it's directly used in the ASGI spec which requires this.
headers: Record<string, string>;
body?: Uint8Array;
body?: Uint8Array | ReadableStream<Uint8Array> | null;
}
export interface HttpResponse {
status: number;

View File

@ -245,11 +245,37 @@ export class WorkerProxy extends EventTarget {
asgiMessagePort.start();
asgiMessagePort.postMessage({
type: "http.request",
more_body: false,
body: request.body
} satisfies ReceiveEvent);
if (request.body instanceof ReadableStream) {
// The following code reading the stream is based on the example in https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream/getReader#examples
const reader = request.body.getReader();
reader.read().then(function process({
done,
value
}): Promise<void> | void {
if (done) {
asgiMessagePort.postMessage({
type: "http.request",
more_body: false,
body: undefined
} satisfies ReceiveEvent);
return;
}
asgiMessagePort.postMessage({
type: "http.request",
more_body: !done,
body: value
} satisfies ReceiveEvent);
return reader.read().then(process);
});
} else {
asgiMessagePort.postMessage({
type: "http.request",
more_body: false,
body: request.body ?? undefined
} satisfies ReceiveEvent);
}
});
}

View File

@ -8,18 +8,78 @@
}
type $$Props = DownloadLinkAttributes;
import { resolve_wasm_src } from ".";
import { getWorkerProxyContext } from "./context";
import { should_proxy_wasm_src } from "./file-url";
import { getHeaderValue } from "../src/http";
export let href: DownloadLinkAttributes["href"] = undefined;
export let download: DownloadLinkAttributes["download"];
const dispatch = createEventDispatcher();
let is_downloading = false;
const worker_proxy = getWorkerProxyContext();
async function wasm_click_handler(): Promise<void> {
if (is_downloading) {
return;
}
dispatch("click");
if (href == null) {
throw new Error("href is not defined.");
}
if (worker_proxy == null) {
throw new Error("Wasm worker proxy is not available.");
}
const url = new URL(href);
const path = url.pathname;
is_downloading = true;
worker_proxy
.httpRequest({
method: "GET",
path,
headers: {},
query_string: ""
})
.then((response) => {
if (response.status !== 200) {
throw new Error(`Failed to get file ${path} from the Wasm worker.`);
}
const blob = new Blob([response.body], {
type: getHeaderValue(response.headers, "content-type")
});
const blobUrl = URL.createObjectURL(blob);
const link = document.createElement("a");
link.href = blobUrl;
link.download = download;
link.click();
URL.revokeObjectURL(blobUrl);
})
.finally(() => {
is_downloading = false;
});
}
</script>
{#await resolve_wasm_src(href) then resolved_href}
{#if worker_proxy && should_proxy_wasm_src(href)}
{#if is_downloading}
<slot />
{:else}
<a {...$$restProps} {href} on:click|preventDefault={wasm_click_handler}>
<slot />
</a>
{/if}
{:else}
<a
href={resolved_href}
target={window.__is_colab__ ? "_blank" : null}
{href}
target={typeof window !== "undefined" && window.__is_colab__
? "_blank"
: null}
rel="noopener noreferrer"
{download}
{...$$restProps}
@ -27,6 +87,4 @@
>
<slot />
</a>
{:catch error}
<p style="color: red;">{error.message}</p>
{/await}
{/if}

View File

@ -4,18 +4,26 @@ import { getHeaderValue } from "../src/http";
type MediaSrc = string | undefined | null;
export async function resolve_wasm_src(src: MediaSrc): Promise<MediaSrc> {
export function should_proxy_wasm_src(src: MediaSrc): boolean {
if (src == null) {
return src;
return false;
}
const url = new URL(src);
if (!is_self_host(url)) {
// `src` is not accessing a local server resource, so we don't need to proxy this request to the Wasm worker.
return src;
return false;
}
if (url.protocol !== "http:" && url.protocol !== "https:") {
// `src` can be a data URL.
return false;
}
return true;
}
export async function resolve_wasm_src(src: MediaSrc): Promise<MediaSrc> {
if (src == null || !should_proxy_wasm_src(src)) {
return src;
}
@ -25,6 +33,7 @@ export async function resolve_wasm_src(src: MediaSrc): Promise<MediaSrc> {
return src;
}
const url = new URL(src);
const path = url.pathname;
return maybeWorkerProxy
.httpRequest({