diff --git a/gradio/app.py b/gradio/app.py index 89e481ebc5..07d4ced265 100644 --- a/gradio/app.py +++ b/gradio/app.py @@ -5,12 +5,12 @@ from fastapi import FastAPI, Form, Request, Depends, HTTPException, status from fastapi.responses import JSONResponse, HTMLResponse, FileResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware -from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm import inspect import os import posixpath import pkg_resources -from secrets import compare_digest +import secrets import traceback from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING import urllib @@ -35,8 +35,34 @@ app.add_middleware( allow_headers=["*"], ) -secure = HTTPBasic() templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False) + + +######## +# Auth +######## + +def get_username_from_token(token: str = Depends(oauth2_scheme)): + print('token2', token) + if token in app.tokens: + return app.tokens[token] + +def is_authenticated(token: str = Depends(oauth2_scheme)): + print('token', token) + return get_username_from_token(token) is not None + +@app.post('/login') +def login(form_data: OAuth2PasswordRequestForm = Depends()): + username, password = form_data.username, form_data.password + if ((not callable(app.auth) and username in app.auth + and app.auth[username] == password) + or (callable(app.auth) and app.auth.__call__(username, password))): + token = secrets.token_urlsafe(16) + app.tokens[token] = username + return {"access_token": token, "token_type": "bearer"} + else: + raise HTTPException(status_code=400, detail="Incorrect credentials.") ############### @@ -46,12 +72,19 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) @app.head('/', response_class=HTMLResponse) @app.get('/', response_class=HTMLResponse) -# @login_check # TODO -def main(request: Request): - # session["state"] = None # TODO +def main(request: Request, is_authenticated=Depends(is_authenticated)): + print('app.tokens', app.tokens) + print('is_authenticated', is_authenticated) + print('token>>>', request.headers.get("Authorization", "missing")) + if app.auth is None or is_authenticated: + config = app.interface.config + else: + config = {"auth_required": True, + "auth_message": app.interface.auth_message} + return templates.TemplateResponse( "frontend/index.html", - {"request": request, "config": app.interface.config} + {"request": request, "config": config} ) @@ -66,7 +99,6 @@ def static_resource(path: str): raise HTTPException(status_code=404, detail="Static file not found") -@app.get("/config/") def get_config(): # if app.interface.auth is None or current_user.is_authenticated: return app.interface.config @@ -172,6 +204,7 @@ async def interpret(request: Request): } + # @app.route("/shutdown", methods=['GET']) # def shutdown(): # shutdown_func = request.environ.get('werkzeug.server.shutdown') @@ -198,8 +231,7 @@ async def interpret(request: Request): # return {"status": status, "data": data} - - +# def get_current_user(token: str = Depends(oauth2_scheme)) ######## # Helper functions @@ -262,4 +294,9 @@ if __name__ == '__main__': # Run directly for debugging: python app.py app.interface.config = app.interface.get_config_file() app.interface.show_error = True app.interface.flagging_callback.setup(app.interface.flagging_dir) + # app.auth = None + app.interface.auth = ("a", "b") + app.auth = {"a": "b"} + app.interface.auth_message = None + app.tokens = {} uvicorn.run(app) diff --git a/gradio/auth.py b/gradio/auth.py deleted file mode 100644 index b76daa8110..0000000000 --- a/gradio/auth.py +++ /dev/null @@ -1,209 +0,0 @@ -"""Authentication via cookies for Fast API -Credit: https://medium.com/data-rebels/fastapi-how-to-add-basic-and-cookie-authentication-a45c85ef47d3 -""" -from typing import Optional -import base64 -from passlib.context import CryptContext -from datetime import datetime, timedelta - -import jwt -from jwt import PyJWTError - -from pydantic import BaseModel - -from fastapi import Depends, FastAPI, HTTPException -from fastapi.encoders import jsonable_encoder -from fastapi.security import OAuth2PasswordRequestForm, OAuth2 -from fastapi.security.base import SecurityBase -from fastapi.security.utils import get_authorization_scheme_param -from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel - -from starlette.responses import RedirectResponse, Response -from starlette.requests import Request - -import uvicorn - -SECRET_KEY = "56c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e709d25e094faa6ca25" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 - - -class Token(BaseModel): - access_token: str - token_type: str - - -class TokenData(BaseModel): - username: str = None - - -class OAuth2PasswordBearerCookie(OAuth2): - def __init__( - self, - tokenUrl: str, - scheme_name: str = None, - scopes: dict = None, - auto_error: bool = True, - ): - if not scopes: - scopes = {} - flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) - super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error) - - async def __call__(self, request: Request) -> Optional[str]: - header_authorization: str = request.headers.get("Authorization") - cookie_authorization: str = request.cookies.get("Authorization") - - header_scheme, header_param = get_authorization_scheme_param( - header_authorization - ) - cookie_scheme, cookie_param = get_authorization_scheme_param( - cookie_authorization - ) - - if header_scheme.lower() == "bearer": - authorization = True - scheme = header_scheme - param = header_param - - elif cookie_scheme.lower() == "bearer": - authorization = True - scheme = cookie_scheme - param = cookie_param - - else: - authorization = False - - if not authorization or scheme.lower() != "bearer": - if self.auto_error: - raise HTTPException( - status_code=403, detail="Not authenticated" - ) - else: - return None - return param - - -class BasicAuth(SecurityBase): - def __init__(self, scheme_name: str = None, auto_error: bool = True): - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error - - async def __call__(self, request: Request) -> Optional[str]: - authorization: str = request.headers.get("Authorization") - scheme, param = get_authorization_scheme_param(authorization) - if not authorization or scheme.lower() != "basic": - if self.auto_error: - raise HTTPException( - status_code=403, detail="Not authenticated" - ) - else: - return None - return param - - -basic_auth = BasicAuth(auto_error=False) - -oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token") - -app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) - -def get_user(username: str): - if username == "johndoe4": - return username - - -def authenticate_user(username: str, password: str): - if username == "johndoe4" and password == "secret": - return username - return False - - -def create_access_token(*, data: dict, expires_delta: timedelta = None): - to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt - - -async def get_current_user(token: str = Depends(oauth2_scheme)): - credentials_exception = HTTPException( - status_code=403, detail="Could not validate credentials" - ) - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - token_data = TokenData(username=username) - except PyJWTError: - raise credentials_exception - username = get_user(username=token_data.username) - if username is None: - raise credentials_exception - return username - - -@app.get("/") -async def homepage(): - return "Welcome to the security test!" - - -@app.post("/token", response_model=Token) -async def route_login_access_token(form_data: OAuth2PasswordRequestForm = Depends()): - username = authenticate_user(form_data.username, form_data.password) - if not username: - raise HTTPException(status_code=400, detail="Incorrect username or password") - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": username}, expires_delta=access_token_expires - ) - return {"access_token": access_token, "token_type": "bearer"} - - -@app.get("/login_basic") -async def login_basic(auth: BasicAuth = Depends(basic_auth)): - if not auth: - response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401) - return response - - try: - decoded = base64.b64decode(auth).decode("ascii") - username, _, password = decoded.partition(":") - username = authenticate_user(username, password) - if not username: - raise HTTPException(status_code=400, detail="Incorrect email or password") - - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token = create_access_token( - data={"sub": username}, expires_delta=access_token_expires - ) - - token = jsonable_encoder(access_token) - - response = RedirectResponse(url="/docs") - response.set_cookie( - "Authorization", - value=f"Bearer {token}", - httponly=True, - max_age=1800, - expires=1800, - ) - return response - - except: - response = Response(headers={"WWW-Authenticate": "Basic"}, status_code=401) - return response - - -@app.get("/docs") -async def get_documentation(current_user: str = Depends(get_current_user)): - return {"hi": "you're in"} - - -if __name__ == '__main__': # Run directly for debugging: python auth.py - uvicorn.run(app) diff --git a/gradio/networking.py b/gradio/networking.py index 202b171c54..b13fa082fc 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -180,6 +180,7 @@ def start_server( # queueing.init() # app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,)) # app.queue_thread.start() + app.tokens = {} app_kwargs = {"app": app, "port": port, "host": server_name, "log_level": "warning"} thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs) diff --git a/setup.py b/setup.py index 7e0cf7b3e2..f7c4d7b7d4 100644 --- a/setup.py +++ b/setup.py @@ -25,8 +25,8 @@ setup( 'paramiko', 'pillow', 'pycryptodome', + 'python-multipart', 'pydub', 'requests', - # python-multipart ], )