Allow mounted Gradio apps to work with external / arbitrary authentication providers (#7557)

* add parameter

* format

* add changeset

* docstrings

* changes

* changes

* docs

* mark flaky

* test

* docs

* docs

* push

* docs

* Update guides/01_getting-started/03_sharing-your-app.md

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>

* fix typecheck

* error

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>
This commit is contained in:
Abubakar Abid 2024-03-05 12:41:39 -08:00 committed by GitHub
parent f26aba00a7
commit 4d5789e905
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 211 additions and 54 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Allow mounted Gradio apps to work with external / arbitrary authentication providers

View File

@ -53,6 +53,7 @@ def connect(
class TestClientInitialization:
@pytest.mark.flaky
def test_headers_constructed_correctly(self):
client = Client("gradio-tests/titanic-survival", hf_token=HF_TOKEN)
assert {"authorization": f"Bearer {HF_TOKEN}"}.items() <= client.headers.items()
@ -527,6 +528,7 @@ class TestClientPredictions:
finally:
server.thread.join(timeout=1)
@pytest.mark.flaky
def test_predict_with_space_with_api_name_false(self):
client = Client("gradio-tests/client-bool-api-name-error")
assert client.predict("Hello!", api_name="/run") == "Hello!"

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Literal, Sequenc
from urllib.parse import urlparse, urlunparse
import anyio
import fastapi
import httpx
from anyio import CapacityLimiter
from gradio_client import utils as client_utils
@ -1901,6 +1902,7 @@ Received outputs:
state_session_capacity: int = 10000,
share_server_address: str | None = None,
share_server_protocol: Literal["http", "https"] | None = None,
auth_dependency: Callable[[fastapi.Request], str | None] | None = None,
_frontend: bool = True,
) -> tuple[FastAPI, str, str]:
"""
@ -1935,6 +1937,7 @@ Received outputs:
state_session_capacity: The maximum number of sessions whose information to store in memory. If the number of sessions exceeds this number, the oldest sessions will be removed. Reduce capacity to reduce memory usage when using gradio.State or returning updated components from functions. Defaults to 10000.
share_server_address: Use this to specify a custom FRP server and port for sharing Gradio apps (only applies if share=True). If not provided, will use the default FRP server at https://gradio.live. See https://github.com/huggingface/frp for more information.
share_server_protocol: Use this to specify the protocol to use for the share links. Defaults to "https", unless a custom share_server_address is provided, in which case it defaults to "http". If you are using a custom share_server_address and want to use https, you must set this to "https".
auth_dependency: A function that takes a FastAPI request and returns a string user ID or None. If the function returns None for a specific request, that user is not authorized to access the app (they will see a 401 Unauthorized response). To be used with external authentication systems like OAuth.
Returns:
app: FastAPI app object that is running the demo
local_url: Locally accessible link to the demo
@ -1961,6 +1964,10 @@ Received outputs:
if not self.exited:
self.__exit__()
if auth is not None and auth_dependency is not None:
raise ValueError(
"You cannot provide both `auth` and `auth_dependency` in launch(). Please choose one."
)
if (
auth
and not callable(auth)
@ -2019,7 +2026,9 @@ Received outputs:
# and avoid using `networking.start_server` that would start a server that don't work in the Wasm env.
from gradio.routes import App
app = App.create_app(self, app_kwargs=app_kwargs)
app = App.create_app(
self, auth_dependency=auth_dependency, app_kwargs=app_kwargs
)
wasm_utils.register_app(app)
else:
(

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import tempfile
from pathlib import Path
from typing import Callable, Literal
from gradio_client.documentation import document
@ -25,7 +26,7 @@ class DownloadButton(Component):
def __init__(
self,
label: str = "Download",
value: str | list[str] | Callable | None = None,
value: str | Path | Callable | None = None,
*,
every: float | None = None,
variant: Literal["primary", "secondary", "stop"] = "secondary",
@ -41,8 +42,8 @@ class DownloadButton(Component):
):
"""
Parameters:
label: Text to display on the button. Defaults to "Upload a File".
value: File or list of files to upload by default.
label: Text to display on the button. Defaults to "Download".
value: A str or pathlib.Path filepath or URL to download, or a Callable that returns a str or pathlib.Path filepath or URL to download.
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
variant: 'primary' for main call-to-action, 'secondary' for a more subdued style, 'stop' for a stop button.
visible: If False, component will be hidden.
@ -87,16 +88,16 @@ class DownloadButton(Component):
file.name = file_name
return file_name
def postprocess(self, value: str | None) -> FileData | None:
def postprocess(self, value: str | Path | None) -> FileData | None:
"""
Parameters:
value: Expects a `str` filepath
value: Expects a `str` or `pathlib.Path` filepath
Returns:
File information as a FileData object
"""
if value is None:
return None
return FileData(path=value)
return FileData(path=str(value))
def example_inputs(self) -> str:
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"

View File

@ -144,7 +144,11 @@ class App(FastAPI):
FastAPI App Wrapper
"""
def __init__(self, **kwargs):
def __init__(
self,
auth_dependency: Callable[[fastapi.Request], str | None] | None = None,
**kwargs,
):
self.tokens = {}
self.auth = None
self.blocks: gradio.Blocks | None = None
@ -158,6 +162,7 @@ class App(FastAPI):
self.uploaded_file_dir = get_upload_folder()
self.change_event: None | threading.Event = None
self._asyncio_tasks: list[asyncio.Task] = []
self.auth_dependency = auth_dependency
# Allow user to manually set `docs_url` and `redoc_url`
# when instantiating an App; when they're not set, disable docs and redoc.
kwargs.setdefault("docs_url", None)
@ -210,7 +215,9 @@ class App(FastAPI):
@staticmethod
def create_app(
blocks: gradio.Blocks, app_kwargs: Dict[str, Any] | None = None
blocks: gradio.Blocks,
app_kwargs: Dict[str, Any] | None = None,
auth_dependency: Callable[[fastapi.Request], str | None] | None = None,
) -> App:
app_kwargs = app_kwargs or {}
app_kwargs.setdefault("default_response_class", ORJSONResponse)
@ -218,7 +225,7 @@ class App(FastAPI):
app_kwargs["lifespan"] = create_lifespan_handler(
app_kwargs.get("lifespan", None), *blocks.delete_cache
)
app = App(**app_kwargs)
app = App(auth_dependency=auth_dependency, **app_kwargs)
app.configure_app(blocks)
if not wasm_utils.IS_WASM:
@ -227,6 +234,8 @@ class App(FastAPI):
@app.get("/user")
@app.get("/user/")
def get_current_user(request: fastapi.Request) -> Optional[str]:
if app.auth_dependency is not None:
return app.auth_dependency(request)
token = request.cookies.get(
f"access-token-{app.cookie_id}"
) or request.cookies.get(f"access-token-unsecure-{app.cookie_id}")
@ -235,7 +244,7 @@ class App(FastAPI):
@app.get("/login_check")
@app.get("/login_check/")
def login_check(user: str = Depends(get_current_user)):
if app.auth is None or user is not None:
if (app.auth is None and app.auth_dependency is None) or user is not None:
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
@ -344,9 +353,13 @@ class App(FastAPI):
root = route_utils.get_root_url(
request=request, route_path="/", root_path=app.root_path
)
if app.auth is None or user is not None:
if (app.auth is None and app.auth_dependency is None) or user is not None:
config = app.get_blocks().config
config = route_utils.update_root_in_config(config, root)
elif app.auth_dependency:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)
else:
config = {
"auth_required": True,
@ -1002,6 +1015,8 @@ def mount_gradio_app(
blocks: gradio.Blocks,
path: str,
app_kwargs: dict[str, Any] | None = None,
*,
auth_dependency: Callable[[fastapi.Request], str | None] | None = None,
) -> fastapi.FastAPI:
"""Mount a gradio.Blocks to an existing FastAPI application.
@ -1010,6 +1025,7 @@ def mount_gradio_app(
blocks: The blocks object we want to mount to the parent app.
path: The path at which the gradio application will be mounted.
app_kwargs: Additional keyword arguments to pass to the underlying FastAPI app as a dictionary of parameter keys and argument values. For example, `{"docs_url": "/docs"}`
auth_dependency: A function that takes a FastAPI request and returns a string user ID or None. If the function returns None for a specific request, that user is not authorized to access the app (they will see a 401 Unauthorized response). To be used with external authentication systems like OAuth.
Example:
from fastapi import FastAPI
import gradio as gr
@ -1024,7 +1040,9 @@ def mount_gradio_app(
blocks.dev_mode = False
blocks.config = blocks.get_config_file()
blocks.validate_queue_settings()
gradio_app = App.create_app(blocks, app_kwargs=app_kwargs)
gradio_app = App.create_app(
blocks, app_kwargs=app_kwargs, auth_dependency=auth_dependency
)
old_lifespan = app.router.lifespan_context

View File

@ -6,9 +6,9 @@ How to share your Gradio app:
2. [Hosting on HF Spaces](#hosting-on-hf-spaces)
3. [Embedding hosted spaces](#embedding-hosted-spaces)
4. [Using the API page](#api-page)
5. [Authentication](#authentication)
6. [Accessing network requests](#accessing-the-network-request-directly)
7. [Mounting within FastAPI](#mounting-within-another-fast-api-app)
5. [Accessing network requests](#accessing-the-network-request-directly)
6. [Mounting within FastAPI](#mounting-within-another-fast-api-app)
7. [Authentication](#authentication)
8. [Security and file access](#security-and-file-access)
## Sharing Demos
@ -168,6 +168,39 @@ btn.click(add, [num1, num2], output, api_name="addition")
This will add and document the endpoint `/api/addition/` to the automatically generated API page. Otherwise, your API endpoints will appear as "unnamed" endpoints.
## Accessing the Network Request Directly
When a user makes a prediction to your app, you may need the underlying network request, in order to get the request headers (e.g. for advanced authentication), log the client's IP address, getting the query parameters, or for other reasons. Gradio supports this in a similar manner to FastAPI: simply add a function parameter whose type hint is `gr.Request` and Gradio will pass in the network request as that parameter. Here is an example:
```python
import gradio as gr
def echo(text, request: gr.Request):
if request:
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
print("Query parameters:", dict(request.query_params))
return text
io = gr.Interface(echo, "textbox", "textbox").launch()
```
Note: if your function is called directly instead of through the UI (this happens, for
example, when examples are cached, or when the Gradio app is called via API), then `request` will be `None`.
You should handle this case explicitly to ensure that your app does not throw any errors. That is why
we have the explicit check `if request`.
## Mounting Within Another FastAPI App
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo.
You can easily do this with `gradio.mount_gradio_app()`.
Here's a complete example:
$code_custom_path
Note that this approach also allows you run your Gradio apps on custom paths (`http://localhost:8000/gradio` in the example above).
## Authentication
@ -179,7 +212,7 @@ You may wish to put an authentication page in front of your app to limit who can
demo.launch(auth=("admin", "pass1234"))
```
For more complex authentication handling, you can even pass a function that takes a username and password as arguments, and returns True to allow authentication, False otherwise. This can be used for, among other things, making requests to 3rd-party authentication services.
For more complex authentication handling, you can even pass a function that takes a username and password as arguments, and returns `True` to allow access, `False` otherwise.
Here's an example of a function that accepts any login where the username and password are the same:
@ -189,7 +222,7 @@ def same_auth(username, password):
demo.launch(auth=same_auth)
```
If you have multiple users, you may wish to customize the content that is shown depending on the user that is logged in. You can retrieve the logged in user by [accessing the network request directly](#accessing-the-network-request-directly) and then reading the `.username` attribute of the request. Here's an example:
If you have multiple users, you may wish to customize the content that is shown depending on the user that is logged in. You can retrieve the logged in user by [accessing the network request directly](#accessing-the-network-request-directly) as discussed above, and then reading the `.username` attribute of the request. Here's an example:
```python
@ -227,8 +260,7 @@ Note: Gradio's built-in authentication provides a straightforward and basic laye
### OAuth (Login via Hugging Face)
Gradio supports OAuth login via Hugging Face. This feature is currently **experimental** and only available on Spaces.
It allows you to add a _"Sign in with Hugging Face"_ button to your demo. Check out [this Space](https://huggingface.co/spaces/Wauplin/gradio-oauth-demo) for a live demo.
Gradio natively supports OAuth login via Hugging Face. In other words, you can easily add a _"Sign in with Hugging Face"_ button to your demo, which allows you to get a user's HF username as well as other information from their HF profile. Check out [this Space](https://huggingface.co/spaces/Wauplin/gradio-oauth-demo) for a live demo.
To enable OAuth, you must set `hf_oauth: true` as a Space metadata in your README.md file. This will register your Space
as an OAuth application on Hugging Face. Next, you can use `gr.LoginButton` and `gr.LogoutButton` to add login and logout buttons to
@ -275,38 +307,119 @@ Users can revoke access to their profile at any time in their [settings](https:/
As seen above, OAuth features are available only when your app runs in a Space. However, you often need to test your app
locally before deploying it. To test OAuth features locally, your machine must be logged in to Hugging Face. Please run `huggingface-cli login` or set `HF_TOKEN` as environment variable with one of your access token. You can generate a new token in your settings page (https://huggingface.co/settings/tokens). Then, clicking on the `gr.LoginButton` will login your local Hugging Face profile, allowing you to debug your app with your Hugging Face account before deploying it to a Space.
## Accessing the Network Request Directly
When a user makes a prediction to your app, you may need the underlying network request, in order to get the request headers (e.g. for advanced authentication), log the client's IP address, getting the query parameters, or for other reasons. Gradio supports this in a similar manner to FastAPI: simply add a function parameter whose type hint is `gr.Request` and Gradio will pass in the network request as that parameter. Here is an example:
### OAuth (with external providers)
It is also possible to authenticate with external OAuth providers (e.g. Google OAuth) in your Gradio apps. To do this, first mount your Gradio app within a FastAPI app ([as discussed above](#mounting-within-another-fast-api-app)). Then, you must write an *authentication function*, which gets the user's username from the OAuth provider and returns it. This function should be passed to the `auth_dependency` parameter in `gr.mount_gradio_app`.
Similar to [FastAPI dependency functions](https://fastapi.tiangolo.com/tutorial/dependencies/), the function specified by `auth_dependency` will run before any Gradio-related route in your FastAPI app. The function should accept a single parameter: the FastAPI `Request` and return either a string (representing a user's username) or `None`. If a string is returned, the user will be able to access the Gradio-related routes in your FastAPI app.
First, let's show a simplistic example to illustrate the `auth_dependency` parameter:
```python
from fastapi import FastAPI, Request
import gradio as gr
def echo(text, request: gr.Request):
if request:
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
print("Query parameters:", dict(request.query_params))
return text
app = FastAPI()
io = gr.Interface(echo, "textbox", "textbox").launch()
def get_user(request: Request):
return request.headers.get("user")
demo = gr.Interface(lambda s: f"Hello {s}!", "textbox", "textbox")
app = gr.mount_gradio_app(app, demo, path="/demo", auth_dependency=get_user)
if __name__ == '__main__':
uvicorn.run(app)
```
Note: if your function is called directly instead of through the UI (this happens, for
example, when examples are cached, or when the Gradio app is called via API), then `request` will be `None`.
You should handle this case explicitly to ensure that your app does not throw any errors. That is why
we have the explicit check `if request`.
In this example, only requests that include a "user" header will be allowed to access the Gradio app. Of course, this does not add much security, since any user can add this header in their request.
## Mounting Within Another FastAPI App
Here's a more complete example showing how to add Google OAuth to a Gradio app (assuming you've already created OAuth Credentials on the [Google Developer Console](https://console.cloud.google.com/project)):
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo.
You can easily do this with `gradio.mount_gradio_app()`.
```python
import os
from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import FastAPI, Depends, Request
from starlette.config import Config
from starlette.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
import uvicorn
import gradio as gr
Here's a complete example:
app = FastAPI()
$code_custom_path
# Replace these with your own OAuth settings
GOOGLE_CLIENT_ID = "..."
GOOGLE_CLIENT_SECRET = "..."
SECRET_KEY = "..."
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET}
starlette_config = Config(environ=config_data)
oauth = OAuth(starlette_config)
oauth.register(
name='google',
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
client_kwargs={'scope': 'openid email profile'},
)
SECRET_KEY = os.environ.get('SECRET_KEY') or "a_very_secret_key"
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
# Dependency to get the current user
def get_user(request: Request):
user = request.session.get('user')
if user:
return user['name']
return None
@app.get('/')
def public(user: dict = Depends(get_user)):
if user:
return RedirectResponse(url='/gradio')
else:
return RedirectResponse(url='/login-demo')
@app.route('/logout')
async def logout(request: Request):
request.session.pop('user', None)
return RedirectResponse(url='/')
@app.route('/login')
async def login(request: Request):
redirect_uri = request.url_for('auth')
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.route('/auth')
async def auth(request: Request):
try:
access_token = await oauth.google.authorize_access_token(request)
except OAuthError:
return RedirectResponse(url='/')
request.session['user'] = dict(access_token)["userinfo"]
return RedirectResponse(url='/')
with gr.Blocks() as login_demo:
gr.Button("Login", link="/login")
app = gr.mount_gradio_app(app, login_demo, path="/login-demo")
def greet(request: gr.Request):
return f"Welcome to Gradio, {request.username}"
with gr.Blocks() as main_demo:
m = gr.Markdown("Welcome to Gradio!")
gr.Button("Logout", link="/logout")
main_demo.load(greet, None, m)
app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user)
if __name__ == '__main__':
uvicorn.run(app)
```
There are actually two separate Gradio apps in this example! One that simply displays a log in button (this demo is accessible to any user), while the other main demo is only accessible to users that are logged in. You can try this example out on [this Space](https://huggingface.co/spaces/gradio/oauth-example).
Note that this approach also allows you run your Gradio apps on custom paths (`http://localhost:8000/gradio` in the example above).
## Security and File Access

View File

@ -1,13 +0,0 @@
import { test, expect } from "@gradio/tootils";
test("test that the submit and clear buttons in a loaded space work", async ({
page
}) => {
await page.getByLabel("x").click();
await page.getByLabel("Pakistan", { exact: true }).click();
await page.getByRole("button", { name: "Submit" }).click();
await expect(await page.getByLabel("Output")).toHaveValue("Pakistan");
await page.getByRole("button", { name: "Clear" }).click();
await expect(await page.getByLabel("Output")).toHaveValue("");
});

View File

@ -26,7 +26,12 @@ def io_components():
while classes_to_check:
subclass = classes_to_check.pop()
if subclass in [gr.components.FormComponent, gr.State]:
if subclass in [
gr.components.FormComponent,
gr.State,
gr.LoginButton,
gr.LogoutButton,
]:
continue
children = subclass.__subclasses__()

View File

@ -498,7 +498,6 @@ class TestComponentsInBlocks:
assert all(dep["queue"] is False for dep in demo.config["dependencies"])
def test_io_components_attach_load_events_when_value_is_fn(self, io_components):
io_components = [comp for comp in io_components if comp not in [gr.State]]
interface = gr.Interface(
lambda *args: None,
inputs=[comp(value=lambda: None, every=1) for comp in io_components],

View File

@ -29,11 +29,13 @@ class TestClearButton:
class TestOAuthButtons:
@pytest.mark.flaky
def test_login_button_warns_when_not_on_spaces(self):
with pytest.warns(UserWarning):
with gr.Blocks():
gr.LoginButton()
@pytest.mark.flaky
def test_logout_button_warns_when_not_on_spaces(self):
with pytest.warns(UserWarning):
with gr.Blocks():

View File

@ -4,6 +4,7 @@ import transformers
import gradio as gr
@pytest.mark.flaky
def test_text_to_text_model_from_pipeline():
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
io = gr.Interface.from_pipeline(pipe)

View File

@ -390,6 +390,20 @@ class TestRoutes:
assert client.get("/ps").is_success
assert client.get("/py").is_success
def test_mount_gradio_app_with_auth_dependency(self):
app = FastAPI()
def get_user(request: Request):
return request.headers.get("user")
demo = gr.Interface(lambda s: f"Hello from ps, {s}!", "textbox", "textbox")
app = gr.mount_gradio_app(app, demo, path="/demo", auth_dependency=get_user)
with TestClient(app) as client:
assert client.get("/demo", headers={"user": "abubakar"}).is_success
assert not client.get("/demo").is_success
def test_static_file_missing(self, test_client):
response = test_client.get(r"/static/not-here.js")
assert response.status_code == 404

View File

@ -297,6 +297,7 @@ class TestThemeUploadDownload:
# assert demo.theme.to_dict() == dracula.to_dict()
# assert demo.theme.name == "gradio/dracula_test"
@pytest.mark.flaky
def test_theme_download_raises_error_if_theme_does_not_exist(self):
with pytest.raises(
ValueError, match="The space freddyaboulton/nonexistent does not exist"