From d626c21e91df026b04fdb3ee5c7dba74a261cfd3 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 27 Sep 2023 12:58:31 -0700 Subject: [PATCH] Fully resolve generated filepaths when running on Hugging Face Spaces with multiple replicas (#5668) * print * add changeset * url * routes * routes * test * test * add to / route * comment * root_url approach * replica url * print * print * test * revert * fixes * changes * replica url fix * lint * routes * routes * fix * docstring * add changeset * add changeset * add changeset * modify in place * add test * unit tests * fix copy --------- Co-authored-by: gradio-pr-bot --- .changeset/some-forks-raise.md | 5 +++++ gradio/route_utils.py | 24 ++++++++++++++++++++++++ gradio/routes.py | 29 ++++++++++++++++++++++++----- test/test_route_utils.py | 27 +++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 5 deletions(-) create mode 100644 .changeset/some-forks-raise.md create mode 100644 test/test_route_utils.py diff --git a/.changeset/some-forks-raise.md b/.changeset/some-forks-raise.md new file mode 100644 index 0000000000..5d3a64aa6e --- /dev/null +++ b/.changeset/some-forks-raise.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Fully resolve generated filepaths when running on Hugging Face Spaces with multiple replicas diff --git a/gradio/route_utils.py b/gradio/route_utils.py index b79058820d..1b3d846edc 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -1,9 +1,11 @@ from __future__ import annotations +import copy import json from typing import TYPE_CHECKING, Optional, Union import fastapi +import httpx from gradio_client.documentation import document, set_documentation_group from gradio import utils @@ -244,3 +246,25 @@ async def call_process_api( output["data"] = output["data"][0] return output + + +def set_replica_url_in_config(config: dict, replica_url: str) -> dict: + """ + If the Gradio app is running on Hugging Face Spaces and the machine has multiple replicas, + we pass in the direct URL to the replica so that we have the fully resolved path to any files + on that machine. This direct URL can be shared with other users and the path will still work. + """ + parsed_url = httpx.URL(replica_url) + stripped_url = parsed_url.copy_with(query=None) + stripped_url = str(stripped_url) + if not stripped_url.endswith("/"): + stripped_url += "/" + + config_ = copy.deepcopy(config) + for component in config_["components"]: + if ( + component.get("props") is not None + and component["props"].get("root_url") is None + ): + component["props"]["root_url"] = stripped_url + return config_ diff --git a/gradio/routes.py b/gradio/routes.py index ff6886a4e8..1035c3fe4a 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -53,7 +53,7 @@ from gradio.deprecation import warn_deprecation from gradio.exceptions import Error from gradio.oauth import attach_oauth from gradio.queueing import Estimation, Event -from gradio.route_utils import Request # noqa: F401 +from gradio.route_utils import Request, set_replica_url_in_config # noqa: F401 from gradio.state_holder import StateHolder from gradio.utils import ( cancel_tasks, @@ -124,6 +124,9 @@ class App(FastAPI): self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str( Path(tempfile.gettempdir()) / "gradio" ) + self.replica_urls = ( + set() + ) # these are the full paths to the replicas if running on a Hugging Face Space with multiple replicas self.change_event: None | threading.Event = None # Allow user to manually set `docs_url` and `redoc_url` # when instantiating an App; when they're not set, disable docs and redoc. @@ -158,9 +161,10 @@ class App(FastAPI): assert self.blocks # Don't proxy a URL unless it's a URL specifically loaded by the user using # gr.load() to prevent SSRF or harvesting of HF tokens by malicious Spaces. - is_safe_url = any( - url.host == httpx.URL(root).host for root in self.blocks.root_urls - ) + safe_urls = {httpx.URL(root).host for root in self.blocks.root_urls} | { + httpx.URL(root).host for root in self.replica_urls + } + is_safe_url = url.host in safe_urls if not is_safe_url: raise PermissionError("This URL cannot be proxied.") is_hf_url = url.host.endswith(".hf.space") @@ -307,6 +311,13 @@ class App(FastAPI): if app.auth is None or user is not None: config = app.get_blocks().config config["root"] = root_path + + # Handles the case where the app is running on Hugging Face Spaces with + # multiple replicas. See `set_replica_url_in_config` for more details. + replica_url = request.headers.get("X-Direct-Url") + if utils.get_space() and replica_url: + app.replica_urls.add(replica_url) + config = set_replica_url_in_config(config, replica_url) else: config = { "auth_required": True, @@ -344,8 +355,16 @@ class App(FastAPI): @app.get("/config/", dependencies=[Depends(login_check)]) @app.get("/config", dependencies=[Depends(login_check)]) def get_config(request: fastapi.Request): - root_path = request.scope.get("root_path", "") config = app.get_blocks().config + + # Handles the case where the app is running on Hugging Face Spaces with + # multiple replicas. See `set_replica_url_in_config` for more details. + replica_url = request.headers.get("X-Direct-Url") + if utils.get_space() and replica_url: + app.replica_urls.add(replica_url) + config = set_replica_url_in_config(config, replica_url) + + root_path = request.scope.get("root_path", "") config["root"] = root_path return config diff --git a/test/test_route_utils.py b/test/test_route_utils.py new file mode 100644 index 0000000000..66ebd751ba --- /dev/null +++ b/test/test_route_utils.py @@ -0,0 +1,27 @@ +from gradio.route_utils import set_replica_url_in_config + + +def test_set_replica_url(): + config = { + "components": [{"props": {}}, {"props": {"root_url": "existing_url/"}}, {}] + } + replica_url = "https://abidlabs-test-client-replica--fttzk.hf.space?__theme=light" + + config = set_replica_url_in_config(config, replica_url) + assert ( + config["components"][0]["props"]["root_url"] + == "https://abidlabs-test-client-replica--fttzk.hf.space/" + ) + assert config["components"][1]["props"]["root_url"] == "existing_url/" + assert "props" not in config["components"][2] + + +def test_url_without_trailing_slash(): + config = {"components": [{"props": {}}]} + replica_url = "https://abidlabs-test-client-replica--fttzk.hf.space" + + config = set_replica_url_in_config(config, replica_url) + assert ( + config["components"][0]["props"]["root_url"] + == "https://abidlabs-test-client-replica--fttzk.hf.space/" + )