Fully resolve generated filepaths when running on Hugging Face Spaces with multiple replicas (#5668)

* print

* add changeset

* url

* routes

* routes

* test

* test

* add to / route

* comment

* root_url approach

* replica url

* print

* print

* test

* revert

* fixes

* changes

* replica url fix

* lint

* routes

* routes

* fix

* docstring

* add changeset

* add changeset

* add changeset

* modify in place

* add test

* unit tests

* fix copy

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-09-27 12:58:31 -07:00 committed by GitHub
parent c2b31c396f
commit d626c21e91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 5 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Fully resolve generated filepaths when running on Hugging Face Spaces with multiple replicas

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import copy
import json
from typing import TYPE_CHECKING, Optional, Union
import fastapi
import httpx
from gradio_client.documentation import document, set_documentation_group
from gradio import utils
@ -244,3 +246,25 @@ async def call_process_api(
output["data"] = output["data"][0]
return output
def set_replica_url_in_config(config: dict, replica_url: str) -> dict:
"""
If the Gradio app is running on Hugging Face Spaces and the machine has multiple replicas,
we pass in the direct URL to the replica so that we have the fully resolved path to any files
on that machine. This direct URL can be shared with other users and the path will still work.
"""
parsed_url = httpx.URL(replica_url)
stripped_url = parsed_url.copy_with(query=None)
stripped_url = str(stripped_url)
if not stripped_url.endswith("/"):
stripped_url += "/"
config_ = copy.deepcopy(config)
for component in config_["components"]:
if (
component.get("props") is not None
and component["props"].get("root_url") is None
):
component["props"]["root_url"] = stripped_url
return config_

View File

@ -53,7 +53,7 @@ from gradio.deprecation import warn_deprecation
from gradio.exceptions import Error
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation, Event
from gradio.route_utils import Request # noqa: F401
from gradio.route_utils import Request, set_replica_url_in_config # noqa: F401
from gradio.state_holder import StateHolder
from gradio.utils import (
cancel_tasks,
@ -124,6 +124,9 @@ class App(FastAPI):
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
self.replica_urls = (
set()
) # these are the full paths to the replicas if running on a Hugging Face Space with multiple replicas
self.change_event: None | threading.Event = None
# Allow user to manually set `docs_url` and `redoc_url`
# when instantiating an App; when they're not set, disable docs and redoc.
@ -158,9 +161,10 @@ class App(FastAPI):
assert self.blocks
# Don't proxy a URL unless it's a URL specifically loaded by the user using
# gr.load() to prevent SSRF or harvesting of HF tokens by malicious Spaces.
is_safe_url = any(
url.host == httpx.URL(root).host for root in self.blocks.root_urls
)
safe_urls = {httpx.URL(root).host for root in self.blocks.root_urls} | {
httpx.URL(root).host for root in self.replica_urls
}
is_safe_url = url.host in safe_urls
if not is_safe_url:
raise PermissionError("This URL cannot be proxied.")
is_hf_url = url.host.endswith(".hf.space")
@ -307,6 +311,13 @@ class App(FastAPI):
if app.auth is None or user is not None:
config = app.get_blocks().config
config["root"] = root_path
# Handles the case where the app is running on Hugging Face Spaces with
# multiple replicas. See `set_replica_url_in_config` for more details.
replica_url = request.headers.get("X-Direct-Url")
if utils.get_space() and replica_url:
app.replica_urls.add(replica_url)
config = set_replica_url_in_config(config, replica_url)
else:
config = {
"auth_required": True,
@ -344,8 +355,16 @@ class App(FastAPI):
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config(request: fastapi.Request):
root_path = request.scope.get("root_path", "")
config = app.get_blocks().config
# Handles the case where the app is running on Hugging Face Spaces with
# multiple replicas. See `set_replica_url_in_config` for more details.
replica_url = request.headers.get("X-Direct-Url")
if utils.get_space() and replica_url:
app.replica_urls.add(replica_url)
config = set_replica_url_in_config(config, replica_url)
root_path = request.scope.get("root_path", "")
config["root"] = root_path
return config

27
test/test_route_utils.py Normal file
View File

@ -0,0 +1,27 @@
from gradio.route_utils import set_replica_url_in_config
def test_set_replica_url():
config = {
"components": [{"props": {}}, {"props": {"root_url": "existing_url/"}}, {}]
}
replica_url = "https://abidlabs-test-client-replica--fttzk.hf.space?__theme=light"
config = set_replica_url_in_config(config, replica_url)
assert (
config["components"][0]["props"]["root_url"]
== "https://abidlabs-test-client-replica--fttzk.hf.space/"
)
assert config["components"][1]["props"]["root_url"] == "existing_url/"
assert "props" not in config["components"][2]
def test_url_without_trailing_slash():
config = {"components": [{"props": {}}]}
replica_url = "https://abidlabs-test-client-replica--fttzk.hf.space"
config = set_replica_url_in_config(config, replica_url)
assert (
config["components"][0]["props"]["root_url"]
== "https://abidlabs-test-client-replica--fttzk.hf.space/"
)