From 7f19ba272c5a46d7f2c11ad5c9f0ab15d4e27083 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 9 Feb 2024 15:26:27 -0800 Subject: [PATCH] Stop caching root url (#7374) * stop caching root * add changeset * add changeset * add changeset * cleanup * config * changes * routes * routes --------- Co-authored-by: gradio-pr-bot --- .changeset/soft-lies-carry.md | 5 +++++ gradio/route_utils.py | 16 ++++++++++------ gradio/routes.py | 29 ++++++++++++++--------------- 3 files changed, 29 insertions(+), 21 deletions(-) create mode 100644 .changeset/soft-lies-carry.md diff --git a/.changeset/soft-lies-carry.md b/.changeset/soft-lies-carry.md new file mode 100644 index 0000000000..6620f6eb62 --- /dev/null +++ b/.changeset/soft-lies-carry.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Stop caching root url diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 56b1e5d3ce..6e663e5ff4 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -261,14 +261,18 @@ async def call_process_api( return output -def strip_url(orig_url: str) -> str: +def get_root_url(request: fastapi.Request) -> str: """ - Strips the query parameters and trailing slash from a URL. + Gets the root url of the request, stripping off any query parameters and trailing slashes. + Also ensures that the root url is https if the request is https. """ - parsed_url = httpx.URL(orig_url) - stripped_url = parsed_url.copy_with(query=None) - stripped_url = str(stripped_url) - return stripped_url.rstrip("/") + root_url = str(request.url) + root_url = httpx.URL(root_url) + root_url = root_url.copy_with(query=None) + root_url = str(root_url) + if request.headers.get("x-forwarded-proto") == "https": + root_url = root_url.replace("http://", "https://") + return root_url.rstrip("/") def _user_safe_decode(src: bytes, codec: str) -> str: diff --git a/gradio/routes.py b/gradio/routes.py index 970fc3c0ca..45af1a2bb6 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import contextlib +import copy import sys if sys.version_info >= (3, 9): @@ -310,18 +311,17 @@ class App(FastAPI): def main(request: fastapi.Request, user: str = Depends(get_current_user)): mimetypes.add_type("application/javascript", ".js") blocks = app.get_blocks() - root_path = route_utils.strip_url(str(request.url)) + root_path = route_utils.get_root_url(request) if app.auth is None or user is not None: - config = app.get_blocks().config - if "root" not in config: - config["root"] = root_path - config = add_root_url(config, root_path) + config = copy.deepcopy(app.get_blocks().config) + config["root"] = root_path + config = add_root_url(config, root_path) else: config = { "auth_required": True, "auth_message": blocks.auth_message, "space_id": app.get_blocks().space_id, - "root": route_utils.strip_url(root_path), + "root": root_path, } try: @@ -352,11 +352,10 @@ class App(FastAPI): @app.get("/config/", dependencies=[Depends(login_check)]) @app.get("/config", dependencies=[Depends(login_check)]) def get_config(request: fastapi.Request): - root_path = route_utils.strip_url(str(request.url))[:-7] - config = app.get_blocks().config - if "root" not in config: - config["root"] = route_utils.strip_url(root_path) - config = add_root_url(config, root_path) + config = copy.deepcopy(app.get_blocks().config) + root_path = route_utils.get_root_url(request)[: -len("/config")] + config["root"] = root_path + config = add_root_url(config, root_path) return config @app.get("/static/{path:path}") @@ -571,8 +570,8 @@ class App(FastAPI): content={"error": str(error) if show_error else None}, status_code=500, ) - root_path = app.get_blocks().config.get("root", "") - output = add_root_url(output, route_utils.strip_url(root_path)) + root_path = route_utils.get_root_url(request)[: -len(f"/api/{api_name}")] + output = add_root_url(output, root_path) return output @app.get("/queue/data", dependencies=[Depends(login_check)]) @@ -581,7 +580,7 @@ class App(FastAPI): session_hash: str, ): blocks = app.get_blocks() - root_path = app.get_blocks().config.get("root", "") + root_path = route_utils.get_root_url(request)[: -len("/queue/data")] async def sse_stream(request: fastapi.Request): try: @@ -627,7 +626,7 @@ class App(FastAPI): "success": False, } if message: - add_root_url(message, route_utils.strip_url(root_path)) + add_root_url(message, root_path) yield f"data: {json.dumps(message)}\n\n" if message["msg"] == ServerMessage.process_completed: blocks._queue.pending_event_ids_session[