mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Sign in with Hugging Face (OAuth support) (#4943)
* first draft * debug * add print * working oauth * inject OAuth profile + enable OAuth when expected + some doc * add changeset * mypy * opt * open in a new tab only from iframe * msg * add changeset * Apply suggestions from code review Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * fix injection + gr.Error * allow third party cookie when possible * add button to sign in/sign out button * feedback changes * oauth as optional dependency * disable login/logout buttons locally * nothing * a bit of documentation * Add tests for Login/Logout buttons * Apply suggestions from code review Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * mention required dependencies * fix package * fix tests * fix windows tests as well * Fake profile on local debug * doc * fix tets * lint * fix test * test buttons * login button fix * lint * fix final tests --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
This commit is contained in:
parent
987725cf6a
commit
947d615db6
5
.changeset/hot-worms-type.md
Normal file
5
.changeset/hot-worms-type.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Sign in with Hugging Face (OAuth support)
|
5
.github/workflows/backend.yml
vendored
5
.github/workflows/backend.yml
vendored
@ -128,6 +128,7 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
client/python/requirements.txt
|
||||
requirements.txt
|
||||
requirements-oauth.txt
|
||||
test/requirements.txt
|
||||
- name: Create env
|
||||
run: |
|
||||
@ -138,7 +139,7 @@ jobs:
|
||||
with:
|
||||
path: |
|
||||
venv/*
|
||||
key: gradio-lib-${{ runner.os }}-pip-${{ hashFiles('client/python/requirements.txt') }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('test/requirements.txt') }}
|
||||
key: gradio-lib-${{ runner.os }}-pip-${{ hashFiles('client/python/requirements.txt') }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-oauth.txt') }}-${{ hashFiles('test/requirements.txt') }}
|
||||
- uses: actions/cache@v3
|
||||
id: frontend-cache
|
||||
with:
|
||||
@ -198,7 +199,7 @@ jobs:
|
||||
if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Windows'
|
||||
run: |
|
||||
venv\Scripts\activate
|
||||
python -m pip install -e . -r test/requirements.txt
|
||||
python -m pip install -e . -r test/requirements.txt -r requirements-oauth.txt
|
||||
- name: Run tests (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
|
@ -39,6 +39,8 @@ from gradio.components import (
|
||||
Json,
|
||||
Label,
|
||||
LinePlot,
|
||||
LoginButton,
|
||||
LogoutButton,
|
||||
Markdown,
|
||||
Model3D,
|
||||
Number,
|
||||
@ -82,6 +84,7 @@ from gradio.interface import Interface, TabbedInterface, close_all
|
||||
from gradio.ipython_ext import load_ipython_extension
|
||||
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
|
||||
from gradio.mix import Parallel, Series
|
||||
from gradio.oauth import OAuthProfile
|
||||
from gradio.routes import Request, mount_gradio_app
|
||||
from gradio.templates import (
|
||||
Files,
|
||||
|
@ -918,6 +918,14 @@ class Blocks(BlockContext):
|
||||
repr += f"\n |-{block}"
|
||||
return repr
|
||||
|
||||
@property
|
||||
def expects_oauth(self):
|
||||
"""Return whether the app expects user to authenticate via OAuth."""
|
||||
return any(
|
||||
isinstance(block, (components.LoginButton, components.LogoutButton))
|
||||
for block in self.blocks.values()
|
||||
)
|
||||
|
||||
def render(self):
|
||||
if Context.root_block is not None:
|
||||
if self._id in Context.root_block.blocks:
|
||||
|
@ -33,6 +33,8 @@ from gradio.components.interpretation import Interpretation
|
||||
from gradio.components.json_component import JSON
|
||||
from gradio.components.label import Label
|
||||
from gradio.components.line_plot import LinePlot
|
||||
from gradio.components.login_button import LoginButton
|
||||
from gradio.components.logout_button import LogoutButton
|
||||
from gradio.components.markdown import Markdown
|
||||
from gradio.components.model3d import Model3D
|
||||
from gradio.components.number import Number
|
||||
@ -87,6 +89,8 @@ __all__ = [
|
||||
"Json",
|
||||
"Label",
|
||||
"LinePlot",
|
||||
"LoginButton",
|
||||
"LogoutButton",
|
||||
"Markdown",
|
||||
"Textbox",
|
||||
"Dropdown",
|
||||
|
95
gradio/components/login_button.py
Normal file
95
gradio/components/login_button.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Predefined button to sign in with Hugging Face in a Gradio Space."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Literal
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio.components import Button
|
||||
from gradio.context import Context
|
||||
from gradio.routes import Request
|
||||
|
||||
set_documentation_group("component")
|
||||
|
||||
|
||||
@document()
|
||||
class LoginButton(Button):
|
||||
"""
|
||||
Button that redirects the user to Sign with Hugging Face using OAuth.
|
||||
"""
|
||||
|
||||
is_template = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
value: str = "Sign in with Hugging Face",
|
||||
variant: Literal["primary", "secondary", "stop"] = "secondary",
|
||||
size: Literal["sm", "lg"] | None = None,
|
||||
icon: str
|
||||
| None = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg",
|
||||
link: str | None = None,
|
||||
visible: bool = True,
|
||||
interactive: bool = True,
|
||||
elem_id: str | None = None,
|
||||
elem_classes: list[str] | str | None = None,
|
||||
scale: int | None = 0,
|
||||
min_width: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
value,
|
||||
variant=variant,
|
||||
size=size,
|
||||
icon=icon,
|
||||
link=link,
|
||||
visible=visible,
|
||||
interactive=interactive,
|
||||
elem_id=elem_id,
|
||||
elem_classes=elem_classes,
|
||||
scale=scale,
|
||||
min_width=min_width,
|
||||
**kwargs,
|
||||
)
|
||||
if Context.root_block is not None:
|
||||
self.activate()
|
||||
else:
|
||||
warnings.warn(
|
||||
"LoginButton created outside of a Blocks context. May not work unless you call its `activate()` method manually."
|
||||
)
|
||||
|
||||
def activate(self):
|
||||
# Taken from https://cmgdo.com/external-link-in-gradio-button/
|
||||
# Taking `self` as input to check if user is logged in
|
||||
# ('self' value will be either "Sign in with Hugging Face" or "Signed in as ...")
|
||||
self.click(fn=None, inputs=[self], outputs=None, _js=_js_open_if_not_logged_in)
|
||||
|
||||
self.attach_load_event(self._check_login_status, None)
|
||||
|
||||
def _check_login_status(self, request: Request) -> dict[str, Any]:
|
||||
# Each time the page is refreshed or loaded, check if the user is logged in and adapt label
|
||||
session = getattr(request, "session", None) or getattr(
|
||||
request.request, "session", None
|
||||
)
|
||||
if session is None or "oauth_profile" not in session:
|
||||
return self.update("Sign in with Hugging Face", interactive=True)
|
||||
else:
|
||||
username = session["oauth_profile"]["preferred_username"]
|
||||
return self.update(f"Signed in as {username}", interactive=False)
|
||||
|
||||
|
||||
# JS code to redirects to /login/huggingface if user is not logged in.
|
||||
# If the app is opened in an iframe, open the login page in a new tab.
|
||||
# Otherwise, redirects locally. Taken from https://stackoverflow.com/a/61596084.
|
||||
_js_open_if_not_logged_in = """
|
||||
(buttonValue) => {
|
||||
if (!buttonValue.includes("Signed in")) {
|
||||
if ( window !== window.parent ) {
|
||||
window.open('/login/huggingface', '_blank');
|
||||
} else {
|
||||
window.location.assign('/login/huggingface');
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
52
gradio/components/logout_button.py
Normal file
52
gradio/components/logout_button.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Predefined button to sign out from Hugging Face in a Gradio Space."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio.components import Button
|
||||
|
||||
set_documentation_group("component")
|
||||
|
||||
|
||||
@document()
|
||||
class LogoutButton(Button):
|
||||
"""
|
||||
Button to log out a user from a Space.
|
||||
"""
|
||||
|
||||
is_template = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
value: str = "Logout",
|
||||
variant: Literal["primary", "secondary", "stop"] = "secondary",
|
||||
size: Literal["sm", "lg"] | None = None,
|
||||
icon: str
|
||||
| None = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg",
|
||||
# Link to logout page (which will delete the session cookie and redirect to landing page).
|
||||
link: str | None = "/logout",
|
||||
visible: bool = True,
|
||||
interactive: bool = True,
|
||||
elem_id: str | None = None,
|
||||
elem_classes: list[str] | str | None = None,
|
||||
scale: int | None = 0,
|
||||
min_width: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
value,
|
||||
variant=variant,
|
||||
size=size,
|
||||
icon=icon,
|
||||
link=link,
|
||||
visible=visible,
|
||||
interactive=interactive,
|
||||
elem_id=elem_id,
|
||||
elem_classes=elem_classes,
|
||||
scale=scale,
|
||||
min_width=min_width,
|
||||
**kwargs,
|
||||
)
|
@ -13,7 +13,7 @@ import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
@ -23,8 +23,9 @@ from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
from matplotlib import animation
|
||||
|
||||
from gradio import components, processing_utils, routes, utils
|
||||
from gradio import components, oauth, processing_utils, routes, utils
|
||||
from gradio.context import Context
|
||||
from gradio.exceptions import Error
|
||||
from gradio.flagging import CSVLogger
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
@ -690,6 +691,30 @@ def special_args(
|
||||
elif type_hint == routes.Request:
|
||||
if inputs is not None:
|
||||
inputs.insert(i, request)
|
||||
elif (
|
||||
type_hint == Optional[oauth.OAuthProfile]
|
||||
or type_hint == oauth.OAuthProfile
|
||||
# Note: "OAuthProfile | None" is equals to Optional[OAuthProfile] in Python
|
||||
# => it is automatically handled as well by the above condition
|
||||
# (adding explicit "OAuthProfile | None" would break in Python3.9)
|
||||
):
|
||||
if inputs is not None:
|
||||
# Retrieve session from gr.Request, if it exists (i.e. if user is logged in)
|
||||
session = (
|
||||
# request.session (if fastapi.Request obj i.e. direct call)
|
||||
getattr(request, "session", {})
|
||||
or
|
||||
# or request.request.session (if gr.Request obj i.e. websocket call)
|
||||
getattr(getattr(request, "request", None), "session", {})
|
||||
)
|
||||
oauth_profile = (
|
||||
session["oauth_profile"] if "oauth_profile" in session else None
|
||||
)
|
||||
if type_hint == oauth.OAuthProfile and oauth_profile is None:
|
||||
raise Error(
|
||||
"This action requires a logged in user. Please sign in and retry."
|
||||
)
|
||||
inputs.insert(i, oauth_profile)
|
||||
elif (
|
||||
type_hint
|
||||
and inspect.isclass(type_hint)
|
||||
|
186
gradio/oauth.py
Normal file
186
gradio/oauth.py
Normal file
@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
import fastapi
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from .utils import get_space
|
||||
|
||||
OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
|
||||
OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
|
||||
OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
|
||||
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
|
||||
|
||||
|
||||
def attach_oauth(app: fastapi.FastAPI):
|
||||
try:
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
|
||||
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
|
||||
) from e
|
||||
|
||||
# Add `/login/huggingface`, `/login/callback` and `/logout` routes to enable OAuth in the Gradio app.
|
||||
# If the app is running in a Space, OAuth is enabled normally. Otherwise, we mock the "real" routes to make the
|
||||
# user log in with a fake user profile - without any calls to hf.co.
|
||||
if get_space() is not None:
|
||||
_add_oauth_routes(app)
|
||||
else:
|
||||
_add_mocked_oauth_routes(app)
|
||||
|
||||
# Session Middleware requires a secret key to sign the cookies. Let's use a hash
|
||||
# of the OAuth secret key to make it unique to the Space + updated in case OAuth
|
||||
# config gets updated.
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=hashlib.sha256((OAUTH_CLIENT_SECRET or "").encode()).hexdigest(),
|
||||
same_site="none",
|
||||
https_only=True,
|
||||
)
|
||||
|
||||
|
||||
def _add_oauth_routes(app: fastapi.FastAPI) -> None:
|
||||
"""Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
|
||||
try:
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot initialize OAuth to due a missing library. Please run `pip install gradio[oauth]` or add "
|
||||
"`gradio[oauth]` to your requirements.txt file in order to install the required dependencies."
|
||||
) from e
|
||||
|
||||
# Check environment variables
|
||||
msg = (
|
||||
"OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
|
||||
" setting `hf_oauth: true` in the Space metadata."
|
||||
)
|
||||
if OAUTH_CLIENT_ID is None:
|
||||
raise ValueError(msg.format("OAUTH_CLIENT_ID"))
|
||||
if OAUTH_CLIENT_SECRET is None:
|
||||
raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
|
||||
if OAUTH_SCOPES is None:
|
||||
raise ValueError(msg.format("OAUTH_SCOPES"))
|
||||
if OPENID_PROVIDER_URL is None:
|
||||
raise ValueError(msg.format("OPENID_PROVIDER_URL"))
|
||||
|
||||
# Register OAuth server
|
||||
oauth = OAuth()
|
||||
oauth.register(
|
||||
name="huggingface",
|
||||
client_id=OAUTH_CLIENT_ID,
|
||||
client_secret=OAUTH_CLIENT_SECRET,
|
||||
client_kwargs={"scope": OAUTH_SCOPES},
|
||||
server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
|
||||
)
|
||||
|
||||
# Define OAuth routes
|
||||
@app.get("/login/huggingface")
|
||||
async def oauth_login(request: fastapi.Request):
|
||||
"""Endpoint that redirects to HF OAuth page."""
|
||||
redirect_uri = str(request.url_for("oauth_redirect_callback"))
|
||||
if ".hf.space" in redirect_uri:
|
||||
# In Space, FastAPI redirect as http but we want https
|
||||
redirect_uri = redirect_uri.replace("http://", "https://")
|
||||
return await oauth.huggingface.authorize_redirect(request, redirect_uri)
|
||||
|
||||
@app.get("/login/callback")
|
||||
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
|
||||
"""Endpoint that handles the OAuth callback."""
|
||||
token = await oauth.huggingface.authorize_access_token(request)
|
||||
request.session["oauth_profile"] = token["userinfo"]
|
||||
request.session["oauth_token"] = token
|
||||
return RedirectResponse("/")
|
||||
|
||||
@app.get("/logout")
|
||||
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
|
||||
"""Endpoint that logs out the user (e.g. delete cookie session)."""
|
||||
request.session.pop("oauth_profile", None)
|
||||
request.session.pop("oauth_token", None)
|
||||
return RedirectResponse("/")
|
||||
|
||||
|
||||
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
|
||||
"""Add fake oauth routes if Gradio is run locally and OAuth is enabled.
|
||||
|
||||
Clicking on a gr.LoginButton will have the same behavior as in a Space (i.e. gets redirected in a new tab) but
|
||||
instead of authenticating with HF, a mocked user profile is added to the session.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Gradio does not support OAuth features outside of a Space environment. "
|
||||
"To help you debug your app locally, the login and logout buttons are mocked with a fake user profile."
|
||||
)
|
||||
|
||||
# Define OAuth routes
|
||||
@app.get("/login/huggingface")
|
||||
async def oauth_login(request: fastapi.Request):
|
||||
"""Fake endpoint that redirects to HF OAuth page."""
|
||||
return RedirectResponse("/login/callback")
|
||||
|
||||
@app.get("/login/callback")
|
||||
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
|
||||
"""Endpoint that handles the OAuth callback."""
|
||||
request.session["oauth_profile"] = MOCKED_OAUTH_TOKEN["userinfo"]
|
||||
request.session["oauth_token"] = MOCKED_OAUTH_TOKEN
|
||||
return RedirectResponse("/")
|
||||
|
||||
@app.get("/logout")
|
||||
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
|
||||
"""Endpoint that logs out the user (e.g. delete cookie session)."""
|
||||
request.session.pop("oauth_profile", None)
|
||||
request.session.pop("oauth_token", None)
|
||||
return RedirectResponse("/")
|
||||
|
||||
|
||||
class OAuthProfile(typing.Dict):
|
||||
"""
|
||||
A Gradio OAuthProfile object that can be used to inject the profile of a user in a
|
||||
function. If a function expects `OAuthProfile` or `Optional[OAuthProfile]` as input,
|
||||
the value will be injected from the FastAPI session if the user is logged in. If the
|
||||
user is not logged in and the function expects `OAuthProfile`, an error will be
|
||||
raised.
|
||||
|
||||
Example:
|
||||
import gradio as gr
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def hello(profile: Optional[gr.OAuthProfile]) -> str:
|
||||
if profile is None:
|
||||
return "I don't know you."
|
||||
return f"Hello {profile.name}"
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.LoginButton()
|
||||
gr.LogoutButton()
|
||||
gr.Markdown().attach_load_event(hello, None)
|
||||
"""
|
||||
|
||||
|
||||
MOCKED_OAUTH_TOKEN = {
|
||||
"access_token": "hf_oauth_AAAAAAAAAAAAAAAAAAAAAAAAAA",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
"id_token": "AAAAAAAAAAAAAAAAAAAAAAAAAA",
|
||||
"scope": "openid profile",
|
||||
"expires_at": 1691676444,
|
||||
"userinfo": {
|
||||
"sub": "11111111111111111111111",
|
||||
"name": "Fake Gradio User",
|
||||
"preferred_username": "FakeGradioUser",
|
||||
"profile": "https://huggingface.co/FakeGradioUser",
|
||||
"picture": "https://huggingface.co/front/assets/huggingface_logo-noborder.svg",
|
||||
"website": "",
|
||||
"aud": "00000000-0000-0000-0000-000000000000",
|
||||
"auth_time": 1691672844,
|
||||
"nonce": "aaaaaaaaaaaaaaaaaaa",
|
||||
"iat": 1691672844,
|
||||
"exp": 1691676444,
|
||||
"iss": "https://huggingface.co",
|
||||
},
|
||||
}
|
@ -338,13 +338,20 @@ class Queue:
|
||||
)
|
||||
|
||||
def get_request_params(self, websocket: fastapi.WebSocket) -> dict[str, Any]:
|
||||
return {
|
||||
params = {
|
||||
"url": str(websocket.url),
|
||||
"headers": dict(websocket.headers),
|
||||
"query_params": dict(websocket.query_params),
|
||||
"path_params": dict(websocket.path_params),
|
||||
"client": {"host": websocket.client.host, "port": websocket.client.port}, # type: ignore
|
||||
}
|
||||
try:
|
||||
params[
|
||||
"session"
|
||||
] = websocket.session # forward OAuth information if available
|
||||
except Exception:
|
||||
pass
|
||||
return params
|
||||
|
||||
async def call_prediction(self, events: list[Event], batch: bool):
|
||||
data = events[0].data
|
||||
|
@ -53,6 +53,7 @@ from gradio.context import Context
|
||||
from gradio.data_classes import PredictBody, ResetBody
|
||||
from gradio.exceptions import Error
|
||||
from gradio.helpers import EventData
|
||||
from gradio.oauth import attach_oauth
|
||||
from gradio.queueing import Estimation, Event
|
||||
from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name
|
||||
|
||||
@ -243,6 +244,15 @@ class App(FastAPI):
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Incorrect credentials.")
|
||||
|
||||
###############
|
||||
# OAuth Routes
|
||||
###############
|
||||
|
||||
# Define OAuth routes if the app expects it (i.e. a LoginButton is defined).
|
||||
# It allows users to "Sign in with HuggingFace".
|
||||
if app.blocks is not None and app.blocks.expects_oauth:
|
||||
attach_oauth(app)
|
||||
|
||||
###############
|
||||
# Main Routes
|
||||
###############
|
||||
|
@ -26,6 +26,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -760,7 +761,7 @@ def get_cancel_function(
|
||||
def get_type_hints(fn):
|
||||
# Importing gradio with the canonical abbreviation. Used in typing._eval_type.
|
||||
import gradio as gr # noqa: F401
|
||||
from gradio import Request # noqa: F401
|
||||
from gradio import OAuthProfile, Request # noqa: F401
|
||||
|
||||
if inspect.isfunction(fn) or inspect.ismethod(fn):
|
||||
pass
|
||||
@ -780,6 +781,9 @@ def get_type_hints(fn):
|
||||
for name, param in sig.parameters.items():
|
||||
if param.annotation is inspect.Parameter.empty:
|
||||
continue
|
||||
if param.annotation == "gr.OAuthProfile | None":
|
||||
# Special case: we want to inject the OAuthProfile value even on Python 3.9
|
||||
type_hints[name] = Optional[OAuthProfile]
|
||||
if "|" in str(param.annotation):
|
||||
continue
|
||||
# To convert the string annotation to a class, we use the
|
||||
@ -797,15 +801,17 @@ def get_type_hints(fn):
|
||||
|
||||
def is_special_typed_parameter(name, parameter_types):
|
||||
from gradio.helpers import EventData
|
||||
from gradio.oauth import OAuthProfile
|
||||
from gradio.routes import Request
|
||||
|
||||
"""Checks if parameter has a type hint designating it as a gr.Request or gr.EventData"""
|
||||
"""Checks if parameter has a type hint designating it as a gr.Request, gr.EventData or gr.OAuthProfile."""
|
||||
hint = parameter_types.get(name)
|
||||
if not hint:
|
||||
return False
|
||||
is_request = hint == Request
|
||||
is_oauth_arg = hint in (OAuthProfile, Optional[OAuthProfile])
|
||||
is_event_data = inspect.isclass(hint) and issubclass(hint, EventData)
|
||||
return is_request or is_event_data
|
||||
return is_request or is_event_data or is_oauth_arg
|
||||
|
||||
|
||||
def check_function_inputs_match(fn: Callable, inputs: list, inputs_as_dict: bool):
|
||||
|
@ -147,6 +147,8 @@ This will add and document the endpoint `/api/addition/` to the automatically ge
|
||||
|
||||
## Authentication
|
||||
|
||||
### Password-protected app
|
||||
|
||||
You may wish to put an authentication page in front of your app to limit who can open your app. With the `auth=` keyword argument in the `launch()` method, you can provide a tuple with a username and password, or a list of acceptable username/password tuples; Here's an example that provides password-based authentication for a single user named "admin":
|
||||
|
||||
```python
|
||||
@ -166,6 +168,47 @@ demo.launch(auth=same_auth)
|
||||
For authentication to work properly, third party cookies must be enabled in your browser.
|
||||
This is not the case by default for Safari, Chrome Incognito Mode.
|
||||
|
||||
### OAuth (Login via Hugging Face)
|
||||
|
||||
Gradio supports OAuth login via Hugging Face. This feature is currently **experimental** and only available on Spaces.
|
||||
If allows to add a *"Sign in with Hugging Face"* button to your demo. Check out [this Space](https://huggingface.co/spaces/Wauplin/gradio-oauth-demo)
|
||||
for a live demo.
|
||||
|
||||
To enable OAuth, you must set `hf_oauth: true` as a Space metadata in your README.md file. This will register your Space
|
||||
as an OAuth application on Hugging Face. You also need to include `itsdangerous` and `authlib` in a separate
|
||||
`requirements.txt` file. Next, you can use `gr.LoginButton` and `gr.LogoutButton` to add login and logout buttons to
|
||||
your Gradio app. Once a user is logged in with their HF account, you can retrieve their profile. To do so, you only
|
||||
have to add a parameter of type `gr.OAuthProfile` to any Gradio function. The user profile will be automatically
|
||||
injected as a parameter value.
|
||||
|
||||
Here is a short example:
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def hello(profile: gr.OAuthProfile | None) -> str:
|
||||
if profile is None:
|
||||
return "I don't know you."
|
||||
return f"Hello {profile.name}"
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.LoginButton()
|
||||
gr.LogoutButton()
|
||||
gr.Markdown().attach_load_event(hello, None)
|
||||
```
|
||||
|
||||
When the user clicks on the login button, they get redirected in a new page to authorize your Space.
|
||||
|
||||
![Allow Space app](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guides/oauth_sign_in.png)
|
||||
|
||||
Users can revoke access to their profile at any time in their [settings](https://huggingface.co/settings/connected-applications).
|
||||
|
||||
As seen above, OAuth features are available only when your app runs in a Space. However, you often need to test your app
|
||||
locally before deploying it. To help with that, the `gr.LoginButton` is mocked. When a user clicks on it, they are
|
||||
automatically logged in with a fake user profile. This allows you to debug your app before deploying it to a Space.
|
||||
|
||||
## Accessing the Network Request Directly
|
||||
|
||||
When a user makes a prediction to your app, you may need the underlying network request, in order to get the request headers (e.g. for advanced authentication), log the client's IP address, or for other reasons. Gradio supports this in a similar manner to FastAPI: simply add a function parameter whose type hint is `gr.Request` and Gradio will pass in the network request as that parameter. Here is an example:
|
||||
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "gradio"
|
||||
dynamic = ["version", "dependencies", "readme"]
|
||||
dynamic = ["version", "dependencies", "optional-dependencies", "readme"]
|
||||
description = "Python library for easily interacting with trained machine learning models"
|
||||
license = "Apache-2.0"
|
||||
requires-python = ">=3.8"
|
||||
@ -48,6 +48,9 @@ pattern = "(?P<version>.+)"
|
||||
[tool.hatch.metadata.hooks.requirements_txt]
|
||||
filename = "requirements.txt"
|
||||
|
||||
[tool.hatch.metadata.hooks.requirements_txt.optional-dependencies]
|
||||
oauth = ["requirements-oauth.txt"]
|
||||
|
||||
[tool.hatch.metadata.hooks.fancy-pypi-readme]
|
||||
content-type = "text/markdown"
|
||||
fragments = [
|
||||
|
2
requirements-oauth.txt
Normal file
2
requirements-oauth.txt
Normal file
@ -0,0 +1,2 @@
|
||||
authlib
|
||||
itsdangerous # required for starlette SessionMiddleware
|
@ -25,4 +25,4 @@ requests~=2.0
|
||||
semantic_version~=2.0
|
||||
typing_extensions~=4.0
|
||||
uvicorn>=0.14.0
|
||||
websockets>=10.0,<12.0
|
||||
websockets>=10.0,<12.0
|
@ -8,4 +8,5 @@ pip_required
|
||||
echo "Installing requirements before running tests..."
|
||||
pip install --upgrade pip
|
||||
pip install -r test/requirements.txt
|
||||
pip install -r requirements-oauth.txt
|
||||
pip install -e client/python
|
||||
|
@ -594,7 +594,7 @@ class TestComponentsInBlocks:
|
||||
with gr.Blocks() as demo:
|
||||
for component in io_components:
|
||||
components.append(component(value=lambda: None, every=1))
|
||||
assert [comp.load_event for comp in components] == demo.dependencies
|
||||
assert all(comp.load_event in demo.dependencies for comp in components)
|
||||
|
||||
|
||||
class TestBlocksPostprocessing:
|
||||
|
@ -1,3 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
@ -13,3 +17,28 @@ class TestClearButton:
|
||||
assert not clear_event_trigger["backend_fn"]
|
||||
assert clear_event_trigger["js"]
|
||||
assert clear_event_trigger["outputs"] == [textbox._id, chatbot._id]
|
||||
|
||||
|
||||
class TestOAuthButtons:
|
||||
def test_login_button_warns_when_not_on_spaces(self):
|
||||
with pytest.warns(UserWarning):
|
||||
with gr.Blocks():
|
||||
gr.LoginButton()
|
||||
|
||||
def test_logout_button_warns_when_not_on_spaces(self):
|
||||
with pytest.warns(UserWarning):
|
||||
with gr.Blocks():
|
||||
gr.LogoutButton()
|
||||
|
||||
@patch("gradio.oauth.get_space", lambda: "fake_space")
|
||||
@patch("gradio.oauth._add_oauth_routes")
|
||||
def test_login_button_setup_correctly(self, mock_add_oauth_routes):
|
||||
with gr.Blocks() as demo:
|
||||
button = gr.LoginButton()
|
||||
|
||||
login_event = demo.dependencies[0]
|
||||
assert login_event["trigger"] == "click"
|
||||
assert not login_event["backend_fn"] # No Python code
|
||||
assert login_event["js"] # But JS code instead
|
||||
assert login_event["inputs"] == [button._id]
|
||||
assert login_event["outputs"] == []
|
||||
|
Loading…
Reference in New Issue
Block a user