mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Improve URL handling in JS Client (#8258)
* amend endpoint logic * add changeset * improve url joining for local URLs * handle relative paths * use join_urls in /info to ensure correct endpoints * add relative url logic and throw error for node * tweaks * remove relative paths support (wont work) * tweak * tweak func to throw error and amend tests * tweak * accomodate . in space names * replace error with const msg * tweak tests with error var * revert map() to reduce() due to misinterpreted base URL --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
881f0a9e2c
commit
1f8e5c44e0
6
.changeset/sad-feet-brush.md
Normal file
6
.changeset/sad-feet-brush.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"@gradio/client": patch
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Improve URL handling in JS Client
|
@ -25,6 +25,7 @@ export const CONFIG_ERROR_MSG = "Could not resolve app config. ";
|
||||
export const SPACE_STATUS_ERROR_MSG = "Could not get space status. ";
|
||||
export const API_INFO_ERROR_MSG = "Could not get API info. ";
|
||||
export const SPACE_METADATA_ERROR_MSG = "Space metadata could not be loaded. ";
|
||||
export const INVALID_URL_MSG = "Invalid URL. A full URL path is required.";
|
||||
export const UNAUTHORIZED_MSG = "Not authorized to access this space. ";
|
||||
export const INVALID_CREDENTIALS_MSG = "Invalid credentials. Could not login. ";
|
||||
export const MISSING_CREDENTIALS_MSG =
|
||||
|
@ -1,9 +1,14 @@
|
||||
import type { Status } from "../types";
|
||||
import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants";
|
||||
import {
|
||||
HOST_URL,
|
||||
INVALID_URL_MSG,
|
||||
QUEUE_FULL_MSG,
|
||||
SPACE_METADATA_ERROR_MSG
|
||||
} from "../constants";
|
||||
import type { ApiData, ApiInfo, Config, JsApiData } from "../types";
|
||||
import { determine_protocol } from "./init_helpers";
|
||||
|
||||
export const RE_SPACE_NAME = /^[^\/]*\/[^\/]*$/;
|
||||
export const RE_SPACE_NAME = /^[a-zA-Z0-9_\-\.]+\/[a-zA-Z0-9_\-\.]+$/;
|
||||
export const RE_SPACE_DOMAIN = /.*hf\.space\/{0,1}$/;
|
||||
|
||||
export async function process_endpoint(
|
||||
@ -20,12 +25,13 @@ export async function process_endpoint(
|
||||
headers.Authorization = `Bearer ${hf_token}`;
|
||||
}
|
||||
|
||||
const _app_reference = app_reference.trim();
|
||||
const _app_reference = app_reference.trim().replace(/\/$/, "");
|
||||
|
||||
if (RE_SPACE_NAME.test(_app_reference)) {
|
||||
// app_reference is a HF space name
|
||||
try {
|
||||
const res = await fetch(
|
||||
`https://huggingface.co/api/spaces/${_app_reference}/host`,
|
||||
`https://huggingface.co/api/spaces/${_app_reference}/${HOST_URL}`,
|
||||
{ headers }
|
||||
);
|
||||
|
||||
@ -41,6 +47,7 @@ export async function process_endpoint(
|
||||
}
|
||||
|
||||
if (RE_SPACE_DOMAIN.test(_app_reference)) {
|
||||
// app_reference is a direct HF space domain
|
||||
const { ws_protocol, http_protocol, host } =
|
||||
determine_protocol(_app_reference);
|
||||
|
||||
@ -58,6 +65,18 @@ export async function process_endpoint(
|
||||
};
|
||||
}
|
||||
|
||||
export const join_urls = (...urls: string[]): string => {
|
||||
try {
|
||||
return urls.reduce((base_url: string, part: string) => {
|
||||
base_url = base_url.replace(/\/+$/, "");
|
||||
part = part.replace(/^\/+/, "");
|
||||
return new URL(part, base_url + "/").toString();
|
||||
});
|
||||
} catch (e) {
|
||||
throw new Error(INVALID_URL_MSG);
|
||||
}
|
||||
};
|
||||
|
||||
export function transform_api_info(
|
||||
api_info: ApiInfo<ApiData>,
|
||||
config: Config,
|
||||
|
@ -9,7 +9,7 @@ import {
|
||||
UNAUTHORIZED_MSG
|
||||
} from "../constants";
|
||||
import { Client } from "..";
|
||||
import { process_endpoint } from "./api_info";
|
||||
import { join_urls, process_endpoint } from "./api_info";
|
||||
|
||||
/**
|
||||
* This function is used to resolve the URL for making requests when the app has a root path.
|
||||
@ -86,7 +86,8 @@ export async function resolve_config(
|
||||
config.root = config_root;
|
||||
return { ...config, path } as Config;
|
||||
} else if (endpoint) {
|
||||
const response = await this.fetch(`${endpoint}/${CONFIG_URL}`, {
|
||||
const config_url = join_urls(endpoint, CONFIG_URL);
|
||||
const response = await this.fetch(config_url, {
|
||||
headers,
|
||||
credentials: "include"
|
||||
});
|
||||
@ -173,7 +174,7 @@ export function determine_protocol(endpoint: string): {
|
||||
host: string;
|
||||
} {
|
||||
if (endpoint.startsWith("http")) {
|
||||
const { protocol, host } = new URL(endpoint);
|
||||
const { protocol, host, pathname } = new URL(endpoint);
|
||||
|
||||
if (host.endsWith("hf.space")) {
|
||||
return {
|
||||
@ -185,7 +186,7 @@ export function determine_protocol(endpoint: string): {
|
||||
return {
|
||||
ws_protocol: protocol === "https:" ? "wss" : "ws",
|
||||
http_protocol: protocol as "http:" | "https:",
|
||||
host
|
||||
host: host + (pathname !== "/" ? pathname : "")
|
||||
};
|
||||
} else if (endpoint.startsWith("file:")) {
|
||||
// This case is only expected to be used for the Wasm mode (Gradio-lite),
|
||||
|
@ -1,16 +1,22 @@
|
||||
import { QUEUE_FULL_MSG, SPACE_METADATA_ERROR_MSG } from "../constants";
|
||||
import {
|
||||
INVALID_URL_MSG,
|
||||
QUEUE_FULL_MSG,
|
||||
SPACE_METADATA_ERROR_MSG
|
||||
} from "../constants";
|
||||
import { beforeAll, afterEach, afterAll, it, expect, describe } from "vitest";
|
||||
import {
|
||||
handle_message,
|
||||
get_description,
|
||||
get_type,
|
||||
process_endpoint,
|
||||
join_urls,
|
||||
map_data_to_params
|
||||
} from "../helpers/api_info";
|
||||
import { initialise_server } from "./server";
|
||||
import { transformed_api_info } from "./test_data";
|
||||
|
||||
const server = initialise_server();
|
||||
const IS_NODE = process.env.TEST_MODE === "node";
|
||||
|
||||
beforeAll(() => server.listen());
|
||||
afterEach(() => server.resetHandlers());
|
||||
@ -453,6 +459,67 @@ describe("process_endpoint", () => {
|
||||
const result = await process_endpoint("hmb/hello_world");
|
||||
expect(result).toEqual(expected);
|
||||
});
|
||||
|
||||
it("processes local server URLs correctly", async () => {
|
||||
const local_url = "http://localhost:7860/gradio";
|
||||
const response_local_url = await process_endpoint(local_url);
|
||||
expect(response_local_url.space_id).toBe(false);
|
||||
expect(response_local_url.host).toBe("localhost:7860/gradio");
|
||||
|
||||
const local_url_2 = "http://localhost:7860/gradio/";
|
||||
const response_local_url_2 = await process_endpoint(local_url_2);
|
||||
expect(response_local_url_2.space_id).toBe(false);
|
||||
expect(response_local_url_2.host).toBe("localhost:7860/gradio");
|
||||
});
|
||||
|
||||
it("handles hugging face space references", async () => {
|
||||
const space_id = "hmb/hello_world";
|
||||
|
||||
const response = await process_endpoint(space_id);
|
||||
expect(response.space_id).toBe(space_id);
|
||||
expect(response.host).toContain("hf.space");
|
||||
});
|
||||
|
||||
it("handles hugging face domain URLs", async () => {
|
||||
const app_reference = "https://hmb-hello-world.hf.space/";
|
||||
const response = await process_endpoint(app_reference);
|
||||
expect(response.space_id).toBe("hmb-hello-world");
|
||||
expect(response.host).toBe("hmb-hello-world.hf.space");
|
||||
});
|
||||
});
|
||||
|
||||
describe("join_urls", () => {
|
||||
it("joins URLs correctly", () => {
|
||||
expect(join_urls("http://localhost:7860", "/gradio")).toBe(
|
||||
"http://localhost:7860/gradio"
|
||||
);
|
||||
expect(join_urls("http://localhost:7860/", "/gradio")).toBe(
|
||||
"http://localhost:7860/gradio"
|
||||
);
|
||||
expect(join_urls("http://localhost:7860", "app/", "/gradio")).toBe(
|
||||
"http://localhost:7860/app/gradio"
|
||||
);
|
||||
expect(join_urls("http://localhost:7860/", "/app/", "/gradio/")).toBe(
|
||||
"http://localhost:7860/app/gradio/"
|
||||
);
|
||||
|
||||
expect(join_urls("http://127.0.0.1:8000/app", "/config")).toBe(
|
||||
"http://127.0.0.1:8000/app/config"
|
||||
);
|
||||
|
||||
expect(join_urls("http://127.0.0.1:8000/app/gradio", "/config")).toBe(
|
||||
"http://127.0.0.1:8000/app/gradio/config"
|
||||
);
|
||||
});
|
||||
it("throws an error when the URLs are not valid", () => {
|
||||
expect(() => join_urls("localhost:7860", "/gradio")).toThrowError(
|
||||
INVALID_URL_MSG
|
||||
);
|
||||
|
||||
expect(() => join_urls("localhost:7860", "/gradio", "app")).toThrowError(
|
||||
INVALID_URL_MSG
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("map_data_params", () => {
|
||||
|
@ -3,7 +3,7 @@ import semiver from "semiver";
|
||||
import { API_INFO_URL, BROKEN_CONNECTION_MSG } from "../constants";
|
||||
import { Client } from "../client";
|
||||
import { SPACE_FETCHER_URL } from "../constants";
|
||||
import { transform_api_info } from "../helpers/api_info";
|
||||
import { join_urls, transform_api_info } from "../helpers/api_info";
|
||||
|
||||
export async function view_api(this: Client): Promise<any> {
|
||||
if (this.api_info) return this.api_info;
|
||||
@ -38,7 +38,8 @@ export async function view_api(this: Client): Promise<any> {
|
||||
credentials: "include"
|
||||
});
|
||||
} else {
|
||||
response = await this.fetch(`${config?.root}/${API_INFO_URL}`, {
|
||||
const url = join_urls(config.root, API_INFO_URL);
|
||||
response = await this.fetch(url, {
|
||||
headers,
|
||||
credentials: "include"
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user