mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Allow passing FastAPI app options (#4282)
* App: don't force docs_url and redoc_url to None * App.create_app: allow passing in app_kwargs * start_server + launch: allow passing in app_kwargs * Changelog * Apply suggestions from code review Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Use .launch for tests --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
1d0f0a9db0
commit
f0b8862475
@ -16,6 +16,7 @@ No changes to highlight.
|
||||
- Refactor web component `initial_height` attribute by [@whitphx](https://github.com/whitphx) in [PR 4223](https://github.com/gradio-app/gradio/pull/4223)
|
||||
- Relocate `mount_css` fn to remove circular dependency [@whitphx](https://github.com/whitphx) in [PR 4222](https://github.com/gradio-app/gradio/pull/4222)
|
||||
- Upgrade Black to 23.3 by [@akx](https://github.com/akx) in [PR 4259](https://github.com/gradio-app/gradio/pull/4259)
|
||||
- `Interface.launch()` and `Blocks.launch()` now accept an `app_kwargs` argument to allow customizing the configuration of the underlying FastAPI app, by [@akx](https://github.com/akx) in [PR 4282](https://github.com/gradio-app/gradio/pull/4282)
|
||||
|
||||
## Breaking Changes:
|
||||
|
||||
|
@ -1615,6 +1615,7 @@ Received outputs:
|
||||
blocked_paths: list[str] | None = None,
|
||||
root_path: str = "",
|
||||
_frontend: bool = True,
|
||||
app_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[FastAPI, str, str]:
|
||||
"""
|
||||
Launches a simple web server that serves the demo. Can also be used to create a
|
||||
@ -1648,6 +1649,7 @@ Received outputs:
|
||||
allowed_paths: List of complete filepaths or parent directories that gradio is allowed to serve (in addition to the directory containing the gradio python file). Must be absolute paths. Warning: if you provide directories, any files in these directories or their subdirectories are accessible to all users of your app.
|
||||
blocked_paths: List of complete filepaths or parent directories that gradio is not allowed to serve (i.e. users of your app are not allowed to access). Must be absolute paths. Warning: takes precedence over `allowed_paths` and all other directories exposed by Gradio by default.
|
||||
root_path: The root path (or "mount point") of the application, if it's not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application. For example, if the application is served at "https://example.com/myapp", the `root_path` should be set to "/myapp".
|
||||
app_kwargs: Additional keyword arguments to pass to the underlying FastAPI app as a dictionary of parameter keys and argument values. For example, `{"docs_url": "/docs"}`
|
||||
Returns:
|
||||
app: FastAPI app object that is running the demo
|
||||
local_url: Locally accessible link to the demo
|
||||
@ -1747,6 +1749,7 @@ Received outputs:
|
||||
ssl_keyfile,
|
||||
ssl_certfile,
|
||||
ssl_keyfile_password,
|
||||
app_kwargs=app_kwargs,
|
||||
)
|
||||
self.server_name = server_name
|
||||
self.local_url = local_url
|
||||
|
@ -89,6 +89,7 @@ def start_server(
|
||||
ssl_keyfile: str | None = None,
|
||||
ssl_certfile: str | None = None,
|
||||
ssl_keyfile_password: str | None = None,
|
||||
app_kwargs: dict | None = None,
|
||||
) -> tuple[str, int, str, App, Server]:
|
||||
"""Launches a local server running the provided Interface
|
||||
Parameters:
|
||||
@ -99,6 +100,8 @@ def start_server(
|
||||
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
|
||||
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
|
||||
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
|
||||
app_kwargs: Additional keyword arguments to pass to the gradio.routes.App constructor.
|
||||
|
||||
Returns:
|
||||
port: the port number the server is running on
|
||||
path_to_local_server: the complete address that the local server can be accessed at
|
||||
@ -143,7 +146,7 @@ def start_server(
|
||||
else:
|
||||
host = server_name
|
||||
|
||||
app = App.create_app(blocks)
|
||||
app = App.create_app(blocks, app_kwargs=app_kwargs)
|
||||
|
||||
if blocks.save_to is not None: # Used for selenium tests
|
||||
blocks.save_to["port"] = port
|
||||
|
@ -115,7 +115,11 @@ class App(FastAPI):
|
||||
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
|
||||
Path(tempfile.gettempdir()) / "gradio"
|
||||
)
|
||||
super().__init__(**kwargs, docs_url=None, redoc_url=None)
|
||||
# Allow user to manually set `docs_url` and `redoc_url`
|
||||
# when instantiating an App; when they're not set, disable docs and redoc.
|
||||
kwargs.setdefault("docs_url", None)
|
||||
kwargs.setdefault("redoc_url", None)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def configure_app(self, blocks: gradio.Blocks) -> None:
|
||||
auth = blocks.auth
|
||||
@ -141,8 +145,12 @@ class App(FastAPI):
|
||||
return self.blocks
|
||||
|
||||
@staticmethod
|
||||
def create_app(blocks: gradio.Blocks) -> App:
|
||||
app = App(default_response_class=ORJSONResponse)
|
||||
def create_app(
|
||||
blocks: gradio.Blocks, app_kwargs: Dict[str, Any] | None = None
|
||||
) -> App:
|
||||
app_kwargs = app_kwargs or {}
|
||||
app_kwargs.setdefault("default_response_class", ORJSONResponse)
|
||||
app = App(**app_kwargs)
|
||||
app.configure_app(blocks)
|
||||
|
||||
app.add_middleware(
|
||||
|
@ -81,3 +81,20 @@ class TestURLs:
|
||||
def test_url_ok(self):
|
||||
res = networking.url_ok("https://www.gradio.app")
|
||||
assert res
|
||||
|
||||
|
||||
def test_start_server_app_kwargs():
|
||||
"""
|
||||
Test that start_server accepts app_kwargs and they're propagated to FastAPI.
|
||||
"""
|
||||
io = Interface(lambda x: x, "number", "number")
|
||||
app, _, _ = io.launch(
|
||||
show_error=True,
|
||||
prevent_thread_lock=True,
|
||||
app_kwargs={
|
||||
"docs_url": "/docs",
|
||||
},
|
||||
)
|
||||
client = TestClient(app)
|
||||
assert client.get("/docs").status_code == 200
|
||||
io.close()
|
||||
|
Loading…
Reference in New Issue
Block a user