diff --git a/.changeset/wild-planets-sin.md b/.changeset/wild-planets-sin.md new file mode 100644 index 0000000000..b501860f98 --- /dev/null +++ b/.changeset/wild-planets-sin.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Mount on a FastAPI app with lifespan manager diff --git a/gradio/routes.py b/gradio/routes.py index 825c1f5def..f8f24eb0d4 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -4,6 +4,7 @@ module use the Optional/Union notation so that they work correctly with pydantic from __future__ import annotations import asyncio +import contextlib import sys if sys.version_info >= (3, 9): @@ -892,9 +893,17 @@ def mount_gradio_app( blocks.validate_queue_settings() gradio_app = App.create_app(blocks, app_kwargs=app_kwargs) - @app.on_event("startup") - async def start_queue(): - gradio_app.get_blocks().startup_events() + old_lifespan = app.router.lifespan_context + + @contextlib.asynccontextmanager + async def new_lifespan(app: FastAPI): + async with old_lifespan( + app + ): # Instert the startup events inside the FastAPI context manager + gradio_app.get_blocks().startup_events() + yield + + app.router.lifespan_context = new_lifespan app.mount(path, gradio_app) return app diff --git a/test/test_routes.py b/test/test_routes.py index 564767cbaa..08bd596d5e 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -2,7 +2,7 @@ import functools import os import tempfile -from contextlib import closing +from contextlib import asynccontextmanager, closing from unittest import mock as mock import gradio_client as grc @@ -339,6 +339,50 @@ class TestRoutes: with TestClient(app) as client: assert client.get("/echo/docs-custom").is_success + def test_mount_gradio_app_with_lifespan(self): + @asynccontextmanager + async def empty_lifespan(app: FastAPI): + yield + + app = FastAPI(lifespan=empty_lifespan) + + demo = gr.Interface( + lambda s: f"Hello from ps, {s}!", "textbox", "textbox" + ).queue() + demo1 = gr.Interface( + lambda s: f"Hello from py, {s}!", "textbox", "textbox" + ).queue() + + app = gr.mount_gradio_app(app, demo, path="/ps") + app = gr.mount_gradio_app(app, demo1, path="/py") + + # Use context manager to trigger start up events + with TestClient(app) as client: + assert client.get("/ps").is_success + assert client.get("/py").is_success + + def test_mount_gradio_app_with_startup(self): + app = FastAPI() + + @app.on_event("startup") + async def empty_startup(): + return + + demo = gr.Interface( + lambda s: f"Hello from ps, {s}!", "textbox", "textbox" + ).queue() + demo1 = gr.Interface( + lambda s: f"Hello from py, {s}!", "textbox", "textbox" + ).queue() + + app = gr.mount_gradio_app(app, demo, path="/ps") + app = gr.mount_gradio_app(app, demo1, path="/py") + + # Use context manager to trigger start up events + with TestClient(app) as client: + assert client.get("/ps").is_success + assert client.get("/py").is_success + def test_static_file_missing(self, test_client): response = test_client.get(r"/static/not-here.js") assert response.status_code == 404