cookie-based auth

This commit is contained in:
Abubakar Abid 2022-01-05 08:21:11 -05:00
parent 58604029ee
commit e54576de85
5 changed files with 31 additions and 22 deletions

View File

@ -1,5 +1,6 @@
analytics-python
fastapi
fastapi-login
ffmpy
markdown2
matplotlib

View File

@ -11,7 +11,7 @@ import os
import posixpath
import pkg_resources
import secrets
from starlette.responses import RedirectResponse
from starlette.responses import Response, RedirectResponse
import traceback
from typing import List, Optional, Type, TYPE_CHECKING
import urllib
@ -45,16 +45,14 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)
@app.manager.user_loader()
def get_current_user(username: str) -> Optional[str]:
if username in app.users:
return username
def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[str]:
if token in app.tokens:
return app.tokens[token]
def is_authenticated(token: str = Depends(oauth2_scheme)) -> bool:
return get_current_user(token) is not None
def is_authenticated(username: str = Depends(app.manager)) -> bool:
return get_current_user(username) is not None
def login_check(is_authenticated: bool = Depends(is_authenticated)):
@ -63,20 +61,24 @@ def login_check(is_authenticated: bool = Depends(is_authenticated)):
detail="Not authenticated")
@app.get('/token')
def get_token(token: str = Depends(oauth2_scheme)):
return {"token": token}
# @app.get('/token')
# def get_token(token: str = Depends(app.manager)):
# return {"token": token}
@app.post('/login')
def login(form_data: OAuth2PasswordRequestForm = Depends()):
def login(response: Response,
form_data: OAuth2PasswordRequestForm = Depends()):
username, password = form_data.username, form_data.password
if ((not callable(app.auth) and username in app.auth
and app.auth[username] == password)
or (callable(app.auth) and app.auth.__call__(username, password))):
token = secrets.token_urlsafe(16)
app.tokens[token] = username
return {"access_token": token, "token_type": "bearer"}
token = app.manager.create_access_token(
data=dict(sub=username)
)
app.users.add(username)
app.manager.set_cookie(response, token)
return response
else:
raise HTTPException(status_code=400, detail="Incorrect credentials.")
@ -317,5 +319,5 @@ if __name__ == '__main__': # Run directly for debugging: python app.py
app.interface.auth_message = None
else:
app.auth = None
app.tokens = {}
app.users = []
uvicorn.run(app)

View File

@ -11,7 +11,6 @@ import markdown2 # type: ignore
import os
import random
import sys
import threading
import time
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
import warnings
@ -169,7 +168,7 @@ class Interface:
state_param_index = [isinstance(i, i_State)
for i in self.input_components].index(True)
state_init_value = utils.get_default_args(fn[0])[state_param_index]
except ValueError:
except ValueError: # No default value for the state parameter
state_init_value = None
self.state_init_value = state_init_value

View File

@ -3,8 +3,8 @@ Defines helper methods useful for setting up ports, launching servers, and
creating tunnels.
"""
from __future__ import annotations
import contextlib
import fastapi
from fastapi_login import LoginManager
import http
import json
import os
@ -110,7 +110,8 @@ def start_server(
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
"""
server_name = server_name or LOCALHOST_NAME
if server_port is None: # if port is not specified, search for first available port
# if port is not specified, search for first available port
if server_port is None:
port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
)
@ -132,6 +133,11 @@ def start_server(
app.auth = auth
else:
app.auth = None
app.secret = os.urandom(24).hex()
app.login_manager = LoginManager(
app.secret, token_url="/login", use_cookie=True)
app.users = set()
app.interface = interface
app.cwd = os.getcwd()
if app.interface.enable_queue:
@ -142,7 +148,7 @@ def start_server(
app.queue_thread.start()
if interface.save_to is not None: # Used for selenium tests
interface.save_to["port"] = port
app.tokens = {}
config = uvicorn.Config(app=app, port=port, host=server_name,
log_level="warning")
server = Server(config=config)

View File

@ -17,6 +17,7 @@ setup(
install_requires=[
'analytics-python',
'fastapi',
'fastapi-login',
'ffmpy',
'markdown2',
'matplotlib',