mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
Ensure connectivity to private HF spaces with SSE protocol (#8181)
* add msw setup and initialisation tests * add changeset * add eventsource polyfill for node and browser envs * add changeset * add changeset * config tweak * types * update eventsource usage * add changeset * add walk_and_store_blobs improvements and add tests * add changeset * api_info tests * add direct space URL link tests * fix tests * add view_api tests * add post_message test * tweak * add spaces tests * jwt and protocol tests * add post_data tests * test tweaks * dynamically import eventsource * revet eventsource imports * add jwt param to sse requests * add stream test * add changeset * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
7aca673b38
commit
cf52ca6a51
7
.changeset/yummy-paws-eat.md
Normal file
7
.changeset/yummy-paws-eat.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/app": patch
|
||||
"@gradio/client": patch
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Ensure connectivity to private HF spaces with SSE protocol
|
@ -36,6 +36,16 @@ export class NodeBlob extends Blob {
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof window === "undefined") {
|
||||
import("eventsource")
|
||||
.then((EventSourceModule) => {
|
||||
global.EventSource = EventSourceModule.default as any;
|
||||
})
|
||||
.catch((error) =>
|
||||
console.error("Failed to load EventSource module:", error)
|
||||
);
|
||||
}
|
||||
|
||||
export class Client {
|
||||
app_reference: string;
|
||||
options: ClientOptions;
|
||||
@ -51,7 +61,7 @@ export class Client {
|
||||
stream_status = { open: false };
|
||||
pending_stream_messages: Record<string, any[][]> = {};
|
||||
pending_diff_streams: Record<string, any[][]> = {};
|
||||
event_callbacks: Record<string, () => Promise<void>> = {};
|
||||
event_callbacks: Record<string, (data?: unknown) => Promise<void>> = {};
|
||||
unclosed_events: Set<string> = new Set();
|
||||
heartbeat_event: EventSource | null = null;
|
||||
|
||||
|
11
client/js/src/test/mock_eventsource.ts
Normal file
11
client/js/src/test/mock_eventsource.ts
Normal file
@ -0,0 +1,11 @@
|
||||
import { vi } from "vitest";
|
||||
|
||||
Object.defineProperty(window, "EventSource", {
|
||||
writable: true,
|
||||
value: vi.fn().mockImplementation(() => ({
|
||||
close: vi.fn(() => {}),
|
||||
addEventListener: vi.fn(),
|
||||
onmessage: vi.fn((_event: MessageEvent) => {}),
|
||||
onerror: vi.fn((_event: Event) => {})
|
||||
}))
|
||||
});
|
67
client/js/src/test/stream.test.ts
Normal file
67
client/js/src/test/stream.test.ts
Normal file
@ -0,0 +1,67 @@
|
||||
import { vi } from "vitest";
|
||||
import { Client } from "../client";
|
||||
import { initialise_server } from "./server";
|
||||
|
||||
import { describe, it, expect, afterEach } from "vitest";
|
||||
import "./mock_eventsource.ts";
|
||||
|
||||
const server = initialise_server();
|
||||
|
||||
beforeAll(() => server.listen());
|
||||
afterEach(() => server.resetHandlers());
|
||||
afterAll(() => server.close());
|
||||
|
||||
describe("open_stream", () => {
|
||||
let mock_eventsource: any;
|
||||
let app: any;
|
||||
|
||||
beforeEach(async () => {
|
||||
app = await Client.connect("hmb/hello_world");
|
||||
app.eventSource_factory = vi.fn().mockImplementation(() => {
|
||||
mock_eventsource = new EventSource("");
|
||||
return mock_eventsource;
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should throw an error if config is not defined", () => {
|
||||
app.config = undefined;
|
||||
|
||||
expect(() => {
|
||||
app.open_stream();
|
||||
}).toThrow("Could not resolve app config");
|
||||
});
|
||||
|
||||
it("should connect to the SSE endpoint and handle messages", async () => {
|
||||
app.open_stream();
|
||||
|
||||
const eventsource_mock_call = app.eventSource_factory.mock.calls[0][0];
|
||||
|
||||
expect(eventsource_mock_call.href).toMatch(
|
||||
/https:\/\/hmb-hello-world\.hf\.space\/queue\/data\?session_hash/
|
||||
);
|
||||
|
||||
expect(app.eventSource_factory).toHaveBeenCalledWith(eventsource_mock_call);
|
||||
|
||||
const onMessageCallback = mock_eventsource.onmessage;
|
||||
const onErrorCallback = mock_eventsource.onerror;
|
||||
|
||||
const message = { msg: "hello jerry" };
|
||||
|
||||
onMessageCallback({ data: JSON.stringify(message) });
|
||||
expect(app.stream_status.open).toBe(true);
|
||||
|
||||
expect(app.event_callbacks).toEqual({});
|
||||
expect(app.pending_stream_messages).toEqual({});
|
||||
|
||||
const close_stream_message = { msg: "close_stream" };
|
||||
onMessageCallback({ data: JSON.stringify(close_stream_message) });
|
||||
expect(app.stream_status.open).toBe(false);
|
||||
|
||||
onErrorCallback({ data: JSON.stringify("404") });
|
||||
expect(app.stream_status.open).toBe(false);
|
||||
});
|
||||
});
|
@ -7,7 +7,8 @@ export function open_stream(this: Client): void {
|
||||
unclosed_events,
|
||||
pending_stream_messages,
|
||||
stream_status,
|
||||
config
|
||||
config,
|
||||
jwt
|
||||
} = this;
|
||||
|
||||
if (!config) {
|
||||
@ -22,10 +23,16 @@ export function open_stream(this: Client): void {
|
||||
}).toString();
|
||||
|
||||
let url = new URL(`${config.root}/queue/data?${params}`);
|
||||
|
||||
if (jwt) {
|
||||
url.searchParams.set("__sign", jwt);
|
||||
}
|
||||
|
||||
event_source = this.eventSource_factory(url);
|
||||
|
||||
if (!event_source) {
|
||||
throw new Error("Cannot connect to sse endpoint: " + url.toString());
|
||||
console.warn("Cannot connect to SSE endpoint: " + url.toString());
|
||||
return;
|
||||
}
|
||||
|
||||
event_source.onmessage = async function (event: MessageEvent) {
|
||||
@ -37,10 +44,8 @@ export function open_stream(this: Client): void {
|
||||
const event_id = _data.event_id;
|
||||
if (!event_id) {
|
||||
await Promise.all(
|
||||
Object.keys(event_callbacks).map(
|
||||
(event_id) =>
|
||||
// @ts-ignore
|
||||
event_callbacks[event_id](_data) // todo: check event_callbacks
|
||||
Object.keys(event_callbacks).map((event_id) =>
|
||||
event_callbacks[event_id](_data)
|
||||
)
|
||||
);
|
||||
} else if (event_callbacks[event_id] && config) {
|
||||
@ -70,7 +75,6 @@ export function open_stream(this: Client): void {
|
||||
event_source.onerror = async function () {
|
||||
await Promise.all(
|
||||
Object.keys(event_callbacks).map((event_id) =>
|
||||
// @ts-ignore
|
||||
event_callbacks[event_id]({
|
||||
msg: "unexpected_error",
|
||||
message: BROKEN_CONNECTION_MSG
|
||||
|
@ -368,11 +368,15 @@ export function submit(
|
||||
}${params}`
|
||||
);
|
||||
|
||||
if (this.jwt) {
|
||||
url.searchParams.set("__sign", this.jwt);
|
||||
}
|
||||
|
||||
event_source = this.eventSource_factory(url);
|
||||
|
||||
if (!event_source) {
|
||||
throw new Error(
|
||||
"Cannot connect to sse endpoint: " + url.toString()
|
||||
return Promise.reject(
|
||||
new Error("Cannot connect to SSE endpoint: " + url.toString())
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -77,8 +77,10 @@
|
||||
"@gradio/utils": "workspace:^",
|
||||
"@gradio/video": "workspace:^",
|
||||
"@gradio/wasm": "workspace:^",
|
||||
"@types/eventsource": "^1.1.15",
|
||||
"cross-env": "^7.0.3",
|
||||
"d3-dsv": "^3.0.1",
|
||||
"eventsource": "^2.0.2",
|
||||
"mime-types": "^2.1.34",
|
||||
"postcss": "^8.4.21",
|
||||
"postcss-prefix-selector": "^1.16.0"
|
||||
|
@ -520,12 +520,18 @@ importers:
|
||||
'@gradio/wasm':
|
||||
specifier: workspace:^
|
||||
version: link:../wasm
|
||||
'@types/eventsource':
|
||||
specifier: ^1.1.15
|
||||
version: 1.1.15
|
||||
cross-env:
|
||||
specifier: ^7.0.3
|
||||
version: 7.0.3
|
||||
d3-dsv:
|
||||
specifier: ^3.0.1
|
||||
version: 3.0.1
|
||||
eventsource:
|
||||
specifier: ^2.0.2
|
||||
version: 2.0.2
|
||||
mime-types:
|
||||
specifier: ^2.1.34
|
||||
version: 2.1.34
|
||||
|
Loading…
Reference in New Issue
Block a user