mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Allow JS Client to work with authenticated spaces 🍪 (#8299)
* get cookie from /login and apply to fetch requests in connect + duplicate funcs * add error msg * add tests * improve error msgs * remove unused var * add changeset * remove comment * add error msg * add private space test --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
2d705bcf74
commit
ab65360804
6
.changeset/lemon-bugs-brake.md
Normal file
6
.changeset/lemon-bugs-brake.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"@gradio/client": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Allow JS Client to work with authenticated spaces 🍪
|
@ -23,8 +23,10 @@ import { submit } from "./utils/submit";
|
||||
import { RE_SPACE_NAME, process_endpoint } from "./helpers/api_info";
|
||||
import {
|
||||
map_names_to_ids,
|
||||
resolve_cookies,
|
||||
resolve_config,
|
||||
get_jwt
|
||||
get_jwt,
|
||||
parse_and_set_cookies
|
||||
} from "./helpers/init_helpers";
|
||||
import { check_space_status } from "./helpers/spaces";
|
||||
import { open_stream } from "./utils/stream";
|
||||
@ -47,6 +49,8 @@ export class Client {
|
||||
jwt: string | false = false;
|
||||
last_status: Record<string, Status["stage"]> = {};
|
||||
|
||||
private cookies: string | null = null;
|
||||
|
||||
// streaming
|
||||
stream_status = { open: false };
|
||||
pending_stream_messages: Record<string, any[][]> = {};
|
||||
@ -56,7 +60,12 @@ export class Client {
|
||||
heartbeat_event: EventSource | null = null;
|
||||
|
||||
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
|
||||
return fetch(input, init);
|
||||
const headers = new Headers(init?.headers || {});
|
||||
if (this && this.cookies) {
|
||||
headers.append("Cookie", this.cookies);
|
||||
}
|
||||
|
||||
return fetch(input, { ...init, headers });
|
||||
}
|
||||
|
||||
async stream(url: URL): Promise<EventSource> {
|
||||
@ -108,6 +117,7 @@ export class Client {
|
||||
) => Promise<SubmitReturn>;
|
||||
open_stream: () => Promise<void>;
|
||||
private resolve_config: (endpoint: string) => Promise<Config | undefined>;
|
||||
private resolve_cookies: () => Promise<void>;
|
||||
constructor(app_reference: string, options: ClientOptions = {}) {
|
||||
this.app_reference = app_reference;
|
||||
this.options = options;
|
||||
@ -120,6 +130,7 @@ export class Client {
|
||||
this.predict = predict.bind(this);
|
||||
this.open_stream = open_stream.bind(this);
|
||||
this.resolve_config = resolve_config.bind(this);
|
||||
this.resolve_cookies = resolve_cookies.bind(this);
|
||||
this.upload = upload.bind(this);
|
||||
}
|
||||
|
||||
@ -135,8 +146,23 @@ export class Client {
|
||||
}
|
||||
|
||||
try {
|
||||
if (this.options.auth) {
|
||||
await this.resolve_cookies();
|
||||
}
|
||||
|
||||
await this._resolve_config().then(async ({ config }) => {
|
||||
this.config = config;
|
||||
if (config) {
|
||||
this.config = config;
|
||||
if (this.config && this.config.connect_heartbeat) {
|
||||
if (this.config.space_id && this.options.hf_token) {
|
||||
this.jwt = await get_jwt(
|
||||
this.config.space_id,
|
||||
this.options.hf_token,
|
||||
this.cookies
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (config.space_id && this.options.hf_token) {
|
||||
this.jwt = await get_jwt(config.space_id, this.options.hf_token);
|
||||
@ -156,8 +182,8 @@ export class Client {
|
||||
this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540
|
||||
}
|
||||
});
|
||||
} catch (e) {
|
||||
throw Error(CONFIG_ERROR_MSG + (e as Error).message);
|
||||
} catch (e: any) {
|
||||
throw Error(e);
|
||||
}
|
||||
|
||||
this.api_info = await this.view_api();
|
||||
@ -201,9 +227,8 @@ export class Client {
|
||||
}
|
||||
|
||||
return this.config_success(config);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (space_id) {
|
||||
} catch (e: any) {
|
||||
if (space_id && status_callback) {
|
||||
check_space_status(
|
||||
space_id,
|
||||
RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain",
|
||||
@ -217,6 +242,7 @@ export class Client {
|
||||
load_status: "error",
|
||||
detail: "NOT_FOUND"
|
||||
});
|
||||
throw Error(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -246,6 +272,9 @@ export class Client {
|
||||
}
|
||||
|
||||
async handle_space_success(status: SpaceStatus): Promise<Config | void> {
|
||||
if (!this) {
|
||||
throw new Error(CONFIG_ERROR_MSG);
|
||||
}
|
||||
const { status_callback } = this.options;
|
||||
if (status_callback) status_callback(status);
|
||||
if (status.status === "running") {
|
||||
@ -259,7 +288,6 @@ export class Client {
|
||||
|
||||
return _config as Config;
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (status_callback) {
|
||||
status_callback({
|
||||
status: "error",
|
||||
@ -268,6 +296,7 @@ export class Client {
|
||||
detail: "NOT_FOUND"
|
||||
});
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -333,7 +362,8 @@ export class Client {
|
||||
const response = await this.fetch(`${root_url}/component_server/`, {
|
||||
method: "POST",
|
||||
body: body,
|
||||
headers
|
||||
headers,
|
||||
credentials: "include"
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@ -349,6 +379,10 @@ export class Client {
|
||||
}
|
||||
}
|
||||
|
||||
public set_cookies(raw_cookies: string): void {
|
||||
this.cookies = parse_and_set_cookies(raw_cookies).join("; ");
|
||||
}
|
||||
|
||||
private prepare_return_obj(): client_return {
|
||||
return {
|
||||
config: this.config,
|
||||
|
@ -25,3 +25,7 @@ export const CONFIG_ERROR_MSG = "Could not resolve app config. ";
|
||||
export const SPACE_STATUS_ERROR_MSG = "Could not get space status. ";
|
||||
export const API_INFO_ERROR_MSG = "Could not get API info. ";
|
||||
export const SPACE_METADATA_ERROR_MSG = "Space metadata could not be loaded. ";
|
||||
export const UNAUTHORIZED_MSG = "Not authorized to access this space. ";
|
||||
export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. ";
|
||||
export const MISSING_CREDENTIALS_MSG =
|
||||
"Login credentials are required to access this space.";
|
||||
|
@ -1,5 +1,5 @@
|
||||
import type { Status } from "../types";
|
||||
import { QUEUE_FULL_MSG } from "../constants";
|
||||
import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants";
|
||||
import type { ApiData, ApiInfo, Config, JsApiData } from "../types";
|
||||
import { determine_protocol } from "./init_helpers";
|
||||
|
||||
@ -36,9 +36,7 @@ export async function process_endpoint(
|
||||
...determine_protocol(_host)
|
||||
};
|
||||
} catch (e) {
|
||||
throw new Error(
|
||||
"Space metadata could not be loaded. " + (e as Error).message
|
||||
);
|
||||
throw new Error(SPACE_METADATA_ERROR_MSG);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,15 @@
|
||||
import type { Config } from "../types";
|
||||
import { CONFIG_ERROR_MSG, CONFIG_URL } from "../constants";
|
||||
import {
|
||||
CONFIG_ERROR_MSG,
|
||||
CONFIG_URL,
|
||||
INVALID_CREDENTIALS_MSG,
|
||||
LOGIN_URL,
|
||||
MISSING_CREDENTIALS_MSG,
|
||||
SPACE_METADATA_ERROR_MSG,
|
||||
UNAUTHORIZED_MSG
|
||||
} from "../constants";
|
||||
import { Client } from "..";
|
||||
import { process_endpoint } from "./api_info";
|
||||
|
||||
/**
|
||||
* This function is used to resolve the URL for making requests when the app has a root path.
|
||||
@ -25,12 +34,14 @@ export function resolve_root(
|
||||
|
||||
export async function get_jwt(
|
||||
space: string,
|
||||
token: `hf_${string}`
|
||||
token: `hf_${string}`,
|
||||
cookies?: string | null
|
||||
): Promise<string | false> {
|
||||
try {
|
||||
const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`
|
||||
Authorization: `Bearer ${token}`,
|
||||
...(cookies ? { Cookie: cookies } : {})
|
||||
}
|
||||
});
|
||||
|
||||
@ -76,14 +87,22 @@ export async function resolve_config(
|
||||
return { ...config, path } as Config;
|
||||
} else if (endpoint) {
|
||||
const response = await this.fetch(`${endpoint}/${CONFIG_URL}`, {
|
||||
headers
|
||||
headers,
|
||||
credentials: "include"
|
||||
});
|
||||
|
||||
if (response?.status === 401 && !this.options.auth) {
|
||||
throw new Error(MISSING_CREDENTIALS_MSG);
|
||||
} else if (response?.status === 401 && this.options.auth) {
|
||||
throw new Error(INVALID_CREDENTIALS_MSG);
|
||||
}
|
||||
if (response?.status === 200) {
|
||||
let config = await response.json();
|
||||
config.path = config.path ?? "";
|
||||
config.root = endpoint;
|
||||
return config;
|
||||
} else if (response?.status === 401) {
|
||||
throw new Error(UNAUTHORIZED_MSG);
|
||||
}
|
||||
throw new Error(CONFIG_ERROR_MSG);
|
||||
}
|
||||
@ -91,6 +110,63 @@ export async function resolve_config(
|
||||
throw new Error(CONFIG_ERROR_MSG);
|
||||
}
|
||||
|
||||
export async function resolve_cookies(this: Client): Promise<void> {
|
||||
const { http_protocol, host } = await process_endpoint(
|
||||
this.app_reference,
|
||||
this.options.hf_token
|
||||
);
|
||||
|
||||
try {
|
||||
if (this.options.auth) {
|
||||
const cookie_header = await get_cookie_header(
|
||||
http_protocol,
|
||||
host,
|
||||
this.options.auth,
|
||||
this.fetch,
|
||||
this.options.hf_token
|
||||
);
|
||||
|
||||
if (cookie_header) this.set_cookies(cookie_header);
|
||||
}
|
||||
} catch (e: unknown) {
|
||||
throw Error((e as Error).message);
|
||||
}
|
||||
}
|
||||
|
||||
// separating this from client-bound resolve_cookies so that it can be used in duplicate
|
||||
export async function get_cookie_header(
|
||||
http_protocol: string,
|
||||
host: string,
|
||||
auth: [string, string],
|
||||
_fetch: typeof fetch,
|
||||
hf_token?: `hf_${string}`
|
||||
): Promise<string | null> {
|
||||
const formData = new FormData();
|
||||
formData.append("username", auth?.[0]);
|
||||
formData.append("password", auth?.[1]);
|
||||
|
||||
let headers: { Authorization?: string } = {};
|
||||
|
||||
if (hf_token) {
|
||||
headers.Authorization = `Bearer ${hf_token}`;
|
||||
}
|
||||
|
||||
const res = await _fetch(`${http_protocol}//${host}/${LOGIN_URL}`, {
|
||||
headers,
|
||||
method: "POST",
|
||||
body: formData,
|
||||
credentials: "include"
|
||||
});
|
||||
|
||||
if (res.status === 200) {
|
||||
return res.headers.get("set-cookie");
|
||||
} else if (res.status === 401) {
|
||||
throw new Error(INVALID_CREDENTIALS_MSG);
|
||||
} else {
|
||||
throw new Error(SPACE_METADATA_ERROR_MSG);
|
||||
}
|
||||
}
|
||||
|
||||
export function determine_protocol(endpoint: string): {
|
||||
ws_protocol: "ws" | "wss";
|
||||
http_protocol: "http:" | "https:";
|
||||
@ -128,3 +204,15 @@ export function determine_protocol(endpoint: string): {
|
||||
host: endpoint
|
||||
};
|
||||
}
|
||||
|
||||
export const parse_and_set_cookies = (cookie_header: string): string[] => {
|
||||
let cookies: string[] = [];
|
||||
const parts = cookie_header.split(/,(?=\s*[^\s=;]+=[^\s=;]+)/);
|
||||
parts.forEach((cookie) => {
|
||||
const [cookie_name, cookie_value] = cookie.split(";")[0].split("=");
|
||||
if (cookie_name && cookie_value) {
|
||||
cookies.push(`${cookie_name.trim()}=${cookie_value.trim()}`);
|
||||
}
|
||||
});
|
||||
return cookies;
|
||||
};
|
||||
|
@ -435,9 +435,7 @@ describe("process_endpoint", () => {
|
||||
try {
|
||||
await process_endpoint(app_reference, hf_token);
|
||||
} catch (error) {
|
||||
expect(error.message).toEqual(
|
||||
SPACE_METADATA_ERROR_MSG + "Unexpected end of JSON input"
|
||||
);
|
||||
expect(error.message).toEqual(SPACE_METADATA_ERROR_MSG);
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -6,7 +6,8 @@ import {
|
||||
RUNTIME_URL,
|
||||
SLEEPTIME_URL,
|
||||
UPLOAD_URL,
|
||||
BROKEN_CONNECTION_MSG
|
||||
BROKEN_CONNECTION_MSG,
|
||||
LOGIN_URL
|
||||
} from "../constants";
|
||||
import {
|
||||
response_api_info,
|
||||
@ -22,16 +23,24 @@ const root_url = "https://huggingface.co";
|
||||
|
||||
const direct_space_url = "https://hmb-hello-world.hf.space";
|
||||
const private_space_url = "https://hmb-secret-world.hf.space";
|
||||
const private_auth_space_url = "https://hmb-private-auth-space.hf.space";
|
||||
|
||||
const server_error_space_url = "https://hmb-server-error.hf.space";
|
||||
const upload_server_test_space_url = "https://hmb-server-test.hf.space";
|
||||
const server_error_reference = "hmb/server_error";
|
||||
const auth_app_space_url = "https://hmb-auth-space.hf.space";
|
||||
const unauth_app_space_url = "https://hmb-unauth-space.hf.space";
|
||||
const invalid_auth_space_url = "https://hmb-invalid-auth-space.hf.space";
|
||||
|
||||
const server_error_reference = "hmb/server_error";
|
||||
const app_reference = "hmb/hello_world";
|
||||
const broken_app_reference = "hmb/bye_world";
|
||||
const duplicate_app_reference = "gradio/hello_world";
|
||||
const private_app_reference = "hmb/secret_world";
|
||||
const server_test_app_reference = "hmb/server_test";
|
||||
const auth_app_reference = "hmb/auth_space";
|
||||
const unauth_app_reference = "hmb/unauth_space";
|
||||
const invalid_auth_app_reference = "hmb/invalid_auth_space";
|
||||
const private_auth_app_reference = "hmb/private_auth_space";
|
||||
|
||||
export const handlers: RequestHandler[] = [
|
||||
// /host requests
|
||||
@ -58,6 +67,23 @@ export const handlers: RequestHandler[] = [
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.get(
|
||||
`${root_url}/api/spaces/${private_auth_app_reference}/${HOST_URL}`,
|
||||
() => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
subdomain: "hmb-private-auth-space",
|
||||
host: "https://hmb-private-auth-space.hf.space"
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
),
|
||||
http.get(
|
||||
`${root_url}/api/spaces/${private_app_reference}/${HOST_URL}`,
|
||||
({ request }) => {
|
||||
@ -120,6 +146,68 @@ export const handlers: RequestHandler[] = [
|
||||
);
|
||||
}
|
||||
),
|
||||
http.get(`${root_url}/api/spaces/${auth_app_reference}/${HOST_URL}`, () => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
subdomain: "hmb-auth-space",
|
||||
host: "https://hmb-auth-space.hf.space"
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}),
|
||||
http.get(
|
||||
`${root_url}/api/spaces/${invalid_auth_app_reference}/${HOST_URL}`,
|
||||
() => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
subdomain: "hmb-invalid-auth-space",
|
||||
host: "https://hmb-invalid-auth-space.hf.space"
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
),
|
||||
http.get(
|
||||
`${root_url}/api/spaces/${duplicate_app_reference}/${HOST_URL}`,
|
||||
() => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
subdomain: "gradio-hello-world",
|
||||
host: "https://gradio-hello-world.hf.space"
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
),
|
||||
http.get(`${root_url}/api/spaces/${unauth_app_reference}/${HOST_URL}`, () => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
subdomain: "hmb-unath-space",
|
||||
host: "https://hmb-unauth-space.hf.space"
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}),
|
||||
// /info requests
|
||||
http.get(`${direct_space_url}/${API_INFO_URL}`, () => {
|
||||
return new HttpResponse(JSON.stringify(response_api_info), {
|
||||
@ -153,6 +241,22 @@ export const handlers: RequestHandler[] = [
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.get(`${auth_app_space_url}/${API_INFO_URL}`, async () => {
|
||||
return new HttpResponse(JSON.stringify(response_api_info), {
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.get(`${private_auth_space_url}/${API_INFO_URL}`, async () => {
|
||||
return new HttpResponse(JSON.stringify(response_api_info), {
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
});
|
||||
}),
|
||||
// /config requests
|
||||
http.get(`${direct_space_url}/${CONFIG_URL}`, () => {
|
||||
return new HttpResponse(JSON.stringify(config_response), {
|
||||
@ -190,6 +294,20 @@ export const handlers: RequestHandler[] = [
|
||||
}
|
||||
);
|
||||
}),
|
||||
http.get(`${private_auth_space_url}/${CONFIG_URL}`, () => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
...config_response,
|
||||
root: "https://hmb-private-auth-space.hf.space"
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}),
|
||||
http.get(`${direct_space_url}/${CONFIG_URL}`, () => {
|
||||
return new HttpResponse(JSON.stringify(config_response), {
|
||||
status: 500,
|
||||
@ -206,6 +324,42 @@ export const handlers: RequestHandler[] = [
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.get(`${invalid_auth_space_url}/${CONFIG_URL}`, () => {
|
||||
return new HttpResponse(JSON.stringify({ detail: "Unauthorized" }), {
|
||||
status: 401,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.get(`${auth_app_space_url}/${CONFIG_URL}`, ({ request }) => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
...config_response,
|
||||
root: "https://hmb-auth-space.hf.space",
|
||||
space_id: "hmb/auth_space"
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}),
|
||||
http.get(`${unauth_app_space_url}/${CONFIG_URL}`, () => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
detail: "Unauthorized"
|
||||
}),
|
||||
{
|
||||
status: 401,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}),
|
||||
// /whoami requests
|
||||
http.get(`${root_url}/api/whoami-v2`, () => {
|
||||
return new HttpResponse(JSON.stringify(whoami_response), {
|
||||
@ -387,6 +541,20 @@ export const handlers: RequestHandler[] = [
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.get(`${root_url}/api/spaces/${unauth_app_reference}`, () => {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
id: unauth_app_reference,
|
||||
runtime: { ...runtime_response }
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
}),
|
||||
// jwt requests
|
||||
http.get(`${root_url}/api/spaces/${app_reference}/jwt`, () => {
|
||||
return new HttpResponse(
|
||||
@ -434,5 +602,84 @@ export const handlers: RequestHandler[] = [
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
});
|
||||
}),
|
||||
// login requests
|
||||
http.post(`${auth_app_space_url}/${LOGIN_URL}`, async ({ request }) => {
|
||||
let username;
|
||||
let password;
|
||||
|
||||
await request.formData().then((data) => {
|
||||
username = data.get("username");
|
||||
password = data.get("password");
|
||||
});
|
||||
|
||||
if (username === "admin" && password === "pass1234") {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
success: true
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"Set-Cookie":
|
||||
"access-token-123=abc; HttpOnly; Path=/; SameSite=none; Secure",
|
||||
// @ts-ignore - multiple Set-Cookie headers are returned
|
||||
"Set-Cookie":
|
||||
"access-token-unsecure-123=abc; HttpOnly; Path=/; SameSite=none; Secure"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return new HttpResponse(null, {
|
||||
status: 401,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.post(`${invalid_auth_space_url}/${LOGIN_URL}`, async () => {
|
||||
return new HttpResponse(null, {
|
||||
status: 401,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
});
|
||||
}),
|
||||
http.post(`${private_auth_space_url}/${LOGIN_URL}`, async ({ request }) => {
|
||||
let username;
|
||||
let password;
|
||||
|
||||
await request.formData().then((data) => {
|
||||
username = data.get("username");
|
||||
password = data.get("password");
|
||||
});
|
||||
|
||||
if (username === "admin" && password === "pass1234") {
|
||||
return new HttpResponse(
|
||||
JSON.stringify({
|
||||
success: true
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"Set-Cookie":
|
||||
"access-token-123=abc; HttpOnly; Path=/; SameSite=none; Secure",
|
||||
// @ts-ignore - multiple Set-Cookie headers are returned
|
||||
"Set-Cookie":
|
||||
"access-token-unsecure-123=abc; HttpOnly; Path=/; SameSite=none; Secure"
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return new HttpResponse(null, {
|
||||
status: 401,
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
});
|
||||
})
|
||||
];
|
||||
|
@ -67,7 +67,7 @@ describe("Client class", () => {
|
||||
Client.connect("hmb/secret_world", {
|
||||
hf_token: "hf_bad_token"
|
||||
})
|
||||
).rejects.toThrow(CONFIG_ERROR_MSG);
|
||||
).rejects.toThrowError(SPACE_METADATA_ERROR_MSG);
|
||||
});
|
||||
|
||||
test("viewing the api info of a running app", async () => {
|
||||
@ -77,7 +77,7 @@ describe("Client class", () => {
|
||||
|
||||
test("viewing the api info of a non-existent app", async () => {
|
||||
const app = Client.connect(broken_app_reference);
|
||||
await expect(app).rejects.toThrow(CONFIG_ERROR_MSG);
|
||||
await expect(app).rejects.toThrowError();
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -1,10 +1,13 @@
|
||||
import {
|
||||
resolve_root,
|
||||
get_jwt,
|
||||
determine_protocol
|
||||
determine_protocol,
|
||||
parse_and_set_cookies
|
||||
} from "../helpers/init_helpers";
|
||||
import { initialise_server } from "./server";
|
||||
import { beforeAll, afterEach, afterAll, it, expect, describe } from "vitest";
|
||||
import { Client } from "../client";
|
||||
import { INVALID_CREDENTIALS_MSG, MISSING_CREDENTIALS_MSG } from "../constants";
|
||||
|
||||
const server = initialise_server();
|
||||
|
||||
@ -92,3 +95,52 @@ describe("determine_protocol", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("parse_and_set_cookies", () => {
|
||||
it("should return an empty array when the cookie header is empty", () => {
|
||||
const cookie_header = "";
|
||||
const result = parse_and_set_cookies(cookie_header);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it("should parse the cookie header and return an array of cookies", () => {
|
||||
const cookie_header = "access-token-123=abc;access-token-unsecured-456=def";
|
||||
const result = parse_and_set_cookies(cookie_header);
|
||||
expect(result).toEqual(["access-token-123=abc"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("resolve_cookies", () => {
|
||||
it("should set the cookies when correct auth credentials are provided", async () => {
|
||||
const client = await Client.connect("hmb/auth_space", {
|
||||
auth: ["admin", "pass1234"]
|
||||
});
|
||||
|
||||
const api = client.view_api();
|
||||
expect((await api).named_endpoints["/predict"]).toBeDefined();
|
||||
});
|
||||
|
||||
it("should connect to a private and authenticated space", async () => {
|
||||
const client = await Client.connect("hmb/private_auth_space", {
|
||||
hf_token: "hf_123",
|
||||
auth: ["admin", "pass1234"]
|
||||
});
|
||||
|
||||
const api = client.view_api();
|
||||
expect((await api).named_endpoints["/predict"]).toBeDefined();
|
||||
});
|
||||
|
||||
it("should not set the cookies when auth credentials are invalid", async () => {
|
||||
await expect(
|
||||
Client.connect("hmb/invalid_auth_space", {
|
||||
auth: ["admin", "wrong_password"]
|
||||
})
|
||||
).rejects.toThrowError(INVALID_CREDENTIALS_MSG);
|
||||
});
|
||||
|
||||
it("should not set the cookies when auth option is not provided in an auth space", async () => {
|
||||
await expect(Client.connect("hmb/unauth_space")).rejects.toThrowError(
|
||||
MISSING_CREDENTIALS_MSG
|
||||
);
|
||||
});
|
||||
});
|
||||
|
@ -215,6 +215,7 @@ export interface DuplicateOptions extends ClientOptions {
|
||||
export interface ClientOptions {
|
||||
hf_token?: `hf_${string}`;
|
||||
status_callback?: SpaceStatusCallback | null;
|
||||
auth?: [string, string] | null;
|
||||
}
|
||||
|
||||
export interface FileData {
|
||||
|
@ -6,12 +6,17 @@ import {
|
||||
import type { DuplicateOptions } from "../types";
|
||||
import { Client } from "../client";
|
||||
import { SPACE_METADATA_ERROR_MSG } from "../constants";
|
||||
import {
|
||||
get_cookie_header,
|
||||
parse_and_set_cookies
|
||||
} from "../helpers/init_helpers";
|
||||
import { process_endpoint } from "../helpers/api_info";
|
||||
|
||||
export async function duplicate(
|
||||
app_reference: string,
|
||||
options: DuplicateOptions
|
||||
): Promise<Client> {
|
||||
const { hf_token, private: _private, hardware, timeout } = options;
|
||||
const { hf_token, private: _private, hardware, timeout, auth } = options;
|
||||
|
||||
if (hardware && !hardware_types.includes(hardware)) {
|
||||
throw new Error(
|
||||
@ -20,9 +25,29 @@ export async function duplicate(
|
||||
.join(",")}.`
|
||||
);
|
||||
}
|
||||
|
||||
const { http_protocol, host } = await process_endpoint(
|
||||
app_reference,
|
||||
hf_token
|
||||
);
|
||||
|
||||
let cookies: string[] | null = null;
|
||||
|
||||
if (auth) {
|
||||
const cookie_header = await get_cookie_header(
|
||||
http_protocol,
|
||||
host,
|
||||
auth,
|
||||
fetch
|
||||
);
|
||||
|
||||
if (cookie_header) cookies = parse_and_set_cookies(cookie_header);
|
||||
}
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${hf_token}`,
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
...(cookies ? { Cookie: cookies.join("; ") } : {})
|
||||
};
|
||||
|
||||
const user = (
|
||||
|
@ -19,7 +19,8 @@ export async function post_data(
|
||||
var response = await this.fetch(url, {
|
||||
method: "POST",
|
||||
body: JSON.stringify(body),
|
||||
headers: { ...headers, ...additional_headers }
|
||||
headers: { ...headers, ...additional_headers },
|
||||
credentials: "include"
|
||||
});
|
||||
} catch (e) {
|
||||
return [{ error: BROKEN_CONNECTION_MSG }, 500];
|
||||
|
@ -33,7 +33,8 @@ export async function upload_files(
|
||||
response = await this.fetch(upload_url, {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
headers
|
||||
headers,
|
||||
credentials: "include"
|
||||
});
|
||||
} catch (e) {
|
||||
throw new Error(BROKEN_CONNECTION_MSG + (e as Error).message);
|
||||
|
@ -34,11 +34,13 @@ export async function view_api(this: Client): Promise<any> {
|
||||
serialize: false,
|
||||
config: JSON.stringify(config)
|
||||
}),
|
||||
headers
|
||||
headers,
|
||||
credentials: "include"
|
||||
});
|
||||
} else {
|
||||
response = await this.fetch(`${config?.root}/${API_INFO_URL}`, {
|
||||
headers
|
||||
headers,
|
||||
credentials: "include"
|
||||
});
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user