mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Provide status updates on file uploads (#6307)
* Backend * Backend * add changeset * Clean up + close connection * Lint * Fix tests * Apply opacity transition --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
e9bb445a63
commit
f1409f95ed
7
.changeset/tired-berries-tease.md
Normal file
7
.changeset/tired-berries-tease.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/client": minor
|
||||
"@gradio/upload": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Provide status updates on file uploads
|
@ -156,7 +156,8 @@ interface Client {
|
||||
upload_files: (
|
||||
root: string,
|
||||
files: File[],
|
||||
token?: `hf_${string}`
|
||||
token?: `hf_${string}`,
|
||||
upload_id?: string
|
||||
) => Promise<UploadResponse>;
|
||||
client: (
|
||||
app_reference: string,
|
||||
@ -208,7 +209,8 @@ export function api_factory(
|
||||
async function upload_files(
|
||||
root: string,
|
||||
files: (Blob | File)[],
|
||||
token?: `hf_${string}`
|
||||
token?: `hf_${string}`,
|
||||
upload_id?: string
|
||||
): Promise<UploadResponse> {
|
||||
const headers: {
|
||||
Authorization?: string;
|
||||
@ -225,7 +227,10 @@ export function api_factory(
|
||||
formData.append("files", file);
|
||||
});
|
||||
try {
|
||||
var response = await fetch_implementation(`${root}/upload`, {
|
||||
const upload_url = upload_id
|
||||
? `${root}/upload?upload_id=${upload_id}`
|
||||
: `${root}/upload`;
|
||||
var response = await fetch_implementation(upload_url, {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
headers
|
||||
|
@ -88,6 +88,7 @@ export function get_fetchable_url_or_file(
|
||||
export async function upload(
|
||||
file_data: FileData[],
|
||||
root: string,
|
||||
upload_id?: string,
|
||||
upload_fn: typeof upload_files = upload_files
|
||||
): Promise<(FileData | null)[] | null> {
|
||||
let files = (Array.isArray(file_data) ? file_data : [file_data]).map(
|
||||
@ -95,7 +96,7 @@ export async function upload(
|
||||
);
|
||||
|
||||
return await Promise.all(
|
||||
await upload_fn(root, files).then(
|
||||
await upload_fn(root, files, undefined, upload_id).then(
|
||||
async (response: { files?: string[]; error?: string }) => {
|
||||
if (response.error) {
|
||||
throw new Error(response.error);
|
||||
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections import deque
|
||||
from dataclasses import dataclass as python_dataclass
|
||||
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
|
||||
|
||||
@ -294,6 +296,44 @@ class GradioUploadFile(UploadFile):
|
||||
self.sha = hashlib.sha1()
|
||||
|
||||
|
||||
@python_dataclass(frozen=True)
|
||||
class FileUploadProgressUnit:
|
||||
filename: str
|
||||
chunk_size: int
|
||||
is_done: bool
|
||||
|
||||
|
||||
class FileUploadProgress:
|
||||
def __init__(self) -> None:
|
||||
self._statuses: dict[str, deque[FileUploadProgressUnit]] = {}
|
||||
|
||||
def track(self, upload_id: str):
|
||||
if upload_id not in self._statuses:
|
||||
self._statuses[upload_id] = deque()
|
||||
|
||||
def update(self, upload_id: str, filename: str, message_bytes: bytes):
|
||||
if upload_id not in self._statuses:
|
||||
self._statuses[upload_id] = deque()
|
||||
self._statuses[upload_id].append(
|
||||
FileUploadProgressUnit(filename, len(message_bytes), is_done=False)
|
||||
)
|
||||
|
||||
def set_done(self, upload_id: str):
|
||||
self._statuses[upload_id].append(FileUploadProgressUnit("", 0, is_done=True))
|
||||
|
||||
def stop_tracking(self, upload_id: str):
|
||||
if upload_id in self._statuses:
|
||||
del self._statuses[upload_id]
|
||||
|
||||
def status(self, upload_id: str) -> deque[FileUploadProgressUnit]:
|
||||
if upload_id not in self._statuses:
|
||||
return deque()
|
||||
return self._statuses[upload_id]
|
||||
|
||||
def is_tracked(self, upload_id: str):
|
||||
return upload_id in self._statuses
|
||||
|
||||
|
||||
class GradioMultiPartParser:
|
||||
"""Vendored from starlette.MultipartParser.
|
||||
|
||||
@ -315,6 +355,8 @@ class GradioMultiPartParser:
|
||||
*,
|
||||
max_files: Union[int, float] = 1000,
|
||||
max_fields: Union[int, float] = 1000,
|
||||
upload_id: str | None = None,
|
||||
upload_progress: FileUploadProgress | None = None,
|
||||
) -> None:
|
||||
assert (
|
||||
multipart is not None
|
||||
@ -324,6 +366,8 @@ class GradioMultiPartParser:
|
||||
self.max_files = max_files
|
||||
self.max_fields = max_fields
|
||||
self.items: List[Tuple[str, Union[str, UploadFile]]] = []
|
||||
self.upload_id = upload_id
|
||||
self.upload_progress = upload_progress
|
||||
self._current_files = 0
|
||||
self._current_fields = 0
|
||||
self._current_partial_header_name: bytes = b""
|
||||
@ -339,6 +383,10 @@ class GradioMultiPartParser:
|
||||
|
||||
def on_part_data(self, data: bytes, start: int, end: int) -> None:
|
||||
message_bytes = data[start:end]
|
||||
if self.upload_progress is not None:
|
||||
self.upload_progress.update(
|
||||
self.upload_id, self._current_part.file.filename, message_bytes # type: ignore
|
||||
)
|
||||
if self._current_part.file is None:
|
||||
self._current_part.data += message_bytes
|
||||
else:
|
||||
@ -464,4 +512,6 @@ class GradioMultiPartParser:
|
||||
raise exc
|
||||
|
||||
parser.finalize()
|
||||
if self.upload_progress is not None:
|
||||
self.upload_progress.set_done(self.upload_id) # type: ignore
|
||||
return FormData(self.items)
|
||||
|
@ -56,6 +56,7 @@ from gradio.helpers import CACHED_FOLDER
|
||||
from gradio.oauth import attach_oauth
|
||||
from gradio.queueing import Estimation, Event
|
||||
from gradio.route_utils import ( # noqa: F401
|
||||
FileUploadProgress,
|
||||
GradioMultiPartParser,
|
||||
GradioUploadFile,
|
||||
MultiPartException,
|
||||
@ -121,6 +122,9 @@ def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> N
|
||||
shutil.move(file, dest)
|
||||
|
||||
|
||||
file_upload_statuses = FileUploadProgress()
|
||||
|
||||
|
||||
class App(FastAPI):
|
||||
"""
|
||||
FastAPI App Wrapper
|
||||
@ -681,8 +685,57 @@ class App(FastAPI):
|
||||
async def get_queue_status():
|
||||
return app.get_blocks()._queue.get_estimation()
|
||||
|
||||
@app.get("/upload_progress")
|
||||
def get_upload_progress(upload_id: str, request: fastapi.Request):
|
||||
async def sse_stream(request: fastapi.Request):
|
||||
last_heartbeat = time.perf_counter()
|
||||
is_done = False
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
file_upload_statuses.stop_tracking(upload_id)
|
||||
return
|
||||
if is_done:
|
||||
file_upload_statuses.stop_tracking(upload_id)
|
||||
return
|
||||
|
||||
heartbeat_rate = 15
|
||||
check_rate = 0.05
|
||||
message = None
|
||||
try:
|
||||
if update := file_upload_statuses.status(upload_id).popleft():
|
||||
if update.is_done:
|
||||
message = {"msg": "done"}
|
||||
is_done = True
|
||||
else:
|
||||
message = {
|
||||
"msg": "update",
|
||||
"orig_name": update.filename,
|
||||
"chunk_size": update.chunk_size,
|
||||
}
|
||||
else:
|
||||
await asyncio.sleep(check_rate)
|
||||
if time.perf_counter() - last_heartbeat > heartbeat_rate:
|
||||
message = {"msg": "heartbeat"}
|
||||
last_heartbeat = time.perf_counter()
|
||||
if message:
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
except IndexError:
|
||||
if not file_upload_statuses.is_tracked(upload_id):
|
||||
return
|
||||
# pop from empty queue
|
||||
continue
|
||||
|
||||
return StreamingResponse(
|
||||
sse_stream(request),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@app.post("/upload", dependencies=[Depends(login_check)])
|
||||
async def upload_file(request: fastapi.Request, bg_tasks: BackgroundTasks):
|
||||
async def upload_file(
|
||||
request: fastapi.Request,
|
||||
bg_tasks: BackgroundTasks,
|
||||
upload_id: Optional[str] = None,
|
||||
):
|
||||
content_type_header = request.headers.get("Content-Type")
|
||||
content_type: bytes
|
||||
content_type, _ = parse_options_header(content_type_header)
|
||||
@ -690,11 +743,15 @@ class App(FastAPI):
|
||||
raise HTTPException(status_code=400, detail="Invalid content type.")
|
||||
|
||||
try:
|
||||
if upload_id:
|
||||
file_upload_statuses.track(upload_id)
|
||||
multipart_parser = GradioMultiPartParser(
|
||||
request.headers,
|
||||
request.stream(),
|
||||
max_files=1000,
|
||||
max_fields=1000,
|
||||
upload_id=upload_id if upload_id else None,
|
||||
upload_progress=file_upload_statuses if upload_id else None,
|
||||
)
|
||||
form = await multipart_parser.parse()
|
||||
except MultiPartException as exc:
|
||||
|
@ -7,7 +7,7 @@ test("Audio click-to-upload uploads audio successfuly.", async ({ page }) => {
|
||||
const uploader = await page.locator("input[type=file]");
|
||||
await Promise.all([
|
||||
uploader.setInputFiles(["../../test/test_files/audio_sample.wav"]),
|
||||
page.waitForResponse("**/upload")
|
||||
page.waitForResponse("**/upload?*")
|
||||
]);
|
||||
|
||||
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
|
||||
@ -21,7 +21,7 @@ test("Audio click-to-upload uploads audio successfuly.", async ({ page }) => {
|
||||
|
||||
await Promise.all([
|
||||
uploader.setInputFiles(["../../test/test_files/audio_sample.wav"]),
|
||||
page.waitForResponse("**/upload")
|
||||
page.waitForResponse("**/upload?*")
|
||||
]);
|
||||
|
||||
await expect(page.getByLabel("# Change Events")).toHaveValue("3");
|
||||
@ -39,7 +39,7 @@ test("Audio drag-and-drop uploads a file to the server correctly.", async ({
|
||||
"audio_sample.wav",
|
||||
"audio/wav"
|
||||
),
|
||||
page.waitForResponse("**/upload")
|
||||
page.waitForResponse("**/upload?*")
|
||||
]);
|
||||
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
|
||||
await expect(page.getByLabel("# Upload Events")).toHaveValue("1");
|
||||
|
@ -9,7 +9,7 @@ test("Video click-to-upload uploads video successfuly. Clear, play, and pause bu
|
||||
const uploader = await page.locator("input[type=file]");
|
||||
await Promise.all([
|
||||
uploader.setInputFiles(["./test/files/file_test.ogg"]),
|
||||
page.waitForResponse("**/upload")
|
||||
page.waitForResponse("**/upload?*?*")
|
||||
]);
|
||||
|
||||
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
|
||||
@ -28,7 +28,7 @@ test("Video click-to-upload uploads video successfuly. Clear, play, and pause bu
|
||||
|
||||
await Promise.all([
|
||||
uploader.setInputFiles(["./test/files/file_test.ogg"]),
|
||||
page.waitForResponse("**/upload")
|
||||
page.waitForResponse("**/upload?*")
|
||||
]);
|
||||
|
||||
await expect(page.getByLabel("# Change Events")).toHaveValue("3");
|
||||
@ -50,7 +50,7 @@ test("Video drag-and-drop uploads a file to the server correctly.", async ({
|
||||
"file_test.ogg",
|
||||
"video/*"
|
||||
);
|
||||
await page.waitForResponse("**/upload");
|
||||
await page.waitForResponse("**/upload?*");
|
||||
await expect(page.getByLabel("# Change Events")).toHaveValue("1");
|
||||
await expect(page.getByLabel("# Upload Events")).toHaveValue("1");
|
||||
});
|
||||
|
@ -3,6 +3,7 @@
|
||||
import type { FileData } from "@gradio/client";
|
||||
import { upload_files, upload, prepare_files } from "@gradio/client";
|
||||
import { _ } from "svelte-i18n";
|
||||
import UploadProgress from "./UploadProgress.svelte";
|
||||
|
||||
export let filetype: string | null = null;
|
||||
export let dragging = false;
|
||||
@ -15,6 +16,10 @@
|
||||
export let hidden = false;
|
||||
export let include_sources = false;
|
||||
|
||||
let uploading = false;
|
||||
let upload_id: string;
|
||||
let file_data: FileData[];
|
||||
|
||||
// Needed for wasm support
|
||||
const upload_fn = getContext<typeof upload_files>("upload_files");
|
||||
|
||||
@ -36,7 +41,9 @@
|
||||
file_data: FileData[]
|
||||
): Promise<(FileData | null)[]> {
|
||||
await tick();
|
||||
const _file_data = await upload(file_data, root, upload_fn);
|
||||
upload_id = Math.random().toString(36).substring(2, 15);
|
||||
uploading = true;
|
||||
const _file_data = await upload(file_data, root, upload_id, upload_fn);
|
||||
dispatch("load", file_count === "single" ? _file_data?.[0] : _file_data);
|
||||
return _file_data || [];
|
||||
}
|
||||
@ -48,7 +55,7 @@
|
||||
return;
|
||||
}
|
||||
let _files: File[] = files.map((f) => new File([f], f.name));
|
||||
let file_data = await prepare_files(_files);
|
||||
file_data = await prepare_files(_files);
|
||||
return await handle_upload(file_data);
|
||||
}
|
||||
|
||||
@ -90,35 +97,46 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<button
|
||||
class:hidden
|
||||
class:center
|
||||
class:boundedheight
|
||||
class:flex
|
||||
style:height={include_sources ? "calc(100% - 40px" : "100%"}
|
||||
on:drag|preventDefault|stopPropagation
|
||||
on:dragstart|preventDefault|stopPropagation
|
||||
on:dragend|preventDefault|stopPropagation
|
||||
on:dragover|preventDefault|stopPropagation
|
||||
on:dragenter|preventDefault|stopPropagation
|
||||
on:dragleave|preventDefault|stopPropagation
|
||||
on:drop|preventDefault|stopPropagation
|
||||
on:click={open_file_upload}
|
||||
on:drop={loadFilesFromDrop}
|
||||
on:dragenter={updateDragging}
|
||||
on:dragleave={updateDragging}
|
||||
>
|
||||
<slot />
|
||||
<input
|
||||
type="file"
|
||||
bind:this={hidden_upload}
|
||||
on:change={load_files_from_upload}
|
||||
accept={filetype}
|
||||
multiple={file_count === "multiple" || undefined}
|
||||
webkitdirectory={file_count === "directory" || undefined}
|
||||
mozdirectory={file_count === "directory" || undefined}
|
||||
{#if uploading}
|
||||
<UploadProgress
|
||||
{root}
|
||||
{upload_id}
|
||||
files={file_data}
|
||||
on:done={() => {
|
||||
uploading = false;
|
||||
}}
|
||||
/>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class:hidden
|
||||
class:center
|
||||
class:boundedheight
|
||||
class:flex
|
||||
style:height={include_sources ? "calc(100% - 40px" : "100%"}
|
||||
on:drag|preventDefault|stopPropagation
|
||||
on:dragstart|preventDefault|stopPropagation
|
||||
on:dragend|preventDefault|stopPropagation
|
||||
on:dragover|preventDefault|stopPropagation
|
||||
on:dragenter|preventDefault|stopPropagation
|
||||
on:dragleave|preventDefault|stopPropagation
|
||||
on:drop|preventDefault|stopPropagation
|
||||
on:click={open_file_upload}
|
||||
on:drop={loadFilesFromDrop}
|
||||
on:dragenter={updateDragging}
|
||||
on:dragleave={updateDragging}
|
||||
>
|
||||
<slot />
|
||||
<input
|
||||
type="file"
|
||||
bind:this={hidden_upload}
|
||||
on:change={load_files_from_upload}
|
||||
accept={filetype}
|
||||
multiple={file_count === "multiple" || undefined}
|
||||
webkitdirectory={file_count === "directory" || undefined}
|
||||
mozdirectory={file_count === "directory" || undefined}
|
||||
/>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
button {
|
||||
|
102
js/upload/src/UploadProgress.svelte
Normal file
102
js/upload/src/UploadProgress.svelte
Normal file
@ -0,0 +1,102 @@
|
||||
<script lang="ts">
|
||||
import { FileData } from "@gradio/client";
|
||||
import { onMount, createEventDispatcher } from "svelte";
|
||||
|
||||
type FileDataWithProgress = FileData & { progress: number };
|
||||
|
||||
export let upload_id: string;
|
||||
export let root: string;
|
||||
export let files: FileData[];
|
||||
|
||||
let event_source: EventSource;
|
||||
let progress = false;
|
||||
|
||||
let files_with_progress: FileDataWithProgress[] = files.map((file) => {
|
||||
return {
|
||||
...file,
|
||||
progress: 0
|
||||
};
|
||||
});
|
||||
|
||||
const dispatch = createEventDispatcher();
|
||||
|
||||
function handleProgress(filename: string, chunk_size: number): void {
|
||||
// Find the corresponding file in the array and update its progress
|
||||
files_with_progress = files_with_progress.map((file) => {
|
||||
if (file.orig_name === filename) {
|
||||
file.progress += chunk_size;
|
||||
}
|
||||
return file;
|
||||
});
|
||||
}
|
||||
|
||||
function getProgress(file: FileDataWithProgress): number {
|
||||
return (file.progress * 100) / (file.size || 0) || 0;
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
event_source = new EventSource(
|
||||
`${root}/upload_progress?upload_id=${upload_id}`
|
||||
);
|
||||
// Event listener for progress updates
|
||||
event_source.onmessage = async function (event) {
|
||||
const _data = JSON.parse(event.data);
|
||||
if (!progress) progress = true;
|
||||
if (_data.msg === "done") {
|
||||
event_source.close();
|
||||
dispatch("done");
|
||||
} else {
|
||||
handleProgress(_data.orig_name, _data.chunk_size);
|
||||
}
|
||||
};
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="wrap" class:progress>
|
||||
{#each files_with_progress as file, index}
|
||||
<div class="file-info">
|
||||
<span>Uploading {file.orig_name}...</span>
|
||||
</div>
|
||||
<div class="progress-bar-wrap">
|
||||
<div class="progress-bar" style="width: {getProgress(file)}%;"></div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.wrap {
|
||||
margin-top: var(--size-7);
|
||||
overflow-y: auto;
|
||||
opacity: 0;
|
||||
transition: opacity 0.5s ease-in-out;
|
||||
background: var(--block-background-fill);
|
||||
}
|
||||
|
||||
.wrap.progress {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.progress-bar-wrap {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
background: var(--background-fill-primary);
|
||||
height: var(--size-4);
|
||||
}
|
||||
.progress-bar {
|
||||
transform-origin: left;
|
||||
background-color: var(--loader-color);
|
||||
height: var(--size-full);
|
||||
}
|
||||
|
||||
.file-info {
|
||||
height: 100%;
|
||||
justify-content: center;
|
||||
text-align: center;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.file-info span {
|
||||
color: var(--body-text-color);
|
||||
font-size: var(--text-med);
|
||||
font-family: var(--font-mono);
|
||||
}
|
||||
</style>
|
Loading…
Reference in New Issue
Block a user