Improve URL handling in JS Client (#8258)

* amend endpoint logic

* add changeset

* improve url joining for local URLs

* handle relative paths

* use join_urls in /info  to ensure correct endpoints

* add relative url logic and throw error for node

* tweaks

* remove relative paths support (wont work)

* tweak

* tweak func to throw error and amend tests

* tweak

* accomodate . in space names

* replace error with const msg

* tweak tests with error var

* revert map() to reduce() due to misinterpreted base URL

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Hannah 2024-05-22 18:47:29 +01:00 committed by GitHub
parent 881f0a9e2c
commit 1f8e5c44e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 106 additions and 11 deletions

View File

@ -0,0 +1,6 @@
---
"@gradio/client": patch
"gradio": patch
---
fix:Improve URL handling in JS Client

View File

@ -25,6 +25,7 @@ export const CONFIG_ERROR_MSG = "Could not resolve app config. ";
export const SPACE_STATUS_ERROR_MSG = "Could not get space status. ";
export const API_INFO_ERROR_MSG = "Could not get API info. ";
export const SPACE_METADATA_ERROR_MSG = "Space metadata could not be loaded. ";
export const INVALID_URL_MSG = "Invalid URL. A full URL path is required.";
export const UNAUTHORIZED_MSG = "Not authorized to access this space. ";
export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. ";
export const MISSING_CREDENTIALS_MSG =

View File

@ -1,9 +1,14 @@
import type { Status } from "../types";
import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants";
import {
HOST_URL,
INVALID_URL_MSG,
QUEUE_FULL_MSG,
SPACE_METADATA_ERROR_MSG
} from "../constants";
import type { ApiData, ApiInfo, Config, JsApiData } from "../types";
import { determine_protocol } from "./init_helpers";
export const RE_SPACE_NAME = /^[^\/]*\/[^\/]*$/;
export const RE_SPACE_NAME = /^[a-zA-Z0-9_\-\.]+\/[a-zA-Z0-9_\-\.]+$/;
export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/;
export async function process_endpoint(
@ -20,12 +25,13 @@ export async function process_endpoint(
headers.Authorization = `Bearer ${hf_token}`;
}
const _app_reference = app_reference.trim();
const _app_reference = app_reference.trim().replace(/\/$/, "");
if (RE_SPACE_NAME.test(_app_reference)) {
// app_reference is a HF space name
try {
const res = await fetch(
`https://huggingface.co/api/spaces/${_app_reference}/host`,
`https://huggingface.co/api/spaces/${_app_reference}/${HOST_URL}`,
{ headers }
);
@ -41,6 +47,7 @@ export async function process_endpoint(
}
if (RE_SPACE_DOMAIN.test(_app_reference)) {
// app_reference is a direct HF space domain
const { ws_protocol, http_protocol, host } =
determine_protocol(_app_reference);
@ -58,6 +65,18 @@ export async function process_endpoint(
};
}
export const join_urls = (...urls: string[]): string => {
try {
return urls.reduce((base_url: string, part: string) => {
base_url = base_url.replace(/\/+$/, "");
part = part.replace(/^\/+/, "");
return new URL(part, base_url + "/").toString();
});
} catch (e) {
throw new Error(INVALID_URL_MSG);
}
};
export function transform_api_info(
api_info: ApiInfo<ApiData>,
config: Config,

View File

@ -9,7 +9,7 @@ import {
UNAUTHORIZED_MSG
} from "../constants";
import { Client } from "..";
import { process_endpoint } from "./api_info";
import { join_urls, process_endpoint } from "./api_info";
/**
* This function is used to resolve the URL for making requests when the app has a root path.
@ -86,7 +86,8 @@ export async function resolve_config(
config.root = config_root;
return { ...config, path } as Config;
} else if (endpoint) {
const response = await this.fetch(`${endpoint}/${CONFIG_URL}`, {
const config_url = join_urls(endpoint, CONFIG_URL);
const response = await this.fetch(config_url, {
headers,
credentials: "include"
});
@ -173,7 +174,7 @@ export function determine_protocol(endpoint: string): {
host: string;
} {
if (endpoint.startsWith("http")) {
const { protocol, host } = new URL(endpoint);
const { protocol, host, pathname } = new URL(endpoint);
if (host.endsWith("hf.space")) {
return {
@ -185,7 +186,7 @@ export function determine_protocol(endpoint: string): {
return {
ws_protocol: protocol === "https:" ? "wss" : "ws",
http_protocol: protocol as "http:" | "https:",
host
host: host + (pathname !== "/" ? pathname : "")
};
} else if (endpoint.startsWith("file:")) {
// This case is only expected to be used for the Wasm mode (Gradio-lite),

View File

@ -1,16 +1,22 @@
import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants";
import {
INVALID_URL_MSG,
QUEUE_FULL_MSG,
SPACE_METADATA_ERROR_MSG
} from "../constants";
import { beforeAll, afterEach, afterAll, it, expect, describe } from "vitest";
import {
handle_message,
get_description,
get_type,
process_endpoint,
join_urls,
map_data_to_params
} from "../helpers/api_info";
import { initialise_server } from "./server";
import { transformed_api_info } from "./test_data";
const server = initialise_server();
const IS_NODE = process.env.TEST_MODE === "node";
beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
@ -453,6 +459,67 @@ describe("process_endpoint", () => {
const result = await process_endpoint("hmb/hello_world");
expect(result).toEqual(expected);
});
it("processes local server URLs correctly", async () => {
const local_url = "http://localhost:7860/gradio";
const response_local_url = await process_endpoint(local_url);
expect(response_local_url.space_id).toBe(false);
expect(response_local_url.host).toBe("localhost:7860/gradio");
const local_url_2 = "http://localhost:7860/gradio/";
const response_local_url_2 = await process_endpoint(local_url_2);
expect(response_local_url_2.space_id).toBe(false);
expect(response_local_url_2.host).toBe("localhost:7860/gradio");
});
it("handles hugging face space references", async () => {
const space_id = "hmb/hello_world";
const response = await process_endpoint(space_id);
expect(response.space_id).toBe(space_id);
expect(response.host).toContain("hf.space");
});
it("handles hugging face domain URLs", async () => {
const app_reference = "https://hmb-hello-world.hf.space/";
const response = await process_endpoint(app_reference);
expect(response.space_id).toBe("hmb-hello-world");
expect(response.host).toBe("hmb-hello-world.hf.space");
});
});
describe("join_urls", () => {
it("joins URLs correctly", () => {
expect(join_urls("http://localhost:7860", "/gradio")).toBe(
"http://localhost:7860/gradio"
);
expect(join_urls("http://localhost:7860/", "/gradio")).toBe(
"http://localhost:7860/gradio"
);
expect(join_urls("http://localhost:7860", "app/", "/gradio")).toBe(
"http://localhost:7860/app/gradio"
);
expect(join_urls("http://localhost:7860/", "/app/", "/gradio/")).toBe(
"http://localhost:7860/app/gradio/"
);
expect(join_urls("http://127.0.0.1:8000/app", "/config")).toBe(
"http://127.0.0.1:8000/app/config"
);
expect(join_urls("http://127.0.0.1:8000/app/gradio", "/config")).toBe(
"http://127.0.0.1:8000/app/gradio/config"
);
});
it("throws an error when the URLs are not valid", () => {
expect(() => join_urls("localhost:7860", "/gradio")).toThrowError(
INVALID_URL_MSG
);
expect(() => join_urls("localhost:7860", "/gradio", "app")).toThrowError(
INVALID_URL_MSG
);
});
});
describe("map_data_params", () => {

View File

@ -3,7 +3,7 @@ import semiver from "semiver";
import { API_INFO_URL, BROKEN_CONNECTION_MSG } from "../constants";
import { Client } from "../client";
import { SPACE_FETCHER_URL } from "../constants";
import { transform_api_info } from "../helpers/api_info";
import { join_urls, transform_api_info } from "../helpers/api_info";
export async function view_api(this: Client): Promise<any> {
if (this.api_info) return this.api_info;
@ -38,7 +38,8 @@ export async function view_api(this: Client): Promise<any> {
credentials: "include"
});
} else {
response = await this.fetch(`${config?.root}/${API_INFO_URL}`, {
const url = join_urls(config.root, API_INFO_URL);
response = await this.fetch(url, {
headers,
credentials: "include"
});