mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
working on auth
This commit is contained in:
parent
d9e3afc263
commit
683c1aa024
@ -5,12 +5,12 @@ from fastapi import FastAPI, Form, Request, Depends, HTTPException, status
|
|||||||
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import posixpath
|
import posixpath
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from secrets import compare_digest
|
import secrets
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
|
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
|
||||||
import urllib
|
import urllib
|
||||||
@ -35,8 +35,34 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
secure = HTTPBasic()
|
|
||||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
########
|
||||||
|
# Auth
|
||||||
|
########
|
||||||
|
|
||||||
|
def get_username_from_token(token: str = Depends(oauth2_scheme)):
|
||||||
|
print('token2', token)
|
||||||
|
if token in app.tokens:
|
||||||
|
return app.tokens[token]
|
||||||
|
|
||||||
|
def is_authenticated(token: str = Depends(oauth2_scheme)):
|
||||||
|
print('token', token)
|
||||||
|
return get_username_from_token(token) is not None
|
||||||
|
|
||||||
|
@app.post('/login')
|
||||||
|
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||||
|
username, password = form_data.username, form_data.password
|
||||||
|
if ((not callable(app.auth) and username in app.auth
|
||||||
|
and app.auth[username] == password)
|
||||||
|
or (callable(app.auth) and app.auth.__call__(username, password))):
|
||||||
|
token = secrets.token_urlsafe(16)
|
||||||
|
app.tokens[token] = username
|
||||||
|
return {"access_token": token, "token_type": "bearer"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail="Incorrect credentials.")
|
||||||
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
@ -46,12 +72,19 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
|||||||
|
|
||||||
@app.head('/', response_class=HTMLResponse)
|
@app.head('/', response_class=HTMLResponse)
|
||||||
@app.get('/', response_class=HTMLResponse)
|
@app.get('/', response_class=HTMLResponse)
|
||||||
# @login_check # TODO
|
def main(request: Request, is_authenticated=Depends(is_authenticated)):
|
||||||
def main(request: Request):
|
print('app.tokens', app.tokens)
|
||||||
# session["state"] = None # TODO
|
print('is_authenticated', is_authenticated)
|
||||||
|
print('token>>>', request.headers.get("Authorization", "missing"))
|
||||||
|
if app.auth is None or is_authenticated:
|
||||||
|
config = app.interface.config
|
||||||
|
else:
|
||||||
|
config = {"auth_required": True,
|
||||||
|
"auth_message": app.interface.auth_message}
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"frontend/index.html",
|
"frontend/index.html",
|
||||||
{"request": request, "config": app.interface.config}
|
{"request": request, "config": config}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +99,6 @@ def static_resource(path: str):
|
|||||||
raise HTTPException(status_code=404, detail="Static file not found")
|
raise HTTPException(status_code=404, detail="Static file not found")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/config/")
|
|
||||||
def get_config():
|
def get_config():
|
||||||
# if app.interface.auth is None or current_user.is_authenticated:
|
# if app.interface.auth is None or current_user.is_authenticated:
|
||||||
return app.interface.config
|
return app.interface.config
|
||||||
@ -172,6 +204,7 @@ async def interpret(request: Request):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# @app.route("/shutdown", methods=['GET'])
|
# @app.route("/shutdown", methods=['GET'])
|
||||||
# def shutdown():
|
# def shutdown():
|
||||||
# shutdown_func = request.environ.get('werkzeug.server.shutdown')
|
# shutdown_func = request.environ.get('werkzeug.server.shutdown')
|
||||||
@ -198,8 +231,7 @@ async def interpret(request: Request):
|
|||||||
# return {"status": status, "data": data}
|
# return {"status": status, "data": data}
|
||||||
|
|
||||||
|
|
||||||
|
# def get_current_user(token: str = Depends(oauth2_scheme))
|
||||||
|
|
||||||
|
|
||||||
########
|
########
|
||||||
# Helper functions
|
# Helper functions
|
||||||
@ -262,4 +294,9 @@ if __name__ == '__main__': # Run directly for debugging: python app.py
|
|||||||
app.interface.config = app.interface.get_config_file()
|
app.interface.config = app.interface.get_config_file()
|
||||||
app.interface.show_error = True
|
app.interface.show_error = True
|
||||||
app.interface.flagging_callback.setup(app.interface.flagging_dir)
|
app.interface.flagging_callback.setup(app.interface.flagging_dir)
|
||||||
|
# app.auth = None
|
||||||
|
app.interface.auth = ("a", "b")
|
||||||
|
app.auth = {"a": "b"}
|
||||||
|
app.interface.auth_message = None
|
||||||
|
app.tokens = {}
|
||||||
uvicorn.run(app)
|
uvicorn.run(app)
|
||||||
|
209
gradio/auth.py
209
gradio/auth.py
@ -1,209 +0,0 @@
|
|||||||
"""Authentication via cookies for Fast API
|
|
||||||
Credit: https://medium.com/data-rebels/fastapi-how-to-add-basic-and-cookie-authentication-a45c85ef47d3
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
import base64
|
|
||||||
from passlib.context import CryptContext
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
from jwt import PyJWTError
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm, OAuth2
|
|
||||||
from fastapi.security.base import SecurityBase
|
|
||||||
from fastapi.security.utils import get_authorization_scheme_param
|
|
||||||
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
|
||||||
|
|
||||||
from starlette.responses import RedirectResponse, Response
|
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
SECRET_KEY = "56c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e709d25e094faa6ca25"
|
|
||||||
ALGORITHM = "HS256"
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
token_type: str
|
|
||||||
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
|
||||||
username: str = None
|
|
||||||
|
|
||||||
|
|
||||||
class OAuth2PasswordBearerCookie(OAuth2):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenUrl: str,
|
|
||||||
scheme_name: str = None,
|
|
||||||
scopes: dict = None,
|
|
||||||
auto_error: bool = True,
|
|
||||||
):
|
|
||||||
if not scopes:
|
|
||||||
scopes = {}
|
|
||||||
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
|
|
||||||
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
|
|
||||||
|
|
||||||
async def __call__(self, request: Request) -> Optional[str]:
|
|
||||||
header_authorization: str = request.headers.get("Authorization")
|
|
||||||
cookie_authorization: str = request.cookies.get("Authorization")
|
|
||||||
|
|
||||||
header_scheme, header_param = get_authorization_scheme_param(
|
|
||||||
header_authorization
|
|
||||||
)
|
|
||||||
cookie_scheme, cookie_param = get_authorization_scheme_param(
|
|
||||||
cookie_authorization
|
|
||||||
)
|
|
||||||
|
|
||||||
if header_scheme.lower() == "bearer":
|
|
||||||
authorization = True
|
|
||||||
scheme = header_scheme
|
|
||||||
param = header_param
|
|
||||||
|
|
||||||
elif cookie_scheme.lower() == "bearer":
|
|
||||||
authorization = True
|
|
||||||
scheme = cookie_scheme
|
|
||||||
param = cookie_param
|
|
||||||
|
|
||||||
else:
|
|
||||||
authorization = False
|
|
||||||
|
|
||||||
if not authorization or scheme.lower() != "bearer":
|
|
||||||
if self.auto_error:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="Not authenticated"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
return param
|
|
||||||
|
|
||||||
|
|
||||||
class BasicAuth(SecurityBase):
|
|
||||||
def __init__(self, scheme_name: str = None, auto_error: bool = True):
|
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
|
||||||
self.auto_error = auto_error
|
|
||||||
|
|
||||||
async def __call__(self, request: Request) -> Optional[str]:
|
|
||||||
authorization: str = request.headers.get("Authorization")
|
|
||||||
scheme, param = get_authorization_scheme_param(authorization)
|
|
||||||
if not authorization or scheme.lower() != "basic":
|
|
||||||
if self.auto_error:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403, detail="Not authenticated"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
return param
|
|
||||||
|
|
||||||
|
|
||||||
basic_auth = BasicAuth(auto_error=False)
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token")
|
|
||||||
|
|
||||||
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
|
|
||||||
|
|
||||||
def get_user(username: str):
|
|
||||||
if username == "johndoe4":
|
|
||||||
return username
|
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(username: str, password: str):
|
|
||||||
if username == "johndoe4" and password == "secret":
|
|
||||||
return username
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(*, data: dict, expires_delta: timedelta = None):
|
|
||||||
to_encode = data.copy()
|
|
||||||
if expires_delta:
|
|
||||||
expire = datetime.utcnow() + expires_delta
|
|
||||||
else:
|
|
||||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
||||||
to_encode.update({"exp": expire})
|
|
||||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
||||||
return encoded_jwt
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
|
||||||
credentials_exception = HTTPException(
|
|
||||||
status_code=403, detail="Could not validate credentials"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
||||||
username: str = payload.get("sub")
|
|
||||||
if username is None:
|
|
||||||
raise credentials_exception
|
|
||||||
token_data = TokenData(username=username)
|
|
||||||
except PyJWTError:
|
|
||||||
raise credentials_exception
|
|
||||||
username = get_user(username=token_data.username)
|
|
||||||
if username is None:
|
|
||||||
raise credentials_exception
|
|
||||||
return username
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def homepage():
|
|
||||||
return "Welcome to the security test!"
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/token", response_model=Token)
|
|
||||||
async def route_login_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
||||||
username = authenticate_user(form_data.username, form_data.password)
|
|
||||||
if not username:
|
|
||||||
raise HTTPException(status_code=400, detail="Incorrect username or password")
|
|
||||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
||||||
access_token = create_access_token(
|
|
||||||
data={"sub": username}, expires_delta=access_token_expires
|
|
||||||
)
|
|
||||||
return {"access_token": access_token, "token_type": "bearer"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/login_basic")
|
|
||||||
async def login_basic(auth: BasicAuth = Depends(basic_auth)):
|
|
||||||
if not auth:
|
|
||||||
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
|
|
||||||
return response
|
|
||||||
|
|
||||||
try:
|
|
||||||
decoded = base64.b64decode(auth).decode("ascii")
|
|
||||||
username, _, password = decoded.partition(":")
|
|
||||||
username = authenticate_user(username, password)
|
|
||||||
if not username:
|
|
||||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
|
||||||
|
|
||||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
||||||
access_token = create_access_token(
|
|
||||||
data={"sub": username}, expires_delta=access_token_expires
|
|
||||||
)
|
|
||||||
|
|
||||||
token = jsonable_encoder(access_token)
|
|
||||||
|
|
||||||
response = RedirectResponse(url="/docs")
|
|
||||||
response.set_cookie(
|
|
||||||
"Authorization",
|
|
||||||
value=f"Bearer {token}",
|
|
||||||
httponly=True,
|
|
||||||
max_age=1800,
|
|
||||||
expires=1800,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except:
|
|
||||||
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs")
|
|
||||||
async def get_documentation(current_user: str = Depends(get_current_user)):
|
|
||||||
return {"hi": "you're in"}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__': # Run directly for debugging: python auth.py
|
|
||||||
uvicorn.run(app)
|
|
@ -180,6 +180,7 @@ def start_server(
|
|||||||
# queueing.init()
|
# queueing.init()
|
||||||
# app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
|
# app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
|
||||||
# app.queue_thread.start()
|
# app.queue_thread.start()
|
||||||
|
app.tokens = {}
|
||||||
app_kwargs = {"app": app, "port": port, "host": server_name,
|
app_kwargs = {"app": app, "port": port, "host": server_name,
|
||||||
"log_level": "warning"}
|
"log_level": "warning"}
|
||||||
thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs)
|
thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user