Handle gradio apps using state in the JS Client (#8439)

* send `null` for each `state` param in space api

* add changeset

* test

* remove state value from payload from server

* tweak

* test

* test

* Revert "test"

This reverts commit 182045ec7ca9448cbff22621dcc087b8487db8d1.

* Revert "test"

This reverts commit 70e074dfdddc2801d78176654675a076b7d89c1e.

* fixes

* add changeset

* fixes

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: pngwn <hello@pngwn.io>
This commit is contained in:
Hannah 2024-06-05 11:53:06 +02:00 committed by GitHub
parent 5c8915b113
commit 63d36fbbf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 304 additions and 19 deletions

View File

@ -0,0 +1,8 @@
---
"@gradio/app": patch
"@gradio/client": patch
"@gradio/preview": patch
"gradio": patch
---
fix:Handle gradio apps using `state` in the JS Client

View File

@ -6,6 +6,7 @@ import type {
DuplicateOptions,
EndpointInfo,
JsApiData,
PredictReturn,
SpaceStatus,
Status,
SubmitReturn,
@ -114,7 +115,7 @@ export class Client {
endpoint: string | number,
data: unknown[] | Record<string, unknown>,
event_data?: unknown
) => Promise<SubmitReturn>;
) => Promise<PredictReturn>;
open_stream: () => Promise<void>;
private resolve_config: (endpoint: string) => Promise<Config | undefined>;
private resolve_cookies: () => Promise<void>;

View File

@ -5,7 +5,9 @@ import type {
Config,
EndpointInfo,
JsApiData,
DataType
DataType,
Dependency,
ComponentMeta
} from "../types";
export function update_object(
@ -118,3 +120,62 @@ export function post_message<Res = any>(
window.parent.postMessage(message, origin, [channel.port2]);
});
}
/**
* Handles the payload by filtering out state inputs and returning an array of resolved payload values.
* We send null values for state inputs to the server, but we don't want to include them in the resolved payload.
*
* @param resolved_payload - The resolved payload values received from the client or the server
* @param dependency - The dependency object.
* @param components - The array of component metadata.
* @param with_null_state - Optional. Specifies whether to include null values for state inputs. Default is false.
* @returns An array of resolved payload values, filtered based on the dependency and component metadata.
*/
export function handle_payload(
resolved_payload: unknown[],
dependency: Dependency,
components: ComponentMeta[],
type: "input" | "output",
with_null_state = false
): unknown[] {
if (type === "input" && !with_null_state) {
throw new Error("Invalid code path. Cannot skip state inputs for input.");
}
// data comes from the server with null state values so we skip
if (type === "output" && with_null_state) {
return resolved_payload;
}
let updated_payload: unknown[] = [];
let payload_index = 0;
for (let i = 0; i < dependency.inputs.length; i++) {
const input_id = dependency.inputs[i];
const component = components.find((c) => c.id === input_id);
if (component?.type === "state") {
// input + with_null_state needs us to fill state with null values
if (with_null_state) {
if (resolved_payload.length === dependency.inputs.length) {
const value = resolved_payload[payload_index];
updated_payload.push(value);
payload_index++;
} else {
updated_payload.push(null);
}
} else {
// this is output & !with_null_state, we skip state inputs
// the server payload always comes with null state values so we move along the payload index
payload_index++;
continue;
}
// input & !with_null_state isn't a case we care about, server needs null
continue;
} else {
const value = resolved_payload[payload_index];
updated_payload.push(value);
payload_index++;
}
}
return updated_payload;
}

View File

@ -16,7 +16,6 @@ import { initialise_server } from "./server";
import { transformed_api_info } from "./test_data";
const server = initialise_server();
const IS_NODE = process.env.TEST_MODE === "node";
beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());

View File

