mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
working on auth
This commit is contained in:
parent
d9e3afc263
commit
683c1aa024
@ -5,12 +5,12 @@ from fastapi import FastAPI, Form, Request, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
import inspect
|
||||
import os
|
||||
import posixpath
|
||||
import pkg_resources
|
||||
from secrets import compare_digest
|
||||
import secrets
|
||||
import traceback
|
||||
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
import urllib
|
||||
@ -35,8 +35,34 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
secure = HTTPBasic()
|
||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)
|
||||
|
||||
|
||||
########
|
||||
# Auth
|
||||
########
|
||||
|
||||
def get_username_from_token(token: str = Depends(oauth2_scheme)):
|
||||
print('token2', token)
|
||||
if token in app.tokens:
|
||||
return app.tokens[token]
|
||||
|
||||
def is_authenticated(token: str = Depends(oauth2_scheme)):
|
||||
print('token', token)
|
||||
return get_username_from_token(token) is not None
|
||||
|
||||
@app.post('/login')
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username, password = form_data.username, form_data.password
|
||||
if ((not callable(app.auth) and username in app.auth
|
||||
and app.auth[username] == password)
|
||||
or (callable(app.auth) and app.auth.__call__(username, password))):
|
||||
token = secrets.token_urlsafe(16)
|
||||
app.tokens[token] = username
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Incorrect credentials.")
|
||||
|
||||
|
||||
###############
|
||||
@ -46,12 +72,19 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
|
||||
@app.head('/', response_class=HTMLResponse)
|
||||
@app.get('/', response_class=HTMLResponse)
|
||||
# @login_check # TODO
|
||||
def main(request: Request):
|
||||
# session["state"] = None # TODO
|
||||
def main(request: Request, is_authenticated=Depends(is_authenticated)):
|
||||
print('app.tokens', app.tokens)
|
||||
print('is_authenticated', is_authenticated)
|
||||
print('token>>>', request.headers.get("Authorization", "missing"))
|
||||
if app.auth is None or is_authenticated:
|
||||
config = app.interface.config
|
||||
else:
|
||||
config = {"auth_required": True,
|
||||
"auth_message": app.interface.auth_message}
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"frontend/index.html",
|
||||
{"request": request, "config": app.interface.config}
|
||||
{"request": request, "config": config}
|
||||
)
|
||||
|
||||
|
||||
@ -66,7 +99,6 @@ def static_resource(path: str):
|
||||
raise HTTPException(status_code=404, detail="Static file not found")
|
||||
|
||||
|
||||
@app.get("/config/")
|
||||
def get_config():
|
||||
# if app.interface.auth is None or current_user.is_authenticated:
|
||||
return app.interface.config
|
||||
@ -172,6 +204,7 @@ async def interpret(request: Request):
|
||||
}
|
||||
|
||||
|
||||
|
||||
# @app.route("/shutdown", methods=['GET'])
|
||||
# def shutdown():
|
||||
# shutdown_func = request.environ.get('werkzeug.server.shutdown')
|
||||
@ -198,8 +231,7 @@ async def interpret(request: Request):
|
||||
# return {"status": status, "data": data}
|
||||
|
||||
|
||||
|
||||
|
||||
# def get_current_user(token: str = Depends(oauth2_scheme))
|
||||
|
||||
########
|
||||
# Helper functions
|
||||
@ -262,4 +294,9 @@ if __name__ == '__main__': # Run directly for debugging: python app.py
|
||||
app.interface.config = app.interface.get_config_file()
|
||||
app.interface.show_error = True
|
||||
app.interface.flagging_callback.setup(app.interface.flagging_dir)
|
||||
# app.auth = None
|
||||
app.interface.auth = ("a", "b")
|
||||
app.auth = {"a": "b"}
|
||||
app.interface.auth_message = None
|
||||
app.tokens = {}
|
||||
uvicorn.run(app)
|
||||
|
209
gradio/auth.py
209
gradio/auth.py
@ -1,209 +0,0 @@
|
||||
"""Authentication via cookies for Fast API
|
||||
Credit: https://medium.com/data-rebels/fastapi-how-to-add-basic-and-cookie-authentication-a45c85ef47d3
|
||||
"""
|
||||
from typing import Optional
|
||||
import base64
|
||||
from passlib.context import CryptContext
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWTError
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.security import OAuth2PasswordRequestForm, OAuth2
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
||||
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
from starlette.requests import Request
|
||||
|
||||
import uvicorn
|
||||
|
||||
SECRET_KEY = "56c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e709d25e094faa6ca25"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str = None
|
||||
|
||||
|
||||
class OAuth2PasswordBearerCookie(OAuth2):
|
||||
def __init__(
|
||||
self,
|
||||
tokenUrl: str,
|
||||
scheme_name: str = None,
|
||||
scopes: dict = None,
|
||||
auto_error: bool = True,
|
||||
):
|
||||
if not scopes:
|
||||
scopes = {}
|
||||
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
|
||||
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
header_authorization: str = request.headers.get("Authorization")
|
||||
cookie_authorization: str = request.cookies.get("Authorization")
|
||||
|
||||
header_scheme, header_param = get_authorization_scheme_param(
|
||||
header_authorization
|
||||
)
|
||||
cookie_scheme, cookie_param = get_authorization_scheme_param(
|
||||
cookie_authorization
|
||||
)
|
||||
|
||||
if header_scheme.lower() == "bearer":
|
||||
authorization = True
|
||||
scheme = header_scheme
|
||||
param = header_param
|
||||
|
||||
elif cookie_scheme.lower() == "bearer":
|
||||
authorization = True
|
||||
scheme = cookie_scheme
|
||||
param = cookie_param
|
||||
|
||||
else:
|
||||
authorization = False
|
||||
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return param
|
||||
|
||||
|
||||
class BasicAuth(SecurityBase):
|
||||
def __init__(self, scheme_name: str = None, auto_error: bool = True):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
authorization: str = request.headers.get("Authorization")
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "basic":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return param
|
||||
|
||||
|
||||
basic_auth = BasicAuth(auto_error=False)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token")
|
||||
|
||||
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
|
||||
|
||||
def get_user(username: str):
|
||||
if username == "johndoe4":
|
||||
return username
|
||||
|
||||
|
||||
def authenticate_user(username: str, password: str):
|
||||
if username == "johndoe4" and password == "secret":
|
||||
return username
|
||||
return False
|
||||
|
||||
|
||||
def create_access_token(*, data: dict, expires_delta: timedelta = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=403, detail="Could not validate credentials"
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except PyJWTError:
|
||||
raise credentials_exception
|
||||
username = get_user(username=token_data.username)
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
return username
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def homepage():
|
||||
return "Welcome to the security test!"
|
||||
|
||||
|
||||
@app.post("/token", response_model=Token)
|
||||
async def route_login_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username = authenticate_user(form_data.username, form_data.password)
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Incorrect username or password")
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": username}, expires_delta=access_token_expires
|
||||
)
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@app.get("/login_basic")
|
||||
async def login_basic(auth: BasicAuth = Depends(basic_auth)):
|
||||
if not auth:
|
||||
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
|
||||
return response
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth).decode("ascii")
|
||||
username, _, password = decoded.partition(":")
|
||||
username = authenticate_user(username, password)
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
||||
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": username}, expires_delta=access_token_expires
|
||||
)
|
||||
|
||||
token = jsonable_encoder(access_token)
|
||||
|
||||
response = RedirectResponse(url="/docs")
|
||||
response.set_cookie(
|
||||
"Authorization",
|
||||
value=f"Bearer {token}",
|
||||
httponly=True,
|
||||
max_age=1800,
|
||||
expires=1800,
|
||||
)
|
||||
return response
|
||||
|
||||
except:
|
||||
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/docs")
|
||||
async def get_documentation(current_user: str = Depends(get_current_user)):
|
||||
return {"hi": "you're in"}
|
||||
|
||||
|
||||
if __name__ == '__main__': # Run directly for debugging: python auth.py
|
||||
uvicorn.run(app)
|
@ -180,6 +180,7 @@ def start_server(
|
||||
# queueing.init()
|
||||
# app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
|
||||
# app.queue_thread.start()
|
||||
app.tokens = {}
|
||||
app_kwargs = {"app": app, "port": port, "host": server_name,
|
||||
"log_level": "warning"}
|
||||
thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user