mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-13 11:57:29 +08:00
[ZeroGPU] postMessage zerogpu-headers
(#7877)
* Backup * postMessage origin * postMessage "zerogpu-headers" -> fetch headers * Fix api_factory headers * Oops * ruff format * Prettier * Fix from_config * Works with direct URL (without auth headers) * Check that we're on Spaces * Honor formatting
This commit is contained in:
parent
b561a27816
commit
d28dab8224
@ -12,7 +12,8 @@ import {
|
||||
set_space_timeout,
|
||||
hardware_types,
|
||||
resolve_root,
|
||||
apply_diff
|
||||
apply_diff,
|
||||
post_message
|
||||
} from "./utils.js";
|
||||
|
||||
import type {
|
||||
@ -185,7 +186,8 @@ export function api_factory(
|
||||
async function post_data(
|
||||
url: string,
|
||||
body: unknown,
|
||||
token?: `hf_${string}`
|
||||
token?: `hf_${string}`,
|
||||
additional_headers?: Record<string, string>
|
||||
): Promise<[PostResponse, number]> {
|
||||
const headers: {
|
||||
Authorization?: string;
|
||||
@ -199,7 +201,7 @@ export function api_factory(
|
||||
var response = await fetch_implementation(url, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(body),
|
||||
headers
|
||||
headers: { ...headers, ...additional_headers }
|
||||
});
|
||||
} catch (e) {
|
||||
return [{ error: BROKEN_CONNECTION_MSG }, 500];
|
||||
@ -438,15 +440,18 @@ export function api_factory(
|
||||
): SubmitReturn {
|
||||
let fn_index: number;
|
||||
let api_info;
|
||||
let dependency;
|
||||
|
||||
if (typeof endpoint === "number") {
|
||||
fn_index = endpoint;
|
||||
api_info = api.unnamed_endpoints[fn_index];
|
||||
dependency = config.dependencies[endpoint];
|
||||
} else {
|
||||
const trimmed_endpoint = endpoint.replace(/^\//, "");
|
||||
|
||||
fn_index = api_map[trimmed_endpoint];
|
||||
api_info = api.named_endpoints[endpoint.trim()];
|
||||
dependency = config.dependencies[api_map[trimmed_endpoint]];
|
||||
}
|
||||
|
||||
if (typeof fn_index !== "number") {
|
||||
@ -776,15 +781,27 @@ export function api_factory(
|
||||
fn_index,
|
||||
time: new Date()
|
||||
});
|
||||
|
||||
post_data(
|
||||
`${config.root}/queue/join?${url_params}`,
|
||||
{
|
||||
...payload,
|
||||
session_hash
|
||||
},
|
||||
hf_token
|
||||
).then(([response, status]) => {
|
||||
let hostname = window.location.hostname;
|
||||
let hfhubdev = "dev.spaces.huggingface.tech";
|
||||
const origin = hostname.includes(".dev.")
|
||||
? `https://moon-${hostname.split(".")[1]}.${hfhubdev}`
|
||||
: `https://huggingface.co`;
|
||||
const zerogpu_auth_promise =
|
||||
dependency.zerogpu && window.parent != window && config.space_id
|
||||
? post_message<Headers>("zerogpu-headers", origin)
|
||||
: Promise.resolve(null);
|
||||
const post_data_promise = zerogpu_auth_promise.then((headers) => {
|
||||
return post_data(
|
||||
`${config.root}/queue/join?${url_params}`,
|
||||
{
|
||||
...payload,
|
||||
session_hash
|
||||
},
|
||||
hf_token,
|
||||
headers
|
||||
);
|
||||
});
|
||||
post_data_promise.then(([response, status]) => {
|
||||
if (status === 503) {
|
||||
fire_event({
|
||||
type: "status",
|
||||
|
@ -298,3 +298,17 @@ export function apply_diff(
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
export function post_message<Res = any>(
|
||||
message: any,
|
||||
origin: string
|
||||
): Promise<Res> {
|
||||
return new Promise((res, _rej) => {
|
||||
const channel = new MessageChannel();
|
||||
channel.port1.onmessage = (({ data }) => {
|
||||
channel.port1.close();
|
||||
res(data);
|
||||
}) as (ev: MessageEvent<Res>) => void;
|
||||
window.parent.postMessage(message, origin, [channel.port2]);
|
||||
});
|
||||
}
|
||||
|
@ -850,6 +850,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
original_mapping[o] for o in dependency["outputs"]
|
||||
]
|
||||
dependency.pop("status_tracker", None)
|
||||
dependency.pop("zerogpu")
|
||||
dependency["preprocess"] = False
|
||||
dependency["postprocess"] = False
|
||||
if is_then_event:
|
||||
@ -1148,6 +1149,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
"trigger_only_on_success": trigger_only_on_success,
|
||||
"trigger_mode": trigger_mode,
|
||||
"show_api": show_api,
|
||||
"zerogpu": hasattr(fn, "zerogpu"),
|
||||
}
|
||||
self.dependencies.append(dependency)
|
||||
return dependency, len(self.dependencies) - 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user