@ -3,7 +3,8 @@ import {
update_object,
walk_and_store_blobs,
skip_queue,
post_message
post_message,
handle_payload
} from "../helpers/data";
import { NodeBlob } from "../client";
import { config_response, endpoint_info } from "./test_data";
@ -276,3 +277,135 @@ describe("post_message", () => {
]);
});
});
describe("handle_payload", () => {
it("should return an input payload with null in place of `state` when with_null_state is true", () => {
const resolved_payload = [2];
const dependency = {
inputs: [1, 2]
};
const components = [
{ id: 1, type: "number" },
{ id: 2, type: "state" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"input",
with_null_state
);
expect(result).toEqual([2, null]);
});
it("should return an input payload with null in place of two `state` components when with_null_state is true", () => {
const resolved_payload = ["hello", "goodbye"];
const dependency = {
inputs: [1, 2, 3, 4]
};
const components = [
{ id: 1, type: "textbox" },
{ id: 2, type: "state" },
{ id: 3, type: "textbox" },
{ id: 4, type: "state" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"input",
with_null_state
);
expect(result).toEqual(["hello", null, "goodbye", null]);
});
it("should return an output payload without the state component value when with_null_state is false", () => {
const resolved_payload = ["hello", null];
const dependency = {
inputs: [2, 3]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "state" }
];
const with_null_state = false;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"output",
with_null_state
);
expect(result).toEqual(["hello"]);
});
it("should return an ouput payload without the two state component values when with_null_state is false", () => {
const resolved_payload = ["hello", null, "world", null];
const dependency = {
inputs: [2, 3, 4, 5]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "state" },
{ id: 4, type: "textbox" },
{ id: 5, type: "state" }
];
const with_null_state = false;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"output",
with_null_state
);
expect(result).toEqual(["hello", "world"]);
});
it("should return an ouput payload with the two state component values when with_null_state is true", () => {
const resolved_payload = ["hello", null, "world", null];
const dependency = {
inputs: [2, 3, 4, 5]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "state" },
{ id: 4, type: "textbox" },
{ id: 5, type: "state" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
"output",
with_null_state
);
expect(result).toEqual(["hello", null, "world", null]);
});
it("should return the same payload where no state components are defined", () => {
const resolved_payload = ["hello", "world"];
const dependency = {
inputs: [2, 3]
};
const components = [
{ id: 2, type: "textbox" },
{ id: 3, type: "textbox" }
];
const with_null_state = true;
const result = handle_payload(
resolved_payload,
// @ts-ignore
dependency,
components,
with_null_state
);
expect(result).toEqual(["hello", "world"]);
});
});

View File

