[ZeroGPU] postMessage zerogpu-headers (#7877)

* Backup

* postMessage origin

* postMessage "zerogpu-headers" -> fetch headers

* Fix api_factory headers

* Oops

* ruff format

* Prettier

* Fix from_config

* Works with direct URL (without auth headers)

* Check that we're on Spaces

* Honor formatting
This commit is contained in:
Charles 2024-04-09 21:44:15 +02:00 committed by GitHub
parent b561a27816
commit d28dab8224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 12 deletions

View File

@ -12,7 +12,8 @@ import {
set_space_timeout,
hardware_types,
resolve_root,
apply_diff
apply_diff,
post_message
} from "./utils.js";
import type {
@ -185,7 +186,8 @@ export function api_factory(
async function post_data(
url: string,
body: unknown,
token?: `hf_${string}`
token?: `hf_${string}`,
additional_headers?: Record<string, string>
): Promise<[PostResponse, number]> {
const headers: {
Authorization?: string;
@ -199,7 +201,7 @@ export function api_factory(
var response = await fetch_implementation(url, {
method: "POST",
body: JSON.stringify(body),
headers
headers: { ...headers, ...additional_headers }
});
} catch (e) {
return [{ error: BROKEN_CONNECTION_MSG }, 500];
@ -438,15 +440,18 @@ export function api_factory(
): SubmitReturn {
let fn_index: number;
let api_info;
let dependency;
if (typeof endpoint === "number") {
fn_index = endpoint;
api_info = api.unnamed_endpoints[fn_index];
dependency = config.dependencies[endpoint];
} else {
const trimmed_endpoint = endpoint.replace(/^\//, "");
fn_index = api_map[trimmed_endpoint];
api_info = api.named_endpoints[endpoint.trim()];
dependency = config.dependencies[api_map[trimmed_endpoint]];
}
if (typeof fn_index !== "number") {
@ -776,15 +781,27 @@ export function api_factory(
fn_index,
time: new Date()
});
post_data(
`${config.root}/queue/join?${url_params}`,
{
...payload,
session_hash
},
hf_token
).then(([response, status]) => {
let hostname = window.location.hostname;
let hfhubdev = "dev.spaces.huggingface.tech";
const origin = hostname.includes(".dev.")
? `https://moon-${hostname.split(".")[1]}.${hfhubdev}`
: `https://huggingface.co`;
const zerogpu_auth_promise =
dependency.zerogpu && window.parent != window && config.space_id
? post_message<Headers>("zerogpu-headers", origin)
: Promise.resolve(null);
const post_data_promise = zerogpu_auth_promise.then((headers) => {
return post_data(
`${config.root}/queue/join?${url_params}`,
{
...payload,
session_hash
},
hf_token,
headers
);
});
post_data_promise.then(([response, status]) => {
if (status === 503) {
fire_event({
type: "status",

View File

@ -298,3 +298,17 @@ export function apply_diff(
return obj;
}
export function post_message<Res = any>(
message: any,
origin: string
): Promise<Res> {
return new Promise((res, _rej) => {
const channel = new MessageChannel();
channel.port1.onmessage = (({ data }) => {
channel.port1.close();
res(data);
}) as (ev: MessageEvent<Res>) => void;
window.parent.postMessage(message, origin, [channel.port2]);
});
}

View File

@ -850,6 +850,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
original_mapping[o] for o in dependency["outputs"]
]
dependency.pop("status_tracker", None)
dependency.pop("zerogpu")
dependency["preprocess"] = False
dependency["postprocess"] = False
if is_then_event:
@ -1148,6 +1149,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
"trigger_only_on_success": trigger_only_on_success,
"trigger_mode": trigger_mode,
"show_api": show_api,
"zerogpu": hasattr(fn, "zerogpu"),
}
self.dependencies.append(dependency)
return dependency, len(self.dependencies) - 1