mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Make <UploadProgress /> Wasm-compatible (#6965)
* Make <UploadProgress /> Wasm-compatible * add changeset * Fix <DownloadLink /> not to prefetch the data in the Wasm mode * add changeset * Fix <DownloadLink /> to check the `window` object existence for SSR * Lite: Fix and improve the file upload progress SSE (#6978) * Update the Wasm ASGI connection to be able handle ReadableStream, which is used for example in <Upload />" * Fix wasm_proxied_fetch() not to pass a leading '?' in the query_string to WorkerProxy.httpRequest() because it's required by the ASGI spec * Fix FileUploadProgress.update() to merge a new item to the existing one in the queue * Fix the SSE stream async task in the /upload_progress endpoint removing an unreached code block. `await asyncio.sleep()` has been moved from the unreached block to a live location, so the stream cadence has been reduced * Fix `FileUploadProgress` to manage the `is_done` flag independent from the queue because it has a different semantics and checking it is a priority over reading other progress events to abort the SSE stream when uploading is done * Refactoring --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>
This commit is contained in:
parent
0f0498bf97
commit
5d00dd37ca
8
.changeset/fair-webs-type.md
Normal file
8
.changeset/fair-webs-type.md
Normal file
@ -0,0 +1,8 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"@gradio/upload": minor
|
||||
"@gradio/wasm": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Make <UploadProgress /> Wasm-compatible
|
@ -298,38 +298,70 @@ class GradioUploadFile(UploadFile):
|
||||
class FileUploadProgressUnit:
|
||||
filename: str
|
||||
chunk_size: int
|
||||
|
||||
|
||||
@python_dataclass
|
||||
class FileUploadProgressTracker:
|
||||
deque: deque[FileUploadProgressUnit]
|
||||
is_done: bool
|
||||
|
||||
|
||||
class FileUploadProgressNotTrackedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FileUploadProgressNotQueuedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FileUploadProgress:
|
||||
def __init__(self) -> None:
|
||||
self._statuses: dict[str, deque[FileUploadProgressUnit]] = {}
|
||||
self._statuses: dict[str, FileUploadProgressTracker] = {}
|
||||
|
||||
def track(self, upload_id: str):
|
||||
if upload_id not in self._statuses:
|
||||
self._statuses[upload_id] = deque()
|
||||
self._statuses[upload_id] = FileUploadProgressTracker(deque(), False)
|
||||
|
||||
def update(self, upload_id: str, filename: str, message_bytes: bytes):
|
||||
def append(self, upload_id: str, filename: str, message_bytes: bytes):
|
||||
if upload_id not in self._statuses:
|
||||
self._statuses[upload_id] = deque()
|
||||
self._statuses[upload_id].append(
|
||||
FileUploadProgressUnit(filename, len(message_bytes), is_done=False)
|
||||
)
|
||||
self.track(upload_id)
|
||||
queue = self._statuses[upload_id].deque
|
||||
|
||||
if len(queue) == 0:
|
||||
queue.append(FileUploadProgressUnit(filename, len(message_bytes)))
|
||||
else:
|
||||
last_unit = queue.popleft()
|
||||
if last_unit.filename != filename:
|
||||
queue.append(FileUploadProgressUnit(filename, len(message_bytes)))
|
||||
else:
|
||||
queue.append(
|
||||
FileUploadProgressUnit(
|
||||
filename,
|
||||
last_unit.chunk_size + len(message_bytes),
|
||||
)
|
||||
)
|
||||
|
||||
def set_done(self, upload_id: str):
|
||||
self._statuses[upload_id].append(FileUploadProgressUnit("", 0, is_done=True))
|
||||
if upload_id not in self._statuses:
|
||||
self.track(upload_id)
|
||||
self._statuses[upload_id].is_done = True
|
||||
|
||||
def is_done(self, upload_id: str):
|
||||
if upload_id not in self._statuses:
|
||||
raise FileUploadProgressNotTrackedError()
|
||||
return self._statuses[upload_id].is_done
|
||||
|
||||
def stop_tracking(self, upload_id: str):
|
||||
if upload_id in self._statuses:
|
||||
del self._statuses[upload_id]
|
||||
|
||||
def status(self, upload_id: str) -> deque[FileUploadProgressUnit]:
|
||||
def pop(self, upload_id: str) -> FileUploadProgressUnit:
|
||||
if upload_id not in self._statuses:
|
||||
return deque()
|
||||
return self._statuses[upload_id]
|
||||
|
||||
def is_tracked(self, upload_id: str):
|
||||
return upload_id in self._statuses
|
||||
raise FileUploadProgressNotTrackedError()
|
||||
try:
|
||||
return self._statuses[upload_id].deque.pop()
|
||||
except IndexError as e:
|
||||
raise FileUploadProgressNotQueuedError() from e
|
||||
|
||||
|
||||
class GradioMultiPartParser:
|
||||
@ -382,7 +414,7 @@ class GradioMultiPartParser:
|
||||
def on_part_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message_bytes = data[start:end]
|
||||
if self.upload_progress is not None:
|
||||
self.upload_progress.update(
|
||||
self.upload_progress.append(
|
||||
self.upload_id, # type: ignore
|
||||
self._current_part.file.filename, # type: ignore
|
||||
message_bytes,
|
||||
|
@ -58,6 +58,8 @@ from gradio.oauth import attach_oauth
|
||||
from gradio.queueing import Estimation
|
||||
from gradio.route_utils import ( # noqa: F401
|
||||
FileUploadProgress,
|
||||
FileUploadProgressNotQueuedError,
|
||||
FileUploadProgressNotTrackedError,
|
||||
GradioMultiPartParser,
|
||||
GradioUploadFile,
|
||||
MultiPartException,
|
||||
@ -732,30 +734,26 @@ class App(FastAPI):
|
||||
|
||||
heartbeat_rate = 15
|
||||
check_rate = 0.05
|
||||
message = None
|
||||
try:
|
||||
if update := file_upload_statuses.status(upload_id).popleft():
|
||||
if update.is_done:
|
||||
message = {"msg": "done"}
|
||||
is_done = True
|
||||
else:
|
||||
message = {
|
||||
"msg": "update",
|
||||
"orig_name": update.filename,
|
||||
"chunk_size": update.chunk_size,
|
||||
}
|
||||
if file_upload_statuses.is_done(upload_id):
|
||||
message = {"msg": "done"}
|
||||
is_done = True
|
||||
else:
|
||||
await asyncio.sleep(check_rate)
|
||||
if time.perf_counter() - last_heartbeat > heartbeat_rate:
|
||||
message = {"msg": "heartbeat"}
|
||||
last_heartbeat = time.perf_counter()
|
||||
if message:
|
||||
update = file_upload_statuses.pop(upload_id)
|
||||
message = {
|
||||
"msg": "update",
|
||||
"orig_name": update.filename,
|
||||
"chunk_size": update.chunk_size,
|
||||
}
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
except FileUploadProgressNotTrackedError:
|
||||
return
|
||||
except FileUploadProgressNotQueuedError:
|
||||
await asyncio.sleep(check_rate)
|
||||
if time.perf_counter() - last_heartbeat > heartbeat_rate:
|
||||
message = {"msg": "heartbeat"}
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
except IndexError:
|
||||
if not file_upload_statuses.is_tracked(upload_id):
|
||||
return
|
||||
# pop from empty queue
|
||||
continue
|
||||
last_heartbeat = time.perf_counter()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_stream(request),
|
||||
|
@ -99,6 +99,9 @@
|
||||
}
|
||||
export let fetch_implementation: typeof fetch = fetch;
|
||||
setContext("fetch_implementation", fetch_implementation);
|
||||
export let EventSource_factory: (url: URL) => EventSource = (url) =>
|
||||
new EventSource(url);
|
||||
setContext("EventSource_factory", EventSource_factory);
|
||||
|
||||
export let space: string | null;
|
||||
export let host: string | null;
|
||||
|
@ -44,16 +44,12 @@ export async function wasm_proxied_fetch(
|
||||
headers[key] = value;
|
||||
});
|
||||
|
||||
const bodyArrayBuffer = await new Response(request.body).arrayBuffer();
|
||||
const body: Parameters<WorkerProxy["httpRequest"]>[0]["body"] =
|
||||
new Uint8Array(bodyArrayBuffer);
|
||||
|
||||
const response = await workerProxy.httpRequest({
|
||||
path: url.pathname,
|
||||
query_string: url.search,
|
||||
query_string: url.searchParams.toString(), // The `query_string` field in the ASGI spec must not contain the leading `?`.
|
||||
method,
|
||||
headers,
|
||||
body
|
||||
body: request.body
|
||||
});
|
||||
return new Response(response.body, {
|
||||
status: response.status,
|
||||
|
@ -161,7 +161,8 @@ export function create(options: Options): GradioAppController {
|
||||
client,
|
||||
upload_files,
|
||||
mount_css: overridden_mount_css,
|
||||
fetch_implementation: overridden_fetch
|
||||
fetch_implementation: overridden_fetch,
|
||||
EventSource_factory
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { FileData } from "@gradio/client";
|
||||
import { onMount, createEventDispatcher } from "svelte";
|
||||
import { onMount, createEventDispatcher, getContext } from "svelte";
|
||||
|
||||
type FileDataWithProgress = FileData & { progress: number };
|
||||
|
||||
@ -36,9 +36,12 @@
|
||||
return (file.progress * 100) / (file.size || 0) || 0;
|
||||
}
|
||||
|
||||
const EventSource_factory = getContext<(url: URL) => EventSource>(
|
||||
"EventSource_factory"
|
||||
);
|
||||
onMount(() => {
|
||||
event_source = new EventSource(
|
||||
`${root}/upload_progress?upload_id=${upload_id}`
|
||||
event_source = EventSource_factory(
|
||||
new URL(`${root}/upload_progress?upload_id=${upload_id}`)
|
||||
);
|
||||
// Event listener for progress updates
|
||||
event_source.onmessage = async function (event) {
|
||||
|
@ -1,9 +1,9 @@
|
||||
export interface HttpRequest {
|
||||
method: "GET" | "POST" | "PUT" | "DELETE";
|
||||
path: string;
|
||||
query_string: string;
|
||||
query_string: string; // This field must not contain the leading `?`, as it's directly used in the ASGI spec which requires this.
|
||||
headers: Record<string, string>;
|
||||
body?: Uint8Array;
|
||||
body?: Uint8Array | ReadableStream<Uint8Array> | null;
|
||||
}
|
||||
export interface HttpResponse {
|
||||
status: number;
|
||||
|
@ -245,11 +245,37 @@ export class WorkerProxy extends EventTarget {
|
||||
|
||||
asgiMessagePort.start();
|
||||
|
||||
asgiMessagePort.postMessage({
|
||||
type: "http.request",
|
||||
more_body: false,
|
||||
body: request.body
|
||||
} satisfies ReceiveEvent);
|
||||
if (request.body instanceof ReadableStream) {
|
||||
// The following code reading the stream is based on the example in https://developer.mozilla.org/en-US/docs/Web/API/ReadableStream/getReader#examples
|
||||
const reader = request.body.getReader();
|
||||
reader.read().then(function process({
|
||||
done,
|
||||
value
|
||||
}): Promise<void> | void {
|
||||
if (done) {
|
||||
asgiMessagePort.postMessage({
|
||||
type: "http.request",
|
||||
more_body: false,
|
||||
body: undefined
|
||||
} satisfies ReceiveEvent);
|
||||
return;
|
||||
}
|
||||
|
||||
asgiMessagePort.postMessage({
|
||||
type: "http.request",
|
||||
more_body: !done,
|
||||
body: value
|
||||
} satisfies ReceiveEvent);
|
||||
|
||||
return reader.read().then(process);
|
||||
});
|
||||
} else {
|
||||
asgiMessagePort.postMessage({
|
||||
type: "http.request",
|
||||
more_body: false,
|
||||
body: request.body ?? undefined
|
||||
} satisfies ReceiveEvent);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -8,18 +8,78 @@
|
||||
}
|
||||
type $$Props = DownloadLinkAttributes;
|
||||
|
||||
import { resolve_wasm_src } from ".";
|
||||
import { getWorkerProxyContext } from "./context";
|
||||
import { should_proxy_wasm_src } from "./file-url";
|
||||
import { getHeaderValue } from "../src/http";
|
||||
|
||||
export let href: DownloadLinkAttributes["href"] = undefined;
|
||||
export let download: DownloadLinkAttributes["download"];
|
||||
|
||||
const dispatch = createEventDispatcher();
|
||||
|
||||
let is_downloading = false;
|
||||
const worker_proxy = getWorkerProxyContext();
|
||||
async function wasm_click_handler(): Promise<void> {
|
||||
if (is_downloading) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch("click");
|
||||
|
||||
if (href == null) {
|
||||
throw new Error("href is not defined.");
|
||||
}
|
||||
if (worker_proxy == null) {
|
||||
throw new Error("Wasm worker proxy is not available.");
|
||||
}
|
||||
|
||||
const url = new URL(href);
|
||||
const path = url.pathname;
|
||||
|
||||
is_downloading = true;
|
||||
worker_proxy
|
||||
.httpRequest({
|
||||
method: "GET",
|
||||
path,
|
||||
headers: {},
|
||||
query_string: ""
|
||||
})
|
||||
.then((response) => {
|
||||
if (response.status !== 200) {
|
||||
throw new Error(`Failed to get file ${path} from the Wasm worker.`);
|
||||
}
|
||||
const blob = new Blob([response.body], {
|
||||
type: getHeaderValue(response.headers, "content-type")
|
||||
});
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = blobUrl;
|
||||
link.download = download;
|
||||
link.click();
|
||||
|
||||
URL.revokeObjectURL(blobUrl);
|
||||
})
|
||||
.finally(() => {
|
||||
is_downloading = false;
|
||||
});
|
||||
}
|
||||
</script>
|
||||
|
||||
{#await resolve_wasm_src(href) then resolved_href}
|
||||
{#if worker_proxy && should_proxy_wasm_src(href)}
|
||||
{#if is_downloading}
|
||||
<slot />
|
||||
{:else}
|
||||
<a {...$$restProps} {href} on:click|preventDefault={wasm_click_handler}>
|
||||
<slot />
|
||||
</a>
|
||||
{/if}
|
||||
{:else}
|
||||
<a
|
||||
href={resolved_href}
|
||||
target={window.__is_colab__ ? "_blank" : null}
|
||||
{href}
|
||||
target={typeof window !== "undefined" && window.__is_colab__
|
||||
? "_blank"
|
||||
: null}
|
||||
rel="noopener noreferrer"
|
||||
{download}
|
||||
{...$$restProps}
|
||||
@ -27,6 +87,4 @@
|
||||
>
|
||||
<slot />
|
||||
</a>
|
||||
{:catch error}
|
||||
<p style="color: red;">{error.message}</p>
|
||||
{/await}
|
||||
{/if}
|
||||
|
@ -4,18 +4,26 @@ import { getHeaderValue } from "../src/http";
|
||||
|
||||
type MediaSrc = string | undefined | null;
|
||||
|
||||
export async function resolve_wasm_src(src: MediaSrc): Promise<MediaSrc> {
|
||||
export function should_proxy_wasm_src(src: MediaSrc): boolean {
|
||||
if (src == null) {
|
||||
return src;
|
||||
return false;
|
||||
}
|
||||
|
||||
const url = new URL(src);
|
||||
if (!is_self_host(url)) {
|
||||
// `src` is not accessing a local server resource, so we don't need to proxy this request to the Wasm worker.
|
||||
return src;
|
||||
return false;
|
||||
}
|
||||
if (url.protocol !== "http:" && url.protocol !== "https:") {
|
||||
// `src` can be a data URL.
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
export async function resolve_wasm_src(src: MediaSrc): Promise<MediaSrc> {
|
||||
if (src == null || !should_proxy_wasm_src(src)) {
|
||||
return src;
|
||||
}
|
||||
|
||||
@ -25,6 +33,7 @@ export async function resolve_wasm_src(src: MediaSrc): Promise<MediaSrc> {
|
||||
return src;
|
||||
}
|
||||
|
||||
const url = new URL(src);
|
||||
const path = url.pathname;
|
||||
return maybeWorkerProxy
|
||||
.httpRequest({
|
||||
|
Loading…
Reference in New Issue
Block a user