@ -1,6 +1,8 @@
// API Data Types
import { hardware_types } from "./helpers/spaces";
import type { SvelteComponent } from "svelte";
import type { ComponentType } from "svelte";
export interface ApiData {
label: string;
@ -62,7 +64,7 @@ export type PredictFunction = (
endpoint: string | number,
data: unknown[] | Record<string, unknown>,
event_data?: unknown
) => Promise<SubmitReturn>;
) => Promise<PredictReturn>;
// Event and Submission Types
@ -90,6 +92,14 @@ export type SubmitReturn = {
destroy: () => void;
};
export type PredictReturn = {
type: EventType;
time: Date;
data: unknown;
endpoint: string;
fn_index: number;
};
// Space Status Types
export type SpaceStatus = SpaceStatusNormal | SpaceStatusError;
@ -128,7 +138,7 @@ export interface Config {
analytics_enabled: boolean;
connect_heartbeat: boolean;
auth_message: string;
components: any[];
components: ComponentMeta[];
css: string | null;
js: string | null;
head: string | null;
@ -153,6 +163,45 @@ export interface Config {
max_file_size?: number;
}
// todo: DRY up types
export interface ComponentMeta {
type: string;
id: number;
has_modes: boolean;
props: SharedProps;
instance: SvelteComponent;
component: ComponentType<SvelteComponent>;
documentation?: Documentation;
children?: ComponentMeta[];
parent?: ComponentMeta;
value?: any;
component_class_id: string;
key: string | number | null;
rendered_in?: number;
}
interface SharedProps {
elem_id?: string;
elem_classes?: string[];
components?: string[];
server_fns?: string[];
interactive: boolean;
[key: string]: unknown;
root_url?: string;
}
export interface Documentation {
type?: TypeDescription;
description?: TypeDescription;
example_data?: string;
}
interface TypeDescription {
input_payload?: string;
response_object?: string;
payload?: string;
}
export interface Dependency {
id: number;
targets: [number, string][];
@ -218,6 +267,7 @@ export interface ClientOptions {
hf_token?: `hf_${string}`;
status_callback?: SpaceStatusCallback | null;
auth?: [string, string] | null;
with_null_state?: boolean;
}
export interface FileData {

View File

@ -1,11 +1,11 @@
import { Client } from "../client";
import type { Dependency, SubmitReturn } from "../types";
import type { Dependency, PredictReturn } from "../types";
export async function predict(
this: Client,
endpoint: string | number,
data: unknown[] | Record<string, unknown>
): Promise<SubmitReturn> {
): Promise<PredictReturn> {
let data_returned = false;
let status_complete = false;
let dependency: Dependency;
@ -38,7 +38,7 @@ export async function predict(
// if complete message comes before data, resolve here
if (status_complete) {
app.destroy();
resolve(d as SubmitReturn);
resolve(d as PredictReturn);
}
data_returned = true;
result = d;
@ -50,7 +50,7 @@ export async function predict(
// if complete message comes after data, resolve here
if (data_returned) {
app.destroy();
resolve(result as SubmitReturn);
resolve(result as PredictReturn);
}
}
});

View File

@ -14,7 +14,7 @@ import type {
Dependency
} from "../types";
import { skip_queue, post_message } from "../helpers/data";
import { skip_queue, post_message, handle_payload } from "../helpers/data";
import { resolve_root } from "../helpers/init_helpers";
import {
handle_message,
@ -47,7 +47,8 @@ export function submit(
pending_diff_streams,
event_callbacks,
unclosed_events,
post_data
post_data,
options
} = this;
if (!api_info) throw new Error("No API found");
@ -193,8 +194,15 @@ export function submit(
this.handle_blob(config.root, resolved_data, endpoint_info).then(
async (_payload) => {
let input_data = handle_payload(
_payload,
dependency,
config.components,
"input",
true
);
payload = {
data: _payload || [],
data: input_data || [],
event_data,
fn_index,
trigger_id
@ -225,7 +233,13 @@ export function submit(
type: "data",
endpoint: _endpoint,
fn_index,
data: data,
data: handle_payload(
data,
dependency,
config.components,
"output",
options.with_null_state
),
time: new Date(),
event_data,
trigger_id
@ -359,7 +373,13 @@ export function submit(
fire_event({
type: "data",
time: new Date(),
data: data.data,
data: handle_payload(
data.data,
dependency,
config.components,
"output",
options.with_null_state
),
endpoint: _endpoint,
fn_index,
event_data,
@ -482,7 +502,13 @@ export function submit(
fire_event({
type: "data",
time: new Date(),
data: data.data,
data: handle_payload(
data.data,
dependency,
config.components,
"output",
options.with_null_state
),
endpoint: _endpoint,
fn_index,
event_data,
@ -633,7 +659,13 @@ export function submit(
fire_event({
type: "data",
time: new Date(),
data: data.data,
data: handle_payload(
data.data,
dependency,
config.components,
"output",
options.with_null_state
),
endpoint: _endpoint,
fn_index
});

View File

@ -274,7 +274,8 @@
: host || space || src || location.origin;
app = await Client.connect(api_url, {
status_callback: handle_status
status_callback: handle_status,
with_null_state: true
});
if (!app.config) {

View File

@ -101,7 +101,7 @@ function find_frontend_folders(start_path: string): string[] {
function to_posix(_path: string): string {
const isExtendedLengthPath = /^\\\\\?\\/.test(_path);
const hasNonAscii = /[^\u0000-\u0080]+/.test(_path); // eslint-disable-line no-control-regex
const hasNonAscii = /[^\u0000-\u0080]+/.test(_path);
if (isExtendedLengthPath || hasNonAscii) {
return _path;