From ab653608045ff9462db7ad9fe63e1c60bf20e773 Mon Sep 17 00:00:00 2001 From: Hannah Date: Wed, 22 May 2024 11:02:56 +0100 Subject: [PATCH] =?UTF-8?q?Allow=20JS=20Client=20to=20work=20with=20authen?= =?UTF-8?q?ticated=20spaces=20=F0=9F=8D=AA=20(#8299)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * get cookie from /login and apply to fetch requests in connect + duplicate funcs * add error msg * add tests * improve error msgs * remove unused var * add changeset * remove comment * add error msg * add private space test --------- Co-authored-by: gradio-pr-bot --- .changeset/lemon-bugs-brake.md | 6 + client/js/src/client.ts | 54 ++++- client/js/src/constants.ts | 4 + client/js/src/helpers/api_info.ts | 6 +- client/js/src/helpers/init_helpers.ts | 96 ++++++++- client/js/src/test/api_info.test.ts | 4 +- client/js/src/test/handlers.ts | 251 +++++++++++++++++++++++- client/js/src/test/init.test.ts | 4 +- client/js/src/test/init_helpers.test.ts | 54 ++++- client/js/src/types.ts | 1 + client/js/src/utils/duplicate.ts | 29 ++- client/js/src/utils/post_data.ts | 3 +- client/js/src/utils/upload_files.ts | 3 +- client/js/src/utils/view_api.ts | 6 +- 14 files changed, 489 insertions(+), 32 deletions(-) create mode 100644 .changeset/lemon-bugs-brake.md diff --git a/.changeset/lemon-bugs-brake.md b/.changeset/lemon-bugs-brake.md new file mode 100644 index 0000000000..98d6710152 --- /dev/null +++ b/.changeset/lemon-bugs-brake.md @@ -0,0 +1,6 @@ +--- +"@gradio/client": minor +"gradio": minor +--- + +feat:Allow JS Client to work with authenticated spaces 🍪 diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 57e97a35da..9efa11b503 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -23,8 +23,10 @@ import { submit } from "./utils/submit"; import { RE_SPACE_NAME, process_endpoint } from "./helpers/api_info"; import { map_names_to_ids, + resolve_cookies, resolve_config, - get_jwt + get_jwt, + parse_and_set_cookies } from "./helpers/init_helpers"; import { check_space_status } from "./helpers/spaces"; import { open_stream } from "./utils/stream"; @@ -47,6 +49,8 @@ export class Client { jwt: string | false = false; last_status: Record = {}; + private cookies: string | null = null; + // streaming stream_status = { open: false }; pending_stream_messages: Record = {}; @@ -56,7 +60,12 @@ export class Client { heartbeat_event: EventSource | null = null; fetch(input: RequestInfo | URL, init?: RequestInit): Promise { - return fetch(input, init); + const headers = new Headers(init?.headers || {}); + if (this && this.cookies) { + headers.append("Cookie", this.cookies); + } + + return fetch(input, { ...init, headers }); } async stream(url: URL): Promise { @@ -108,6 +117,7 @@ export class Client { ) => Promise; open_stream: () => Promise; private resolve_config: (endpoint: string) => Promise; + private resolve_cookies: () => Promise; constructor(app_reference: string, options: ClientOptions = {}) { this.app_reference = app_reference; this.options = options; @@ -120,6 +130,7 @@ export class Client { this.predict = predict.bind(this); this.open_stream = open_stream.bind(this); this.resolve_config = resolve_config.bind(this); + this.resolve_cookies = resolve_cookies.bind(this); this.upload = upload.bind(this); } @@ -135,8 +146,23 @@ export class Client { } try { + if (this.options.auth) { + await this.resolve_cookies(); + } + await this._resolve_config().then(async ({ config }) => { - this.config = config; + if (config) { + this.config = config; + if (this.config && this.config.connect_heartbeat) { + if (this.config.space_id && this.options.hf_token) { + this.jwt = await get_jwt( + this.config.space_id, + this.options.hf_token, + this.cookies + ); + } + } + } if (config.space_id && this.options.hf_token) { this.jwt = await get_jwt(config.space_id, this.options.hf_token); @@ -156,8 +182,8 @@ export class Client { this.heartbeat_event = await this.stream(heartbeat_url); // Just connect to the endpoint without parsing the response. Ref: https://github.com/gradio-app/gradio/pull/7974#discussion_r1557717540 } }); - } catch (e) { - throw Error(CONFIG_ERROR_MSG + (e as Error).message); + } catch (e: any) { + throw Error(e); } this.api_info = await this.view_api(); @@ -201,9 +227,8 @@ export class Client { } return this.config_success(config); - } catch (e) { - console.error(e); - if (space_id) { + } catch (e: any) { + if (space_id && status_callback) { check_space_status( space_id, RE_SPACE_NAME.test(space_id) ? "space_name" : "subdomain", @@ -217,6 +242,7 @@ export class Client { load_status: "error", detail: "NOT_FOUND" }); + throw Error(e); } } } @@ -246,6 +272,9 @@ export class Client { } async handle_space_success(status: SpaceStatus): Promise { + if (!this) { + throw new Error(CONFIG_ERROR_MSG); + } const { status_callback } = this.options; if (status_callback) status_callback(status); if (status.status === "running") { @@ -259,7 +288,6 @@ export class Client { return _config as Config; } catch (e) { - console.error(e); if (status_callback) { status_callback({ status: "error", @@ -268,6 +296,7 @@ export class Client { detail: "NOT_FOUND" }); } + throw e; } } } @@ -333,7 +362,8 @@ export class Client { const response = await this.fetch(`${root_url}/component_server/`, { method: "POST", body: body, - headers + headers, + credentials: "include" }); if (!response.ok) { @@ -349,6 +379,10 @@ export class Client { } } + public set_cookies(raw_cookies: string): void { + this.cookies = parse_and_set_cookies(raw_cookies).join("; "); + } + private prepare_return_obj(): client_return { return { config: this.config, diff --git a/client/js/src/constants.ts b/client/js/src/constants.ts index 79ab3eb118..65451ac7b4 100644 --- a/client/js/src/constants.ts +++ b/client/js/src/constants.ts @@ -25,3 +25,7 @@ export const CONFIG_ERROR_MSG = "Could not resolve app config. "; export const SPACE_STATUS_ERROR_MSG = "Could not get space status. "; export const API_INFO_ERROR_MSG = "Could not get API info. "; export const SPACE_METADATA_ERROR_MSG = "Space metadata could not be loaded. "; +export const UNAUTHORIZED_MSG = "Not authorized to access this space. "; +export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. "; +export const MISSING_CREDENTIALS_MSG = + "Login credentials are required to access this space."; diff --git a/client/js/src/helpers/api_info.ts b/client/js/src/helpers/api_info.ts index 9ac0add27c..ee514f4789 100644 --- a/client/js/src/helpers/api_info.ts +++ b/client/js/src/helpers/api_info.ts @@ -1,5 +1,5 @@ import type { Status } from "../types"; -import { QUEUE_FULL_MSG } from "../constants"; +import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants"; import type { ApiData, ApiInfo, Config, JsApiData } from "../types"; import { determine_protocol } from "./init_helpers"; @@ -36,9 +36,7 @@ export async function process_endpoint( ...determine_protocol(_host) }; } catch (e) { - throw new Error( - "Space metadata could not be loaded. " + (e as Error).message - ); + throw new Error(SPACE_METADATA_ERROR_MSG); } } diff --git a/client/js/src/helpers/init_helpers.ts b/client/js/src/helpers/init_helpers.ts index e024528338..2fcb9b03a8 100644 --- a/client/js/src/helpers/init_helpers.ts +++ b/client/js/src/helpers/init_helpers.ts @@ -1,6 +1,15 @@ import type { Config } from "../types"; -import { CONFIG_ERROR_MSG, CONFIG_URL } from "../constants"; +import { + CONFIG_ERROR_MSG, + CONFIG_URL, + INVALID_CREDENTIALS_MSG, + LOGIN_URL, + MISSING_CREDENTIALS_MSG, + SPACE_METADATA_ERROR_MSG, + UNAUTHORIZED_MSG +} from "../constants"; import { Client } from ".."; +import { process_endpoint } from "./api_info"; /** * This function is used to resolve the URL for making requests when the app has a root path. @@ -25,12 +34,14 @@ export function resolve_root( export async function get_jwt( space: string, - token: `hf_${string}` + token: `hf_${string}`, + cookies?: string | null ): Promise { try { const r = await fetch(`https://huggingface.co/api/spaces/${space}/jwt`, { headers: { - Authorization: `Bearer ${token}` + Authorization: `Bearer ${token}`, + ...(cookies ? { Cookie: cookies } : {}) } }); @@ -76,14 +87,22 @@ export async function resolve_config( return { ...config, path } as Config; } else if (endpoint) { const response = await this.fetch(`${endpoint}/${CONFIG_URL}`, { - headers + headers, + credentials: "include" }); + if (response?.status === 401 && !this.options.auth) { + throw new Error(MISSING_CREDENTIALS_MSG); + } else if (response?.status === 401 && this.options.auth) { + throw new Error(INVALID_CREDENTIALS_MSG); + } if (response?.status === 200) { let config = await response.json(); config.path = config.path ?? ""; config.root = endpoint; return config; + } else if (response?.status === 401) { + throw new Error(UNAUTHORIZED_MSG); } throw new Error(CONFIG_ERROR_MSG); } @@ -91,6 +110,63 @@ export async function resolve_config( throw new Error(CONFIG_ERROR_MSG); } +export async function resolve_cookies(this: Client): Promise { + const { http_protocol, host } = await process_endpoint( + this.app_reference, + this.options.hf_token + ); + + try { + if (this.options.auth) { + const cookie_header = await get_cookie_header( + http_protocol, + host, + this.options.auth, + this.fetch, + this.options.hf_token + ); + + if (cookie_header) this.set_cookies(cookie_header); + } + } catch (e: unknown) { + throw Error((e as Error).message); + } +} + +// separating this from client-bound resolve_cookies so that it can be used in duplicate +export async function get_cookie_header( + http_protocol: string, + host: string, + auth: [string, string], + _fetch: typeof fetch, + hf_token?: `hf_${string}` +): Promise { + const formData = new FormData(); + formData.append("username", auth?.[0]); + formData.append("password", auth?.[1]); + + let headers: { Authorization?: string } = {}; + + if (hf_token) { + headers.Authorization = `Bearer ${hf_token}`; + } + + const res = await _fetch(`${http_protocol}//${host}/${LOGIN_URL}`, { + headers, + method: "POST", + body: formData, + credentials: "include" + }); + + if (res.status === 200) { + return res.headers.get("set-cookie"); + } else if (res.status === 401) { + throw new Error(INVALID_CREDENTIALS_MSG); + } else { + throw new Error(SPACE_METADATA_ERROR_MSG); + } +} + export function determine_protocol(endpoint: string): { ws_protocol: "ws" | "wss"; http_protocol: "http:" | "https:"; @@ -128,3 +204,15 @@ export function determine_protocol(endpoint: string): { host: endpoint }; } + +export const parse_and_set_cookies = (cookie_header: string): string[] => { + let cookies: string[] = []; + const parts = cookie_header.split(/,(?=\s*[^\s=;]+=[^\s=;]+)/); + parts.forEach((cookie) => { + const [cookie_name, cookie_value] = cookie.split(";")[0].split("="); + if (cookie_name && cookie_value) { + cookies.push(`${cookie_name.trim()}=${cookie_value.trim()}`); + } + }); + return cookies; +}; diff --git a/client/js/src/test/api_info.test.ts b/client/js/src/test/api_info.test.ts index 2cfca02a67..13fb49ad92 100644 --- a/client/js/src/test/api_info.test.ts +++ b/client/js/src/test/api_info.test.ts @@ -435,9 +435,7 @@ describe("process_endpoint", () => { try { await process_endpoint(app_reference, hf_token); } catch (error) { - expect(error.message).toEqual( - SPACE_METADATA_ERROR_MSG + "Unexpected end of JSON input" - ); + expect(error.message).toEqual(SPACE_METADATA_ERROR_MSG); } }); diff --git a/client/js/src/test/handlers.ts b/client/js/src/test/handlers.ts index e3bb6d18d5..222f474951 100644 --- a/client/js/src/test/handlers.ts +++ b/client/js/src/test/handlers.ts @@ -6,7 +6,8 @@ import { RUNTIME_URL, SLEEPTIME_URL, UPLOAD_URL, - BROKEN_CONNECTION_MSG + BROKEN_CONNECTION_MSG, + LOGIN_URL } from "../constants"; import { response_api_info, @@ -22,16 +23,24 @@ const root_url = "https://huggingface.co"; const direct_space_url = "https://hmb-hello-world.hf.space"; const private_space_url = "https://hmb-secret-world.hf.space"; +const private_auth_space_url = "https://hmb-private-auth-space.hf.space"; const server_error_space_url = "https://hmb-server-error.hf.space"; const upload_server_test_space_url = "https://hmb-server-test.hf.space"; -const server_error_reference = "hmb/server_error"; +const auth_app_space_url = "https://hmb-auth-space.hf.space"; +const unauth_app_space_url = "https://hmb-unauth-space.hf.space"; +const invalid_auth_space_url = "https://hmb-invalid-auth-space.hf.space"; +const server_error_reference = "hmb/server_error"; const app_reference = "hmb/hello_world"; const broken_app_reference = "hmb/bye_world"; const duplicate_app_reference = "gradio/hello_world"; const private_app_reference = "hmb/secret_world"; const server_test_app_reference = "hmb/server_test"; +const auth_app_reference = "hmb/auth_space"; +const unauth_app_reference = "hmb/unauth_space"; +const invalid_auth_app_reference = "hmb/invalid_auth_space"; +const private_auth_app_reference = "hmb/private_auth_space"; export const handlers: RequestHandler[] = [ // /host requests @@ -58,6 +67,23 @@ export const handlers: RequestHandler[] = [ } }); }), + http.get( + `${root_url}/api/spaces/${private_auth_app_reference}/${HOST_URL}`, + () => { + return new HttpResponse( + JSON.stringify({ + subdomain: "hmb-private-auth-space", + host: "https://hmb-private-auth-space.hf.space" + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + } + ), http.get( `${root_url}/api/spaces/${private_app_reference}/${HOST_URL}`, ({ request }) => { @@ -120,6 +146,68 @@ export const handlers: RequestHandler[] = [ ); } ), + http.get(`${root_url}/api/spaces/${auth_app_reference}/${HOST_URL}`, () => { + return new HttpResponse( + JSON.stringify({ + subdomain: "hmb-auth-space", + host: "https://hmb-auth-space.hf.space" + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + }), + http.get( + `${root_url}/api/spaces/${invalid_auth_app_reference}/${HOST_URL}`, + () => { + return new HttpResponse( + JSON.stringify({ + subdomain: "hmb-invalid-auth-space", + host: "https://hmb-invalid-auth-space.hf.space" + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + } + ), + http.get( + `${root_url}/api/spaces/${duplicate_app_reference}/${HOST_URL}`, + () => { + return new HttpResponse( + JSON.stringify({ + subdomain: "gradio-hello-world", + host: "https://gradio-hello-world.hf.space" + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + } + ), + http.get(`${root_url}/api/spaces/${unauth_app_reference}/${HOST_URL}`, () => { + return new HttpResponse( + JSON.stringify({ + subdomain: "hmb-unath-space", + host: "https://hmb-unauth-space.hf.space" + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + }), // /info requests http.get(`${direct_space_url}/${API_INFO_URL}`, () => { return new HttpResponse(JSON.stringify(response_api_info), { @@ -153,6 +241,22 @@ export const handlers: RequestHandler[] = [ } }); }), + http.get(`${auth_app_space_url}/${API_INFO_URL}`, async () => { + return new HttpResponse(JSON.stringify(response_api_info), { + status: 200, + headers: { + "Content-Type": "application/json" + } + }); + }), + http.get(`${private_auth_space_url}/${API_INFO_URL}`, async () => { + return new HttpResponse(JSON.stringify(response_api_info), { + status: 200, + headers: { + "Content-Type": "application/json" + } + }); + }), // /config requests http.get(`${direct_space_url}/${CONFIG_URL}`, () => { return new HttpResponse(JSON.stringify(config_response), { @@ -190,6 +294,20 @@ export const handlers: RequestHandler[] = [ } ); }), + http.get(`${private_auth_space_url}/${CONFIG_URL}`, () => { + return new HttpResponse( + JSON.stringify({ + ...config_response, + root: "https://hmb-private-auth-space.hf.space" + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + }), http.get(`${direct_space_url}/${CONFIG_URL}`, () => { return new HttpResponse(JSON.stringify(config_response), { status: 500, @@ -206,6 +324,42 @@ export const handlers: RequestHandler[] = [ } }); }), + http.get(`${invalid_auth_space_url}/${CONFIG_URL}`, () => { + return new HttpResponse(JSON.stringify({ detail: "Unauthorized" }), { + status: 401, + headers: { + "Content-Type": "application/json" + } + }); + }), + http.get(`${auth_app_space_url}/${CONFIG_URL}`, ({ request }) => { + return new HttpResponse( + JSON.stringify({ + ...config_response, + root: "https://hmb-auth-space.hf.space", + space_id: "hmb/auth_space" + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + }), + http.get(`${unauth_app_space_url}/${CONFIG_URL}`, () => { + return new HttpResponse( + JSON.stringify({ + detail: "Unauthorized" + }), + { + status: 401, + headers: { + "Content-Type": "application/json" + } + } + ); + }), // /whoami requests http.get(`${root_url}/api/whoami-v2`, () => { return new HttpResponse(JSON.stringify(whoami_response), { @@ -387,6 +541,20 @@ export const handlers: RequestHandler[] = [ } }); }), + http.get(`${root_url}/api/spaces/${unauth_app_reference}`, () => { + return new HttpResponse( + JSON.stringify({ + id: unauth_app_reference, + runtime: { ...runtime_response } + }), + { + status: 200, + headers: { + "Content-Type": "application/json" + } + } + ); + }), // jwt requests http.get(`${root_url}/api/spaces/${app_reference}/jwt`, () => { return new HttpResponse( @@ -434,5 +602,84 @@ export const handlers: RequestHandler[] = [ "Content-Type": "application/json" } }); + }), + // login requests + http.post(`${auth_app_space_url}/${LOGIN_URL}`, async ({ request }) => { + let username; + let password; + + await request.formData().then((data) => { + username = data.get("username"); + password = data.get("password"); + }); + + if (username === "admin" && password === "pass1234") { + return new HttpResponse( + JSON.stringify({ + success: true + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "Set-Cookie": + "access-token-123=abc; HttpOnly; Path=/; SameSite=none; Secure", + // @ts-ignore - multiple Set-Cookie headers are returned + "Set-Cookie": + "access-token-unsecure-123=abc; HttpOnly; Path=/; SameSite=none; Secure" + } + } + ); + } + + return new HttpResponse(null, { + status: 401, + headers: { + "Content-Type": "application/json" + } + }); + }), + http.post(`${invalid_auth_space_url}/${LOGIN_URL}`, async () => { + return new HttpResponse(null, { + status: 401, + headers: { + "Content-Type": "application/json" + } + }); + }), + http.post(`${private_auth_space_url}/${LOGIN_URL}`, async ({ request }) => { + let username; + let password; + + await request.formData().then((data) => { + username = data.get("username"); + password = data.get("password"); + }); + + if (username === "admin" && password === "pass1234") { + return new HttpResponse( + JSON.stringify({ + success: true + }), + { + status: 200, + headers: { + "Content-Type": "application/json", + "Set-Cookie": + "access-token-123=abc; HttpOnly; Path=/; SameSite=none; Secure", + // @ts-ignore - multiple Set-Cookie headers are returned + "Set-Cookie": + "access-token-unsecure-123=abc; HttpOnly; Path=/; SameSite=none; Secure" + } + } + ); + } + + return new HttpResponse(null, { + status: 401, + headers: { + "Content-Type": "application/json" + } + }); }) ]; diff --git a/client/js/src/test/init.test.ts b/client/js/src/test/init.test.ts index 67de6bb711..b240700954 100644 --- a/client/js/src/test/init.test.ts +++ b/client/js/src/test/init.test.ts @@ -67,7 +67,7 @@ describe("Client class", () => { Client.connect("hmb/secret_world", { hf_token: "hf_bad_token" }) - ).rejects.toThrow(CONFIG_ERROR_MSG); + ).rejects.toThrowError(SPACE_METADATA_ERROR_MSG); }); test("viewing the api info of a running app", async () => { @@ -77,7 +77,7 @@ describe("Client class", () => { test("viewing the api info of a non-existent app", async () => { const app = Client.connect(broken_app_reference); - await expect(app).rejects.toThrow(CONFIG_ERROR_MSG); + await expect(app).rejects.toThrowError(); }); }); diff --git a/client/js/src/test/init_helpers.test.ts b/client/js/src/test/init_helpers.test.ts index 71700ddca5..2462dd4ea2 100644 --- a/client/js/src/test/init_helpers.test.ts +++ b/client/js/src/test/init_helpers.test.ts @@ -1,10 +1,13 @@ import { resolve_root, get_jwt, - determine_protocol + determine_protocol, + parse_and_set_cookies } from "../helpers/init_helpers"; import { initialise_server } from "./server"; import { beforeAll, afterEach, afterAll, it, expect, describe } from "vitest"; +import { Client } from "../client"; +import { INVALID_CREDENTIALS_MSG, MISSING_CREDENTIALS_MSG } from "../constants"; const server = initialise_server(); @@ -92,3 +95,52 @@ describe("determine_protocol", () => { }); }); }); + +describe("parse_and_set_cookies", () => { + it("should return an empty array when the cookie header is empty", () => { + const cookie_header = ""; + const result = parse_and_set_cookies(cookie_header); + expect(result).toEqual([]); + }); + + it("should parse the cookie header and return an array of cookies", () => { + const cookie_header = "access-token-123=abc;access-token-unsecured-456=def"; + const result = parse_and_set_cookies(cookie_header); + expect(result).toEqual(["access-token-123=abc"]); + }); +}); + +describe("resolve_cookies", () => { + it("should set the cookies when correct auth credentials are provided", async () => { + const client = await Client.connect("hmb/auth_space", { + auth: ["admin", "pass1234"] + }); + + const api = client.view_api(); + expect((await api).named_endpoints["/predict"]).toBeDefined(); + }); + + it("should connect to a private and authenticated space", async () => { + const client = await Client.connect("hmb/private_auth_space", { + hf_token: "hf_123", + auth: ["admin", "pass1234"] + }); + + const api = client.view_api(); + expect((await api).named_endpoints["/predict"]).toBeDefined(); + }); + + it("should not set the cookies when auth credentials are invalid", async () => { + await expect( + Client.connect("hmb/invalid_auth_space", { + auth: ["admin", "wrong_password"] + }) + ).rejects.toThrowError(INVALID_CREDENTIALS_MSG); + }); + + it("should not set the cookies when auth option is not provided in an auth space", async () => { + await expect(Client.connect("hmb/unauth_space")).rejects.toThrowError( + MISSING_CREDENTIALS_MSG + ); + }); +}); diff --git a/client/js/src/types.ts b/client/js/src/types.ts index 838151b70a..3cccf5e2b3 100644 --- a/client/js/src/types.ts +++ b/client/js/src/types.ts @@ -215,6 +215,7 @@ export interface DuplicateOptions extends ClientOptions { export interface ClientOptions { hf_token?: `hf_${string}`; status_callback?: SpaceStatusCallback | null; + auth?: [string, string] | null; } export interface FileData { diff --git a/client/js/src/utils/duplicate.ts b/client/js/src/utils/duplicate.ts index c334fcc3ff..27cecd3ffb 100644 --- a/client/js/src/utils/duplicate.ts +++ b/client/js/src/utils/duplicate.ts @@ -6,12 +6,17 @@ import { import type { DuplicateOptions } from "../types"; import { Client } from "../client"; import { SPACE_METADATA_ERROR_MSG } from "../constants"; +import { + get_cookie_header, + parse_and_set_cookies +} from "../helpers/init_helpers"; +import { process_endpoint } from "../helpers/api_info"; export async function duplicate( app_reference: string, options: DuplicateOptions ): Promise { - const { hf_token, private: _private, hardware, timeout } = options; + const { hf_token, private: _private, hardware, timeout, auth } = options; if (hardware && !hardware_types.includes(hardware)) { throw new Error( @@ -20,9 +25,29 @@ export async function duplicate( .join(",")}.` ); } + + const { http_protocol, host } = await process_endpoint( + app_reference, + hf_token + ); + + let cookies: string[] | null = null; + + if (auth) { + const cookie_header = await get_cookie_header( + http_protocol, + host, + auth, + fetch + ); + + if (cookie_header) cookies = parse_and_set_cookies(cookie_header); + } + const headers = { Authorization: `Bearer ${hf_token}`, - "Content-Type": "application/json" + "Content-Type": "application/json", + ...(cookies ? { Cookie: cookies.join("; ") } : {}) }; const user = ( diff --git a/client/js/src/utils/post_data.ts b/client/js/src/utils/post_data.ts index b5c7d8258a..53264a3c86 100644 --- a/client/js/src/utils/post_data.ts +++ b/client/js/src/utils/post_data.ts @@ -19,7 +19,8 @@ export async function post_data( var response = await this.fetch(url, { method: "POST", body: JSON.stringify(body), - headers: { ...headers, ...additional_headers } + headers: { ...headers, ...additional_headers }, + credentials: "include" }); } catch (e) { return [{ error: BROKEN_CONNECTION_MSG }, 500]; diff --git a/client/js/src/utils/upload_files.ts b/client/js/src/utils/upload_files.ts index 5f3150debe..dd702461a8 100644 --- a/client/js/src/utils/upload_files.ts +++ b/client/js/src/utils/upload_files.ts @@ -33,7 +33,8 @@ export async function upload_files( response = await this.fetch(upload_url, { method: "POST", body: formData, - headers + headers, + credentials: "include" }); } catch (e) { throw new Error(BROKEN_CONNECTION_MSG + (e as Error).message); diff --git a/client/js/src/utils/view_api.ts b/client/js/src/utils/view_api.ts index 9093732f60..05e0c60412 100644 --- a/client/js/src/utils/view_api.ts +++ b/client/js/src/utils/view_api.ts @@ -34,11 +34,13 @@ export async function view_api(this: Client): Promise { serialize: false, config: JSON.stringify(config) }), - headers + headers, + credentials: "include" }); } else { response = await this.fetch(`${config?.root}/${API_INFO_URL}`, { - headers + headers, + credentials: "include" }); }