Ensure connectivity to private HF spaces with SSE protocol (#8181)

* add msw setup and initialisation tests

* add changeset

* add eventsource polyfill for node and browser envs

* add changeset

* add changeset

* config tweak

* types

* update eventsource usage

* add changeset

* add walk_and_store_blobs improvements and add tests

* add changeset

* api_info tests

* add direct space URL link tests

* fix tests

* add view_api tests

* add post_message test

* tweak

* add spaces tests

* jwt and protocol tests

* add post_data tests

* test tweaks

* dynamically import eventsource

* revet eventsource imports

* add jwt param to sse requests

* add stream test

* add changeset

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Hannah 2024-05-02 23:05:56 +02:00 committed by GitHub
parent 7aca673b38
commit cf52ca6a51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 121 additions and 10 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/app": patch
"@gradio/client": patch
"gradio": patch
---
fix:Ensure connectivity to private HF spaces with SSE protocol

View File

@ -36,6 +36,16 @@ export class NodeBlob extends Blob {
}
}
if (typeof window === "undefined") {
import("eventsource")
.then((EventSourceModule) => {
global.EventSource = EventSourceModule.default as any;
})
.catch((error) =>
console.error("Failed to load EventSource module:", error)
);
}
export class Client {
app_reference: string;
options: ClientOptions;
@ -51,7 +61,7 @@ export class Client {
stream_status = { open: false };
pending_stream_messages: Record<string, any[][]> = {};
pending_diff_streams: Record<string, any[][]> = {};
event_callbacks: Record<string, () => Promise<void>> = {};
event_callbacks: Record<string, (data?: unknown) => Promise<void>> = {};
unclosed_events: Set<string> = new Set();
heartbeat_event: EventSource | null = null;

View File

@ -0,0 +1,11 @@
import { vi } from "vitest";
Object.defineProperty(window, "EventSource", {
writable: true,
value: vi.fn().mockImplementation(() => ({
close: vi.fn(() => {}),
addEventListener: vi.fn(),
onmessage: vi.fn((_event: MessageEvent) => {}),
onerror: vi.fn((_event: Event) => {})
}))
});

View File

@ -0,0 +1,67 @@
import { vi } from "vitest";
import { Client } from "../client";
import { initialise_server } from "./server";
import { describe, it, expect, afterEach } from "vitest";
import "./mock_eventsource.ts";
const server = initialise_server();
beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
afterAll(() => server.close());
describe("open_stream", () => {
let mock_eventsource: any;
let app: any;
beforeEach(async () => {
app = await Client.connect("hmb/hello_world");
app.eventSource_factory = vi.fn().mockImplementation(() => {
mock_eventsource = new EventSource("");
return mock_eventsource;
});
});
afterEach(() => {
vi.clearAllMocks();
});
it("should throw an error if config is not defined", () => {
app.config = undefined;
expect(() => {
app.open_stream();
}).toThrow("Could not resolve app config");
});
it("should connect to the SSE endpoint and handle messages", async () => {
app.open_stream();
const eventsource_mock_call = app.eventSource_factory.mock.calls[0][0];
expect(eventsource_mock_call.href).toMatch(
/https:\/\/hmb-hello-world\.hf\.space\/queue\/data\?session_hash/
);
expect(app.eventSource_factory).toHaveBeenCalledWith(eventsource_mock_call);
const onMessageCallback = mock_eventsource.onmessage;
const onErrorCallback = mock_eventsource.onerror;
const message = { msg: "hello jerry" };
onMessageCallback({ data: JSON.stringify(message) });
expect(app.stream_status.open).toBe(true);
expect(app.event_callbacks).toEqual({});
expect(app.pending_stream_messages).toEqual({});
const close_stream_message = { msg: "close_stream" };
onMessageCallback({ data: JSON.stringify(close_stream_message) });
expect(app.stream_status.open).toBe(false);
onErrorCallback({ data: JSON.stringify("404") });
expect(app.stream_status.open).toBe(false);
});
});

View File

@ -7,7 +7,8 @@ export function open_stream(this: Client): void {
unclosed_events,
pending_stream_messages,
stream_status,
config
config,
jwt
} = this;
if (!config) {
@ -22,10 +23,16 @@ export function open_stream(this: Client): void {
}).toString();
let url = new URL(`${config.root}/queue/data?${params}`);
if (jwt) {
url.searchParams.set("__sign", jwt);
}
event_source = this.eventSource_factory(url);
if (!event_source) {
throw new Error("Cannot connect to sse endpoint: " + url.toString());
console.warn("Cannot connect to SSE endpoint: " + url.toString());
return;
}
event_source.onmessage = async function (event: MessageEvent) {
@ -37,10 +44,8 @@ export function open_stream(this: Client): void {
const event_id = _data.event_id;
if (!event_id) {
await Promise.all(
Object.keys(event_callbacks).map(
(event_id) =>
// @ts-ignore
event_callbacks[event_id](_data) // todo: check event_callbacks
Object.keys(event_callbacks).map((event_id) =>
event_callbacks[event_id](_data)
)
);
} else if (event_callbacks[event_id] && config) {
@ -70,7 +75,6 @@ export function open_stream(this: Client): void {
event_source.onerror = async function () {
await Promise.all(
Object.keys(event_callbacks).map((event_id) =>
// @ts-ignore
event_callbacks[event_id]({
msg: "unexpected_error",
message: BROKEN_CONNECTION_MSG

View File

@ -368,11 +368,15 @@ export function submit(
}${params}`
);
if (this.jwt) {
url.searchParams.set("__sign", this.jwt);
}
event_source = this.eventSource_factory(url);
if (!event_source) {
throw new Error(
"Cannot connect to sse endpoint: " + url.toString()
return Promise.reject(
new Error("Cannot connect to SSE endpoint: " + url.toString())
);
}

View File

@ -77,8 +77,10 @@
"@gradio/utils": "workspace:^",
"@gradio/video": "workspace:^",
"@gradio/wasm": "workspace:^",
"@types/eventsource": "^1.1.15",
"cross-env": "^7.0.3",
"d3-dsv": "^3.0.1",
"eventsource": "^2.0.2",
"mime-types": "^2.1.34",
"postcss": "^8.4.21",
"postcss-prefix-selector": "^1.16.0"

View File

@ -520,12 +520,18 @@ importers:
'@gradio/wasm':
specifier: workspace:^
version: link:../wasm
'@types/eventsource':
specifier: ^1.1.15
version: 1.1.15
cross-env:
specifier: ^7.0.3
version: 7.0.3
d3-dsv:
specifier: ^3.0.1
version: 3.0.1
eventsource:
specifier: ^2.0.2
version: 2.0.2
mime-types:
specifier: ^2.1.34
version: 2.1.34