Redirect with query params after oauth (#7034)

* Redirect with query params after oauth

* add changeset

* fix redirect

* align online and offline

* location

* Update gradio/oauth.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/oauth.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Lucain 2024-01-23 17:28:24 +01:00 committed by GitHub
parent 2cdcf4a890
commit 82fe73d042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 12 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Redirect with query params after oauth

View File

@ -104,13 +104,15 @@ class LoginButton(Button):
_js_handle_redirect = """
(buttonValue) => {
if (buttonValue === BUTTON_DEFAULT_VALUE) {
url = '/login/huggingface' + window.location.search;
if ( window !== window.parent ) {
window.open('/login/huggingface', '_blank');
window.open(url, '_blank');
} else {
window.location.assign('/login/huggingface');
window.location.assign(url);
}
} else {
window.location.assign('/logout');
url = '/logout' + window.location.search
window.location.assign(url);
}
}
"""

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import hashlib
import os
import typing
import urllib.parse
import warnings
from dataclasses import dataclass, field
@ -87,10 +88,8 @@ def _add_oauth_routes(app: fastapi.FastAPI) -> None:
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request):
"""Endpoint that redirects to HF OAuth page."""
redirect_uri = str(request.url_for("oauth_redirect_callback"))
if ".hf.space" in redirect_uri:
# In Space, FastAPI redirect as http but we want https
redirect_uri = redirect_uri.replace("http://", "https://")
# Define target (where to redirect after login)
redirect_uri = _generate_redirect_uri(request)
return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore
@app.get("/login/callback")
@ -98,13 +97,13 @@ def _add_oauth_routes(app: fastapi.FastAPI) -> None:
"""Endpoint that handles the OAuth callback."""
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
request.session["oauth_info"] = oauth_info
return RedirectResponse("/")
return _redirect_to_target(request)
@app.get("/logout")
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that logs out the user (e.g. delete cookie session)."""
request.session.pop("oauth_info", None)
return RedirectResponse("/")
return _redirect_to_target(request)
def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
@ -124,19 +123,49 @@ def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
@app.get("/login/huggingface")
async def oauth_login(request: fastapi.Request):
"""Fake endpoint that redirects to HF OAuth page."""
return RedirectResponse("/login/callback")
# Define target (where to redirect after login)
redirect_uri = _generate_redirect_uri(request)
return RedirectResponse(
"/login/callback?" + urllib.parse.urlencode({"_target_url": redirect_uri})
)
@app.get("/login/callback")
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that handles the OAuth callback."""
request.session["oauth_info"] = mocked_oauth_info
return RedirectResponse("/")
return _redirect_to_target(request)
@app.get("/logout")
async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
"""Endpoint that logs out the user (e.g. delete cookie session)."""
request.session.pop("oauth_info", None)
return RedirectResponse("/")
logout_url = str(request.url).replace("/logout", "/") # preserve query params
return RedirectResponse(url=logout_url)
def _generate_redirect_uri(request: fastapi.Request) -> str:
if "_target_url" in request.query_params:
# if `_target_url` already in query params => respect it
target = request.query_params["_target_url"]
else:
# otherwise => keep query params
target = "/?" + urllib.parse.urlencode(request.query_params)
redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(
_target_url=target
)
redirect_uri_as_str = str(redirect_uri)
if redirect_uri.netloc.endswith(".hf.space"):
# In Space, FastAPI redirect as http but we want https
redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
return redirect_uri_as_str
def _redirect_to_target(
request: fastapi.Request, default_target: str = "/"
) -> RedirectResponse:
target = request.query_params.get("_target_url", default_target)
return RedirectResponse(target)
@dataclass