Client node fix (#8252)

* fix client in node

* run all client tests in ci

* add changeset

* fix types

* add changeset

* format

* types

* add changeset

* format

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
pngwn 2024-05-09 12:06:28 +01:00 committed by GitHub
parent 05fe4918c0
commit 22df61a26a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 95 additions and 62 deletions

View File

@ -0,0 +1,18 @@
---
"@gradio/app": patch
"@gradio/audio": patch
"@gradio/client": patch
"@gradio/dataframe": patch
"@gradio/file": patch
"@gradio/gallery": patch
"@gradio/image": patch
"@gradio/imageeditor": patch
"@gradio/model3d": patch
"@gradio/multimodaltextbox": patch
"@gradio/simpleimage": patch
"@gradio/upload": patch
"@gradio/video": patch
"gradio": patch
---
fix:Client node fix

View File

@ -65,6 +65,8 @@ jobs:
run: pnpm ts:check
- name: unit tests
run: pnpm test:run
- name: client tests
run: pnpm --filter @gradio/client test
- name: do check
if: always()
uses: "gradio-app/github/actions/commit-status@main"

View File

@ -59,19 +59,18 @@ export class Client {
return fetch(input, init);
}
stream_factory(url: URL): EventSource | null {
async stream(url: URL): Promise<EventSource> {
if (typeof window === "undefined" || typeof EventSource === "undefined") {
import("eventsource")
.then((EventSourceModule) => {
return new EventSourceModule.default(url.toString());
})
.catch((error) =>
console.error("Failed to load EventSource module:", error)
);
try {
const EventSourceModule = await import("eventsource");
return new EventSourceModule.default(url.toString()) as EventSource;
} catch (error) {
console.error("Failed to load EventSource module:", error);
throw error;
}
} else {
return new EventSource(url.toString());
}
return null;
}
view_api: () => Promise<ApiInfo<JsApiData>>;
@ -107,7 +106,7 @@ export class Client {
data?: unknown[],
event_data?: unknown
) => Promise<unknown>;
open_stream: () => void;
open_stream: () => Promise<void>;
private resolve_config: (endpoint: string) => Promise<Config | undefined>;
constructor(app_reference: string, options: ClientOptions = {}) {
this.app_reference = app_reference;
@ -144,7 +143,7 @@ export class Client {
const heartbeat_url = new URL(
`${this.config.root}/heartbeat/${this.session_hash}`
);
this.heartbeat_event = this.stream_factory(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
if (this.config.space_id && this.options.hf_token) {
this.jwt = await get_jwt(

View File

@ -1,11 +1,13 @@
import { vi } from "vitest";
Object.defineProperty(window, "EventSource", {
writable: true,
value: vi.fn().mockImplementation(() => ({
close: vi.fn(() => {}),
addEventListener: vi.fn(),
onmessage: vi.fn((_event: MessageEvent) => {}),
onerror: vi.fn((_event: Event) => {})
}))
});
if (process.env.TEST_MODE !== "node") {
Object.defineProperty(window, "EventSource", {
writable: true,
value: vi.fn().mockImplementation(() => ({
close: vi.fn(() => {}),
addEventListener: vi.fn(),
onmessage: vi.fn((_event: MessageEvent) => {}),
onerror: vi.fn((_event: Event) => {})
}))
});
}

View File

@ -1,11 +1,21 @@
import { vi } from "vitest";
import { vi, type Mock } from "vitest";
import { Client } from "../client";
import { initialise_server } from "./server";
import { describe, it, expect, afterEach } from "vitest";
import {
describe,
it,
expect,
afterEach,
beforeAll,
afterAll,
beforeEach
} from "vitest";
import "./mock_eventsource.ts";
import NodeEventSource from "eventsource";
const server = initialise_server();
const IS_NODE = process.env.TEST_MODE === "node";
beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
@ -13,12 +23,14 @@ afterAll(() => server.close());
describe("open_stream", () => {
let mock_eventsource: any;
let app: any;
let app: Client;
beforeEach(async () => {
app = await Client.connect("hmb/hello_world");
app.stream_factory = vi.fn().mockImplementation(() => {
mock_eventsource = new EventSource("");
app.stream = vi.fn().mockImplementation(() => {
mock_eventsource = IS_NODE
? new NodeEventSource("")
: new EventSource("");
return mock_eventsource;
});
});
@ -30,21 +42,21 @@ describe("open_stream", () => {
it("should throw an error if config is not defined", () => {
app.config = undefined;
expect(() => {
app.open_stream();
}).toThrow("Could not resolve app config");
expect(async () => {
await app.open_stream();
}).rejects.toThrow("Could not resolve app config");
});
it("should connect to the SSE endpoint and handle messages", async () => {
app.open_stream();
await app.open_stream();
const eventsource_mock_call = app.stream_factory.mock.calls[0][0];
const eventsource_mock_call = (app.stream as Mock).mock.calls[0][0];
expect(eventsource_mock_call.href).toMatch(
/https:\/\/hmb-hello-world\.hf\.space\/queue\/data\?session_hash/
);
expect(app.stream_factory).toHaveBeenCalledWith(eventsource_mock_call);
expect(app.stream).toHaveBeenCalledWith(eventsource_mock_call);
const onMessageCallback = mock_eventsource.onmessage;
const onErrorCallback = mock_eventsource.onerror;

View File

@ -1,7 +1,7 @@
import { BROKEN_CONNECTION_MSG } from "../constants";
import type { Client } from "../client";
export function open_stream(this: Client): void {
export async function open_stream(this: Client): Promise<void> {
let {
event_callbacks,
unclosed_events,
@ -28,7 +28,7 @@ export function open_stream(this: Client): void {
url.searchParams.set("__sign", jwt);
}
stream = this.stream_factory(url);
stream = await this.stream(url);
if (!stream) {
console.warn("Cannot connect to SSE endpoint: " + url.toString());

View File

@ -372,7 +372,7 @@ export function submit(
url.searchParams.set("__sign", this.jwt);
}
stream = this.stream_factory(url);
stream = await this.stream(url);
if (!stream) {
return Promise.reject(
@ -503,7 +503,7 @@ export function submit(
headers
);
});
post_data_promise.then(([response, status]: any) => {
post_data_promise.then(async ([response, status]: any) => {
if (status === 503) {
fire_event({
type: "status",
@ -655,7 +655,7 @@ export function submit(
event_callbacks[event_id] = callback;
unclosed_events.add(event_id);
if (!stream_status.open) {
this.open_stream();
await this.open_stream();
}
}
});

View File

@ -131,7 +131,7 @@ export function create(options: Options): GradioAppController {
return wasm_proxied_fetch(worker_proxy, input, init);
}
stream_factory(url: URL): EventSource {
async stream(url: URL): Promise<EventSource> {
return wasm_proxied_stream_factory(worker_proxy, url);
}
}

View File

@ -226,7 +226,7 @@
{waveform_options}
{trim_region_settings}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
>
<UploadText i18n={gradio.i18n} type="audio" />
</InteractiveAudio>

View File

@ -35,7 +35,7 @@
export let editable = true;
export let max_file_size: number | null = null;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
$: dispatch("drag", dragging);

View File

@ -151,6 +151,6 @@
{line_breaks}
{column_widths}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
/>
</Block>

View File

@ -40,7 +40,7 @@
export let line_breaks = true;
export let column_widths: string[] = [];
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
let selected: false | [number, number] = false;
export let display_value: string[][] | null = null;

View File

@ -88,7 +88,7 @@
{:else}
<FileUpload
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
{label}
{show_label}
{value}

View File

@ -20,7 +20,7 @@
export let i18n: I18nFormatter;
export let max_file_size: number | null = null;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
async function handle_upload({
detail

View File

@ -78,7 +78,7 @@
file_types={["image"]}
i18n={gradio.i18n}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
on:upload={(e) => {
const files = Array.isArray(e.detail) ? e.detail : [e.detail];
value = files.map((x) => ({ image: x, caption: null }));

View File

@ -157,7 +157,7 @@
max_file_size={gradio.max_file_size}
i18n={gradio.i18n}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
>
{#if active_source === "upload" || !active_source}
<UploadText i18n={gradio.i18n} type="image" />

View File

@ -27,7 +27,7 @@
export let i18n: I18nFormatter;
export let max_file_size: number | null = null;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
let upload_input: Upload;
let uploading = false;

View File

@ -208,7 +208,7 @@
{layers}
status={loading_status?.status}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
></InteractiveImageEditor>
</Block>
{/if}

View File

@ -47,7 +47,7 @@
export let canvas_size: [number, number] | undefined;
export let realtime: boolean;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
const dispatch = createEventDispatcher<{
clear?: never;

View File

@ -26,7 +26,7 @@
export let mirror_webcam = true;
export let i18n: I18nFormatter;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
const { active_tool } = getContext<ToolContext>(TOOL_KEY);
const { pixi, dimensions, register_context, reset, editor_box } =

View File

@ -128,7 +128,7 @@
i18n={gradio.i18n}
max_file_size={gradio.max_file_size}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
>
<UploadText i18n={gradio.i18n} type="file" />
</Model3DUpload>

View File

@ -25,7 +25,7 @@
null
];
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
async function handle_upload({
detail

View File

@ -98,6 +98,6 @@
}}
disabled={!interactive}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
/>
</Block>

View File

@ -36,7 +36,7 @@
export let file_types: string[] | null = null;
export let max_file_size: number | null = null;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
let upload_component: Upload;
let hidden_upload: HTMLInputElement;

View File

@ -91,7 +91,7 @@
<ImageUploader
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
bind:value
{root}
on:clear={() => gradio.dispatch("clear")}

View File

@ -12,7 +12,7 @@
export let show_label: boolean;
export let root: string;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
let upload_component: Upload;
let uploading = false;

View File

@ -20,7 +20,7 @@
export let show_progress = true;
export let max_file_size: number | null = null;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
let upload_id: string;
let file_data: FileData[];

View File

@ -7,9 +7,9 @@
export let upload_id: string;
export let root: string;
export let files: FileData[];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
let stream: ReturnType<Client["stream_factory"]>;
let stream: Awaited<ReturnType<Client["stream"]>>;
let progress = false;
let current_file_upload: FileDataWithProgress;
let file_to_display: FileDataWithProgress;
@ -37,8 +37,8 @@
return (file.progress * 100) / (file.size || 0) || 0;
}
onMount(() => {
stream = stream_handler(
onMount(async () => {
stream = await stream_handler(
new URL(`${root}/upload_progress?upload_id=${upload_id}`)
);

View File

@ -208,7 +208,7 @@
i18n={gradio.i18n}
max_file_size={gradio.max_file_size}
upload={gradio.client.upload}
stream_handler={gradio.client.stream_factory}
stream_handler={gradio.client.stream}
>
<UploadText i18n={gradio.i18n} type="video" />
</Video>

View File

@ -30,7 +30,7 @@
export let handle_reset_value: () => void = () => {};
export let max_file_size: number | null = null;
export let upload: Client["upload"];
export let stream_handler: Client["stream_factory"];
export let stream_handler: Client["stream"];
const dispatch = createEventDispatcher<{
change: FileData | null;