Improve file handling in JS Client (#8462)

* add handler for URLs, Blobs and Files

* add changeset

* remove NodeBlob

* add local file handling

* handle buffers

* add test

* type tweaks

* fix node test with file

* test

* fix test

* handle nested files

* env tweaks

* tweak

* fix test

* use file instead of blob

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Hannah 2024-06-06 13:56:16 +02:00 committed by GitHub
parent 8c18114495
commit 6447dface4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 256 additions and 47 deletions

View File

@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---
fix:Improve file handling in JS Client

View File

@ -33,12 +33,6 @@ import { check_space_status } from "./helpers/spaces";
import { open_stream } from "./utils/stream";
import { API_INFO_ERROR_MSG, CONFIG_ERROR_MSG } from "./constants";
export class NodeBlob extends Blob {
constructor(blobParts?: BlobPart[], options?: BlobPropertyBag) {
super(blobParts, options);
}
}
export class Client {
app_reference: string;
options: ClientOptions;
@ -141,8 +135,6 @@ export class Client {
!global.WebSocket
) {
const ws = await import("ws");
// @ts-ignore
NodeBlob = (await import("node:buffer")).Blob;
global.WebSocket = ws.WebSocket as unknown as typeof WebSocket;
}

View File

@ -30,3 +30,7 @@ export const UNAUTHORIZED_MSG = "Not authorized to access this space. ";
export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. ";
export const MISSING_CREDENTIALS_MSG =
"Login credentials are required to access this space.";
export const NODEJS_FS_ERROR_MSG =
"File system access is only available in Node.js environments";
export const ROOT_URL_ERROR_MSG = "Root URL not found in client config";
export const FILE_PROCESSING_ERROR_MSG = "Error uploading file";

View File

@ -1,14 +1,18 @@
import { NodeBlob } from "../client";
import type {
ApiData,
BlobRef,
Config,
EndpointInfo,
JsApiData,
DataType,
Dependency,
ComponentMeta
import {
type ApiData,
type BlobRef,
type Config,
type EndpointInfo,
type JsApiData,
type DataType,
Command,
type Dependency,
type ComponentMeta
} from "../types";
import { FileData } from "../upload";
const is_node =
typeof process !== "undefined" && process.versions && process.versions.node;
export function update_object(
object: { [x: string]: any },
@ -66,11 +70,10 @@ export async function walk_and_store_blobs(
(globalThis.Buffer && data instanceof globalThis.Buffer) ||
data instanceof Blob
) {
const is_image = type === "Image";
return [
{
path: path,
blob: is_image ? false : new NodeBlob([data]),
blob: new Blob([data]),
type
}
];
@ -121,6 +124,56 @@ export function post_message<Res = any>(
});
}
export function handle_file(
file_or_url: File | string | Blob | Buffer
): FileData | Blob | Command {
if (typeof file_or_url === "string") {
if (
file_or_url.startsWith("http://") ||
file_or_url.startsWith("https://")
) {
return {
path: file_or_url,
url: file_or_url,
orig_name: file_or_url.split("/").pop() ?? "unknown",
meta: { _type: "gradio.FileData" }
};
}
if (is_node) {
// Handle local file paths
return new Command("upload_file", {
path: file_or_url,
name: file_or_url,
orig_path: file_or_url
});
}
} else if (typeof File !== "undefined" && file_or_url instanceof File) {
return {
path: file_or_url instanceof File ? file_or_url.name : "blob",
orig_name: file_or_url instanceof File ? file_or_url.name : "unknown",
// @ts-ignore
blob: file_or_url instanceof File ? file_or_url : new Blob([file_or_url]),
size:
file_or_url instanceof Blob
? file_or_url.size
: Buffer.byteLength(file_or_url as Buffer),
mime_type:
file_or_url instanceof File
? file_or_url.type
: "application/octet-stream", // Default MIME type for buffers
meta: { _type: "gradio.FileData" }
};
} else if (file_or_url instanceof Buffer) {
return new Blob([file_or_url]);
} else if (file_or_url instanceof Blob) {
return file_or_url;
}
throw new Error(
"Invalid input: must be a URL, File, Blob, or Buffer object."
);
}
/**
* Handles the payload by filtering out state inputs and returning an array of resolved payload values.
* We send null values for state inputs to the server, but we don't want to include them in the resolved payload.

View File

@ -4,6 +4,7 @@ export { predict } from "./utils/predict";
export { submit } from "./utils/submit";
export { upload_files } from "./utils/upload_files";
export { FileData, upload, prepare_files } from "./upload";
export { handle_file } from "./helpers/data";
export type {
SpaceStatus,

View File

@ -4,11 +4,14 @@ import {
walk_and_store_blobs,
skip_queue,
post_message,
handle_file,
handle_payload
} from "../helpers/data";
import { NodeBlob } from "../client";
import { config_response, endpoint_info } from "./test_data";
import { BlobRef } from "../types";
import { BlobRef, Command } from "../types";
import { FileData } from "../upload";
const IS_NODE = process.env.TEST_MODE === "node";
describe("walk_and_store_blobs", () => {
it("should convert a Buffer to a Blob", async () => {
@ -16,7 +19,7 @@ describe("walk_and_store_blobs", () => {
const parts = await walk_and_store_blobs(buffer, "text");
expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});
it("should return a Blob when passed a Blob", async () => {
@ -29,19 +32,7 @@ describe("walk_and_store_blobs", () => {
endpoint_info
);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
});
it("should return blob: false when passed an image", async () => {
const blob = new Blob([]);
const parts = await walk_and_store_blobs(
blob,
"Image",
[],
true,
endpoint_info
);
expect(parts[0].blob).toBe(false);
expect(parts[0].blob).toBeInstanceOf(Blob);
});
it("should handle arrays", async () => {
@ -49,7 +40,7 @@ describe("walk_and_store_blobs", () => {
const parts = await walk_and_store_blobs([image]);
expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(parts[0].path).toEqual(["0"]);
});
@ -58,7 +49,7 @@ describe("walk_and_store_blobs", () => {
const parts = await walk_and_store_blobs({ a: { b: { data: { image } } } });
expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(parts[0].path).toEqual(["a", "b", "data", "image"]);
});
@ -80,7 +71,7 @@ describe("walk_and_store_blobs", () => {
]
});
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});
it("should handle deep structures with arrays (with equality check)", async () => {
@ -104,8 +95,8 @@ describe("walk_and_store_blobs", () => {
let ref = obj;
path.forEach((p) => (ref = ref[p]));
// since ref is a Blob and blob is a NodeBlob, we deep equal check the two buffers instead
if (ref instanceof Blob && blob instanceof NodeBlob) {
// since ref is a Blob and blob is a Blob, we deep equal check the two buffers instead
if (ref instanceof Blob && blob instanceof Blob) {
const refBuffer = Buffer.from(await ref.arrayBuffer());
const blobBuffer = Buffer.from(await blob.arrayBuffer());
return refBuffer.equals(blobBuffer);
@ -114,7 +105,7 @@ describe("walk_and_store_blobs", () => {
return ref === blob;
}
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(map_path(obj, parts)).toBeTruthy();
});
@ -123,7 +114,7 @@ describe("walk_and_store_blobs", () => {
const parts = await walk_and_store_blobs(buffer, undefined, ["blob"]);
expect(parts).toHaveLength(1);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
expect(parts[0].path).toEqual(["blob"]);
});
@ -133,7 +124,7 @@ describe("walk_and_store_blobs", () => {
expect(parts).toHaveLength(1);
expect(parts[0].path).toEqual([]);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});
it("should convert an object with deep structures to BlobRefs", async () => {
@ -150,7 +141,7 @@ describe("walk_and_store_blobs", () => {
expect(parts).toHaveLength(1);
expect(parts[0].path).toEqual(["a", "b", "data", "image"]);
expect(parts[0].blob).toBeInstanceOf(NodeBlob);
expect(parts[0].blob).toBeInstanceOf(Blob);
});
});
describe("update_object", () => {
@ -278,6 +269,57 @@ describe("post_message", () => {
});
});
describe("handle_file", () => {
it("should handle a Blob object and return the blob", () => {
const blob = new Blob(["test data"], { type: "image/png" });
const result = handle_file(blob) as FileData;
expect(result).toBe(blob);
});
it("should handle a Buffer object and return it as a blob", () => {
const buffer = Buffer.from("test data");
const result = handle_file(buffer) as FileData;
expect(result).toBeInstanceOf(Blob);
});
it("should handle a local file path and return a Command object", () => {
const file_path = "./owl.png";
const result = handle_file(file_path) as Command;
expect(result).toBeInstanceOf(Command);
expect(result).toEqual({
type: "command",
command: "upload_file",
meta: { path: "./owl.png", name: "./owl.png", orig_path: "./owl.png" },
fileData: undefined
});
});
it("should handle a File object and return it as FileData", () => {
if (IS_NODE) {
return;
}
const file = new File(["test image"], "test.png", { type: "image/png" });
const result = handle_file(file) as FileData;
expect(result.path).toBe("test.png");
expect(result.orig_name).toBe("test.png");
expect(result.blob).toBeInstanceOf(Blob);
expect(result.size).toBe(file.size);
expect(result.mime_type).toBe("image/png");
expect(result.meta).toEqual({ _type: "gradio.FileData" });
});
it("should throw an error for invalid input", () => {
const invalid_input = 123;
expect(() => {
// @ts-ignore
handle_file(invalid_input);
}).toThrowError(
"Invalid input: must be a URL, File, Blob, or Buffer object."
);
});
});
describe("handle_payload", () => {
it("should return an input payload with null in place of `state` when with_null_state is true", () => {
const resolved_payload = [2];

View File

@ -51,6 +51,27 @@ export interface BlobRef {
export type DataType = string | Buffer | Record<string, any> | any[];
// custom class used for uploading local files
export class Command {
type: string;
command: string;
meta: {
path: string;
name: string;
orig_path: string;
};
fileData?: FileData;
constructor(
command: string,
meta: { path: string; name: string; orig_path: string }
) {
this.type = "command";
this.command = command;
this.meta = meta;
}
}
// Function Signature Types
export type SubmitFunction = (

View File

@ -1,7 +1,17 @@
import { update_object, walk_and_store_blobs } from "../helpers/data";
import type { ApiData, EndpointInfo, JsApiData } from "../types";
import {
Command,
type ApiData,
type EndpointInfo,
type JsApiData
} from "../types";
import { FileData } from "../upload";
import type { Client } from "..";
import {
FILE_PROCESSING_ERROR_MSG,
NODEJS_FS_ERROR_MSG,
ROOT_URL_ERROR_MSG
} from "../constants";
export async function handle_blob(
this: Client,
@ -11,6 +21,8 @@ export async function handle_blob(
): Promise<unknown[]> {
const self = this;
await process_local_file_commands(self, data);
const blobRefs = await walk_and_store_blobs(
data,
undefined,
@ -45,3 +57,81 @@ export async function handle_blob(
return data;
}
export async function process_local_file_commands(
client: Client,
data: unknown[]
): Promise<void> {
const root = client.config?.root || client.config?.root_url;
if (!root) {
throw new Error(ROOT_URL_ERROR_MSG);
}
await recursively_process_commands(client, data);
}
async function recursively_process_commands(
client: Client,
data: any,
path: string[] = []
): Promise<void> {
for (const key in data) {
if (data[key] instanceof Command) {
await process_single_command(client, data, key);
} else if (typeof data[key] === "object" && data[key] !== null) {
await recursively_process_commands(client, data[key], [...path, key]);
}
}
}
async function process_single_command(
client: Client,
data: any,
key: string
): Promise<void> {
let cmd_item = data[key] as Command;
const root = client.config?.root || client.config?.root_url;
if (!root) {
throw new Error(ROOT_URL_ERROR_MSG);
}
try {
let fileBuffer: Buffer;
let fullPath: string;
// check if running in a Node.js environment
if (
typeof process !== "undefined" &&
process.versions &&
process.versions.node
) {
const fs = await import("fs/promises");
const path = await import("path");
fullPath = path.resolve(process.cwd(), cmd_item.meta.path);
fileBuffer = await fs.readFile(fullPath); // Read file from disk
} else {
throw new Error(NODEJS_FS_ERROR_MSG);
}
const file = new Blob([fileBuffer], { type: "application/octet-stream" });
const response = await client.upload_files(root, [file]);
const file_url = response.files && response.files[0];
if (file_url) {
const fileData = new FileData({
path: file_url,
orig_name: cmd_item.meta.name || ""
});
// replace the command object with the fileData object
data[key] = fileData;
}
} catch (error) {
console.error(FILE_PROCESSING_ERROR_MSG, error);
}
}

View File

@ -27,7 +27,7 @@ export async function upload_files(
});
try {
const upload_url = upload_id
? `${root_url}/upload?upload_id=${upload_id}`
? `${root_url}/${UPLOAD_URL}?upload_id=${upload_id}`
: `${root_url}/${UPLOAD_URL}`;
response = await this.fetch(upload_url, {