Provide status updates on file uploads (#6307)

* Backend

* Backend

* add changeset

* Clean up + close connection

* Lint

* Fix tests

* Apply opacity transition

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2023-11-07 19:02:31 -05:00 committed by GitHub
parent e9bb445a63
commit f1409f95ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 281 additions and 41 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/client": minor
"@gradio/upload": minor
"gradio": minor
---
feat:Provide status updates on file uploads

View File

@ -156,7 +156,8 @@ interface Client {
upload_files: (
root: string,
files: File[],
token?: `hf_${string}`
token?: `hf_${string}`,
upload_id?: string
) => Promise<UploadResponse>;
client: (
app_reference: string,
@ -208,7 +209,8 @@ export function api_factory(
async function upload_files(
root: string,
files: (Blob | File)[],
token?: `hf_${string}`
token?: `hf_${string}`,
upload_id?: string
): Promise<UploadResponse> {
const headers: {
Authorization?: string;
@ -225,7 +227,10 @@ export function api_factory(
formData.append("files", file);
});
try {
var response = await fetch_implementation(`${root}/upload`, {
const upload_url = upload_id
? `${root}/upload?upload_id=${upload_id}`
: `${root}/upload`;
var response = await fetch_implementation(upload_url, {
method: "POST",
body: formData,
headers

View File

@ -88,6 +88,7 @@ export function get_fetchable_url_or_file(
export async function upload(
file_data: FileData[],
root: string,
upload_id?: string,
upload_fn: typeof upload_files = upload_files
): Promise<(FileData | null)[] | null> {
let files = (Array.isArray(file_data) ? file_data : [file_data]).map(
@ -95,7 +96,7 @@ export async function upload(
);
return await Promise.all(
await upload_fn(root, files).then(
await upload_fn(root, files, undefined, upload_id).then(
async (response: { files?: string[]; error?: string }) => {
if (response.error) {
throw new Error(response.error);

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import hashlib
import json
from collections import deque
from dataclasses import dataclass as python_dataclass
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
@ -294,6 +296,44 @@ class GradioUploadFile(UploadFile):
self.sha = hashlib.sha1()
@python_dataclass(frozen=True)
class FileUploadProgressUnit:
filename: str
chunk_size: int
is_done: bool
class FileUploadProgress:
def __init__(self) -> None:
self._statuses: dict[str, deque[FileUploadProgressUnit]] = {}
def track(self, upload_id: str):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()
def update(self, upload_id: str, filename: str, message_bytes: bytes):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()
self._statuses[upload_id].append(
FileUploadProgressUnit(filename, len(message_bytes), is_done=False)
)
def set_done(self, upload_id: str):
self._statuses[upload_id].append(FileUploadProgressUnit("", 0, is_done=True))
def stop_tracking(self, upload_id: str):
if upload_id in self._statuses:
del self._statuses[upload_id]
def status(self, upload_id: str) -> deque[FileUploadProgressUnit]:
if upload_id not in self._statuses:
return deque()
return self._statuses[upload_id]
def is_tracked(self, upload_id: str):
return upload_id in self._statuses
class GradioMultiPartParser:
"""Vendored from starlette.MultipartParser.
@ -315,6 +355,8 @@ class GradioMultiPartParser:
*,
max_files: Union[int, float] = 1000,
max_fields: Union[int, float] = 1000,
upload_id: str | None = None,
upload_progress: FileUploadProgress | None = None,
) -> None:
assert (
multipart is not None
@ -324,6 +366,8 @@ class GradioMultiPartParser:
self.max_files = max_files
self.max_fields = max_fields
self.items: List[Tuple[str, Union[str, UploadFile]]] = []
self.upload_id = upload_id
self.upload_progress = upload_progress
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
@ -339,6 +383,10 @@ class GradioMultiPartParser:
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self.upload_progress is not None:
self.upload_progress.update(
self.upload_id, self._current_part.file.filename, message_bytes # type: ignore
)
if self._current_part.file is None:
self._current_part.data += message_bytes
else:
@ -464,4 +512,6 @@ class GradioMultiPartParser:
raise exc
parser.finalize()
if self.upload_progress is not None:
self.upload_progress.set_done(self.upload_id) # type: ignore
return FormData(self.items)

View File

@ -56,6 +56,7 @@ from gradio.helpers import CACHED_FOLDER
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation, Event
from gradio.route_utils import ( # noqa: F401
FileUploadProgress,
GradioMultiPartParser,
GradioUploadFile,
MultiPartException,
@ -121,6 +122,9 @@ def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> N
shutil.move(file, dest)
file_upload_statuses = FileUploadProgress()
class App(FastAPI):
"""
FastAPI App Wrapper
@ -681,8 +685,57 @@ class App(FastAPI):
async def get_queue_status():
return app.get_blocks()._queue.get_estimation()
@app.get("/upload_progress")
def get_upload_progress(upload_id: str, request: fastapi.Request):
async def sse_stream(request: fastapi.Request):
last_heartbeat = time.perf_counter()
is_done = False
while True:
if await request.is_disconnected():
file_upload_statuses.stop_tracking(upload_id)
return
if is_done:
file_upload_statuses.stop_tracking(upload_id)
return
heartbeat_rate = 15
check_rate = 0.05
message = None
try:
if update := file_upload_statuses.status(upload_id).popleft():
if update.is_done:
message = {"msg": "done"}
is_done = True
else:
message = {
"msg": "update",
"orig_name": update.filename,
"chunk_size": update.chunk_size,
}
else:
await asyncio.sleep(check_rate)
if time.perf_counter() - last_heartbeat > heartbeat_rate:
message = {"msg": "heartbeat"}
last_heartbeat = time.perf_counter()
if message:
yield f"data: {json.dumps(message)}\n\n"
except IndexError:
if not file_upload_statuses.is_tracked(upload_id):
return
# pop from empty queue
continue
return StreamingResponse(
sse_stream(request),
media_type="text/event-stream",
)
@app.post("/upload", dependencies=[Depends(login_check)])
async def upload_file(request: fastapi.Request, bg_tasks: BackgroundTasks):
async def upload_file(
request: fastapi.Request,
bg_tasks: BackgroundTasks,
upload_id: Optional[str] = None,
):
content_type_header = request.headers.get("Content-Type")
content_type: bytes
content_type, _ = parse_options_header(content_type_header)
@ -690,11 +743,15 @@ class App(FastAPI):
raise HTTPException(status_code=400, detail="Invalid content type.")
try:
if upload_id:
file_upload_statuses.track(upload_id)
multipart_parser = GradioMultiPartParser(
request.headers,
request.stream(),
max_files=1000,
max_fields=1000,
upload_id=upload_id if upload_id else None,
upload_progress=file_upload_statuses if upload_id else None,
)
form = await multipart_parser.parse()
except MultiPartException as exc:

View File

@ -7,7 +7,7 @@ test("Audio click-to-upload uploads audio successfuly.", async ({ page }) => {
const uploader = await page.locator("input[type=file]");
await Promise.all([
uploader.setInputFiles(["../../test/test_files/audio_sample.wav"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
@ -21,7 +21,7 @@ test("Audio click-to-upload uploads audio successfuly.", async ({ page }) => {
await Promise.all([
uploader.setInputFiles(["../../test/test_files/audio_sample.wav"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);
await expect(page.getByLabel("# Change Events")).toHaveValue("3");
@ -39,7 +39,7 @@ test("Audio drag-and-drop uploads a file to the server correctly.", async ({
"audio_sample.wav",
"audio/wav"
),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
await expect(page.getByLabel("# Upload Events")).toHaveValue("1");

View File

@ -9,7 +9,7 @@ test("Video click-to-upload uploads video successfuly. Clear, play, and pause bu
const uploader = await page.locator("input[type=file]");
await Promise.all([
uploader.setInputFiles(["./test/files/file_test.ogg"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*?*")
]);
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
@ -28,7 +28,7 @@ test("Video click-to-upload uploads video successfuly. Clear, play, and pause bu
await Promise.all([
uploader.setInputFiles(["./test/files/file_test.ogg"]),
page.waitForResponse("**/upload")
page.waitForResponse("**/upload?*")
]);
await expect(page.getByLabel("# Change Events")).toHaveValue("3");
@ -50,7 +50,7 @@ test("Video drag-and-drop uploads a file to the server correctly.", async ({
"file_test.ogg",
"video/*"
);
await page.waitForResponse("**/upload");
await page.waitForResponse("**/upload?*");
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
await expect(page.getByLabel("# Upload Events")).toHaveValue("1");
});

View File

@ -3,6 +3,7 @@
import type { FileData } from "@gradio/client";
import { upload_files, upload, prepare_files } from "@gradio/client";
import { _ } from "svelte-i18n";
import UploadProgress from "./UploadProgress.svelte";
export let filetype: string | null = null;
export let dragging = false;
@ -15,6 +16,10 @@
export let hidden = false;
export let include_sources = false;
let uploading = false;
let upload_id: string;
let file_data: FileData[];
// Needed for wasm support
const upload_fn = getContext<typeof upload_files>("upload_files");
@ -36,7 +41,9 @@
file_data: FileData[]
): Promise<(FileData | null)[]> {
await tick();
const _file_data = await upload(file_data, root, upload_fn);
upload_id = Math.random().toString(36).substring(2, 15);
uploading = true;
const _file_data = await upload(file_data, root, upload_id, upload_fn);
dispatch("load", file_count === "single" ? _file_data?.[0] : _file_data);
return _file_data || [];
}
@ -48,7 +55,7 @@
return;
}
let _files: File[] = files.map((f) => new File([f], f.name));
let file_data = await prepare_files(_files);
file_data = await prepare_files(_files);
return await handle_upload(file_data);
}
@ -90,35 +97,46 @@
}
</script>
<button
class:hidden
class:center
class:boundedheight
class:flex
style:height={include_sources ? "calc(100% - 40px" : "100%"}
on:drag|preventDefault|stopPropagation
on:dragstart|preventDefault|stopPropagation
on:dragend|preventDefault|stopPropagation
on:dragover|preventDefault|stopPropagation
on:dragenter|preventDefault|stopPropagation
on:dragleave|preventDefault|stopPropagation
on:drop|preventDefault|stopPropagation
on:click={open_file_upload}
on:drop={loadFilesFromDrop}
on:dragenter={updateDragging}
on:dragleave={updateDragging}
>
<slot />
<input
type="file"
bind:this={hidden_upload}
on:change={load_files_from_upload}
accept={filetype}
multiple={file_count === "multiple" || undefined}
webkitdirectory={file_count === "directory" || undefined}
mozdirectory={file_count === "directory" || undefined}
{#if uploading}
<UploadProgress
{root}
{upload_id}
files={file_data}
on:done={() => {
uploading = false;
}}
/>
</button>
{:else}
<button
class:hidden
class:center
class:boundedheight
class:flex
style:height={include_sources ? "calc(100% - 40px" : "100%"}
on:drag|preventDefault|stopPropagation
on:dragstart|preventDefault|stopPropagation
on:dragend|preventDefault|stopPropagation
on:dragover|preventDefault|stopPropagation
on:dragenter|preventDefault|stopPropagation
on:dragleave|preventDefault|stopPropagation
on:drop|preventDefault|stopPropagation
on:click={open_file_upload}
on:drop={loadFilesFromDrop}
on:dragenter={updateDragging}
on:dragleave={updateDragging}
>
<slot />
<input
type="file"
bind:this={hidden_upload}
on:change={load_files_from_upload}
accept={filetype}
multiple={file_count === "multiple" || undefined}
webkitdirectory={file_count === "directory" || undefined}
mozdirectory={file_count === "directory" || undefined}
/>
</button>
{/if}
<style>
button {

View File

@ -0,0 +1,102 @@
<script lang="ts">
import { FileData } from "@gradio/client";
import { onMount, createEventDispatcher } from "svelte";
type FileDataWithProgress = FileData & { progress: number };
export let upload_id: string;
export let root: string;
export let files: FileData[];
let event_source: EventSource;
let progress = false;
let files_with_progress: FileDataWithProgress[] = files.map((file) => {
return {
...file,
progress: 0
};
});
const dispatch = createEventDispatcher();
function handleProgress(filename: string, chunk_size: number): void {
// Find the corresponding file in the array and update its progress
files_with_progress = files_with_progress.map((file) => {
if (file.orig_name === filename) {
file.progress += chunk_size;
}
return file;
});
}
function getProgress(file: FileDataWithProgress): number {
return (file.progress * 100) / (file.size || 0) || 0;
}
onMount(() => {
event_source = new EventSource(
`${root}/upload_progress?upload_id=${upload_id}`
);
// Event listener for progress updates
event_source.onmessage = async function (event) {
const _data = JSON.parse(event.data);
if (!progress) progress = true;
if (_data.msg === "done") {
event_source.close();
dispatch("done");
} else {
handleProgress(_data.orig_name, _data.chunk_size);
}
};
});
</script>
<div class="wrap" class:progress>
{#each files_with_progress as file, index}
<div class="file-info">
<span>Uploading {file.orig_name}...</span>
</div>
<div class="progress-bar-wrap">
<div class="progress-bar" style="width: {getProgress(file)}%;"></div>
</div>
{/each}
</div>
<style>
.wrap {
margin-top: var(--size-7);
overflow-y: auto;
opacity: 0;
transition: opacity 0.5s ease-in-out;
background: var(--block-background-fill);
}
.wrap.progress {
opacity: 1;
}
.progress-bar-wrap {
border: 1px solid var(--border-color-primary);
background: var(--background-fill-primary);
height: var(--size-4);
}
.progress-bar {
transform-origin: left;
background-color: var(--loader-color);
height: var(--size-full);
}
.file-info {
height: 100%;
justify-content: center;
text-align: center;
width: 100%;
}
.file-info span {
color: var(--body-text-color);
font-size: var(--text-med);
font-family: var(--font-mono);
}
</style>