Allow passing FastAPI app options (#4282)

* App: don't force docs_url and redoc_url to None

* App.create_app: allow passing in app_kwargs

* start_server + launch: allow passing in app_kwargs

* Changelog

* Apply suggestions from code review

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Use .launch for tests

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Aarni Koskela 2023-05-20 12:10:00 +03:00 committed by GitHub
parent 1d0f0a9db0
commit f0b8862475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 15 deletions

View File

@ -16,6 +16,7 @@ No changes to highlight.
- Refactor web component `initial_height` attribute by [@whitphx](https://github.com/whitphx) in [PR 4223](https://github.com/gradio-app/gradio/pull/4223)
- Relocate `mount_css` fn to remove circular dependency [@whitphx](https://github.com/whitphx) in [PR 4222](https://github.com/gradio-app/gradio/pull/4222)
- Upgrade Black to 23.3 by [@akx](https://github.com/akx) in [PR 4259](https://github.com/gradio-app/gradio/pull/4259)
- `Interface.launch()` and `Blocks.launch()` now accept an `app_kwargs` argument to allow customizing the configuration of the underlying FastAPI app, by [@akx](https://github.com/akx) in [PR 4282](https://github.com/gradio-app/gradio/pull/4282)
## Breaking Changes:

View File

@ -1615,6 +1615,7 @@ Received outputs:
blocked_paths: list[str] | None = None,
root_path: str = "",
_frontend: bool = True,
app_kwargs: dict[str, Any] | None = None,
) -> tuple[FastAPI, str, str]:
"""
Launches a simple web server that serves the demo. Can also be used to create a
@ -1648,6 +1649,7 @@ Received outputs:
allowed_paths: List of complete filepaths or parent directories that gradio is allowed to serve (in addition to the directory containing the gradio python file). Must be absolute paths. Warning: if you provide directories, any files in these directories or their subdirectories are accessible to all users of your app.
blocked_paths: List of complete filepaths or parent directories that gradio is not allowed to serve (i.e. users of your app are not allowed to access). Must be absolute paths. Warning: takes precedence over `allowed_paths` and all other directories exposed by Gradio by default.
root_path: The root path (or "mount point") of the application, if it's not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application. For example, if the application is served at "https://example.com/myapp", the `root_path` should be set to "/myapp".
app_kwargs: Additional keyword arguments to pass to the underlying FastAPI app as a dictionary of parameter keys and argument values. For example, `{"docs_url": "/docs"}`
Returns:
app: FastAPI app object that is running the demo
local_url: Locally accessible link to the demo
@ -1747,6 +1749,7 @@ Received outputs:
ssl_keyfile,
ssl_certfile,
ssl_keyfile_password,
app_kwargs=app_kwargs,
)
self.server_name = server_name
self.local_url = local_url

View File

@ -89,6 +89,7 @@ def start_server(
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
ssl_keyfile_password: str | None = None,
app_kwargs: dict | None = None,
) -> tuple[str, int, str, App, Server]:
"""Launches a local server running the provided Interface
Parameters:
@ -99,6 +100,8 @@ def start_server(
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
app_kwargs: Additional keyword arguments to pass to the gradio.routes.App constructor.
Returns:
port: the port number the server is running on
path_to_local_server: the complete address that the local server can be accessed at
@ -143,7 +146,7 @@ def start_server(
else:
host = server_name
app = App.create_app(blocks)
app = App.create_app(blocks, app_kwargs=app_kwargs)
if blocks.save_to is not None: # Used for selenium tests
blocks.save_to["port"] = port

View File

@ -115,7 +115,11 @@ class App(FastAPI):
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
super().__init__(**kwargs, docs_url=None, redoc_url=None)
# Allow user to manually set `docs_url` and `redoc_url`
# when instantiating an App; when they're not set, disable docs and redoc.
kwargs.setdefault("docs_url", None)
kwargs.setdefault("redoc_url", None)
super().__init__(**kwargs)
def configure_app(self, blocks: gradio.Blocks) -> None:
auth = blocks.auth
@ -141,8 +145,12 @@ class App(FastAPI):
return self.blocks
@staticmethod
def create_app(blocks: gradio.Blocks) -> App:
app = App(default_response_class=ORJSONResponse)
def create_app(
blocks: gradio.Blocks, app_kwargs: Dict[str, Any] | None = None
) -> App:
app_kwargs = app_kwargs or {}
app_kwargs.setdefault("default_response_class", ORJSONResponse)
app = App(**app_kwargs)
app.configure_app(blocks)
app.add_middleware(

View File

@ -81,3 +81,20 @@ class TestURLs:
def test_url_ok(self):
res = networking.url_ok("https://www.gradio.app")
assert res
def test_start_server_app_kwargs():
"""
Test that start_server accepts app_kwargs and they're propagated to FastAPI.
"""
io = Interface(lambda x: x, "number", "number")
app, _, _ = io.launch(
show_error=True,
prevent_thread_lock=True,
app_kwargs={
"docs_url": "/docs",
},
)
client = TestClient(app)
assert client.get("/docs").status_code == 200
io.close()