working on auth

This commit is contained in:
Abubakar Abid 2021-12-29 23:57:50 -06:00
parent d9e3afc263
commit 683c1aa024
4 changed files with 49 additions and 220 deletions

View File

@ -5,12 +5,12 @@ from fastapi import FastAPI, Form, Request, Depends, HTTPException, status
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
import inspect import inspect
import os import os
import posixpath import posixpath
import pkg_resources import pkg_resources
from secrets import compare_digest import secrets
import traceback import traceback
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
import urllib import urllib
@ -35,8 +35,34 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
secure = HTTPBasic()
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)
########
# Auth
########
def get_username_from_token(token: str = Depends(oauth2_scheme)):
print('token2', token)
if token in app.tokens:
return app.tokens[token]
def is_authenticated(token: str = Depends(oauth2_scheme)):
print('token', token)
return get_username_from_token(token) is not None
@app.post('/login')
def login(form_data: OAuth2PasswordRequestForm = Depends()):
username, password = form_data.username, form_data.password
if ((not callable(app.auth) and username in app.auth
and app.auth[username] == password)
or (callable(app.auth) and app.auth.__call__(username, password))):
token = secrets.token_urlsafe(16)
app.tokens[token] = username
return {"access_token": token, "token_type": "bearer"}
else:
raise HTTPException(status_code=400, detail="Incorrect credentials.")
############### ###############
@ -46,12 +72,19 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
@app.head('/', response_class=HTMLResponse) @app.head('/', response_class=HTMLResponse)
@app.get('/', response_class=HTMLResponse) @app.get('/', response_class=HTMLResponse)
# @login_check # TODO def main(request: Request, is_authenticated=Depends(is_authenticated)):
def main(request: Request): print('app.tokens', app.tokens)
# session["state"] = None # TODO print('is_authenticated', is_authenticated)
print('token>>>', request.headers.get("Authorization", "missing"))
if app.auth is None or is_authenticated:
config = app.interface.config
else:
config = {"auth_required": True,
"auth_message": app.interface.auth_message}
return templates.TemplateResponse( return templates.TemplateResponse(
"frontend/index.html", "frontend/index.html",
{"request": request, "config": app.interface.config} {"request": request, "config": config}
) )
@ -66,7 +99,6 @@ def static_resource(path: str):
raise HTTPException(status_code=404, detail="Static file not found") raise HTTPException(status_code=404, detail="Static file not found")
@app.get("/config/")
def get_config(): def get_config():
# if app.interface.auth is None or current_user.is_authenticated: # if app.interface.auth is None or current_user.is_authenticated:
return app.interface.config return app.interface.config
@ -172,6 +204,7 @@ async def interpret(request: Request):
} }
# @app.route("/shutdown", methods=['GET']) # @app.route("/shutdown", methods=['GET'])
# def shutdown(): # def shutdown():
# shutdown_func = request.environ.get('werkzeug.server.shutdown') # shutdown_func = request.environ.get('werkzeug.server.shutdown')
@ -198,8 +231,7 @@ async def interpret(request: Request):
# return {"status": status, "data": data} # return {"status": status, "data": data}
# def get_current_user(token: str = Depends(oauth2_scheme))
######## ########
# Helper functions # Helper functions
@ -262,4 +294,9 @@ if __name__ == '__main__': # Run directly for debugging: python app.py
app.interface.config = app.interface.get_config_file() app.interface.config = app.interface.get_config_file()
app.interface.show_error = True app.interface.show_error = True
app.interface.flagging_callback.setup(app.interface.flagging_dir) app.interface.flagging_callback.setup(app.interface.flagging_dir)
# app.auth = None
app.interface.auth = ("a", "b")
app.auth = {"a": "b"}
app.interface.auth_message = None
app.tokens = {}
uvicorn.run(app) uvicorn.run(app)

View File

@ -1,209 +0,0 @@
"""Authentication via cookies for Fast API
Credit: https://medium.com/data-rebels/fastapi-how-to-add-basic-and-cookie-authentication-a45c85ef47d3
"""
from typing import Optional
import base64
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
from jwt import PyJWTError
from pydantic import BaseModel
from fastapi import Depends, FastAPI, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.security import OAuth2PasswordRequestForm, OAuth2
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from starlette.responses import RedirectResponse, Response
from starlette.requests import Request
import uvicorn
SECRET_KEY = "56c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e709d25e094faa6ca25"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str = None
class OAuth2PasswordBearerCookie(OAuth2):
def __init__(
self,
tokenUrl: str,
scheme_name: str = None,
scopes: dict = None,
auto_error: bool = True,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
async def __call__(self, request: Request) -> Optional[str]:
header_authorization: str = request.headers.get("Authorization")
cookie_authorization: str = request.cookies.get("Authorization")
header_scheme, header_param = get_authorization_scheme_param(
header_authorization
)
cookie_scheme, cookie_param = get_authorization_scheme_param(
cookie_authorization
)
if header_scheme.lower() == "bearer":
authorization = True
scheme = header_scheme
param = header_param
elif cookie_scheme.lower() == "bearer":
authorization = True
scheme = cookie_scheme
param = cookie_param
else:
authorization = False
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=403, detail="Not authenticated"
)
else:
return None
return param
class BasicAuth(SecurityBase):
def __init__(self, scheme_name: str = None, auto_error: bool = True):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic":
if self.auto_error:
raise HTTPException(
status_code=403, detail="Not authenticated"
)
else:
return None
return param
basic_auth = BasicAuth(auto_error=False)
oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token")
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
def get_user(username: str):
if username == "johndoe4":
return username
def authenticate_user(username: str, password: str):
if username == "johndoe4" and password == "secret":
return username
return False
def create_access_token(*, data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=403, detail="Could not validate credentials"
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except PyJWTError:
raise credentials_exception
username = get_user(username=token_data.username)
if username is None:
raise credentials_exception
return username
@app.get("/")
async def homepage():
return "Welcome to the security test!"
@app.post("/token", response_model=Token)
async def route_login_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
username = authenticate_user(form_data.username, form_data.password)
if not username:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@app.get("/login_basic")
async def login_basic(auth: BasicAuth = Depends(basic_auth)):
if not auth:
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
return response
try:
decoded = base64.b64decode(auth).decode("ascii")
username, _, password = decoded.partition(":")
username = authenticate_user(username, password)
if not username:
raise HTTPException(status_code=400, detail="Incorrect email or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": username}, expires_delta=access_token_expires
)
token = jsonable_encoder(access_token)
response = RedirectResponse(url="/docs")
response.set_cookie(
"Authorization",
value=f"Bearer {token}",
httponly=True,
max_age=1800,
expires=1800,
)
return response
except:
response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401)
return response
@app.get("/docs")
async def get_documentation(current_user: str = Depends(get_current_user)):
return {"hi": "you're in"}
if __name__ == '__main__': # Run directly for debugging: python auth.py
uvicorn.run(app)

View File

@ -180,6 +180,7 @@ def start_server(
# queueing.init() # queueing.init()
# app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,)) # app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
# app.queue_thread.start() # app.queue_thread.start()
app.tokens = {}
app_kwargs = {"app": app, "port": port, "host": server_name, app_kwargs = {"app": app, "port": port, "host": server_name,
"log_level": "warning"} "log_level": "warning"}
thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs) thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs)

View File

@ -25,8 +25,8 @@ setup(
'paramiko', 'paramiko',
'pillow', 'pillow',
'pycryptodome', 'pycryptodome',
'python-multipart',
'pydub', 'pydub',
'requests', 'requests',
# python-multipart
], ],
) )