diff --git a/gradio.egg-info/requires.txt b/gradio.egg-info/requires.txt index fe3c779285..3f23758fda 100644 --- a/gradio.egg-info/requires.txt +++ b/gradio.egg-info/requires.txt @@ -1,5 +1,6 @@ analytics-python fastapi +fastapi-login ffmpy markdown2 matplotlib diff --git a/gradio/app.py b/gradio/app.py index c5c241cb6d..82848ec274 100644 --- a/gradio/app.py +++ b/gradio/app.py @@ -11,7 +11,7 @@ import os import posixpath import pkg_resources import secrets -from starlette.responses import RedirectResponse +from starlette.responses import Response, RedirectResponse import traceback from typing import List, Optional, Type, TYPE_CHECKING import urllib @@ -45,16 +45,14 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) ########### -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False) +@app.manager.user_loader() +def get_current_user(username: str) -> Optional[str]: + if username in app.users: + return username -def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[str]: - if token in app.tokens: - return app.tokens[token] - - -def is_authenticated(token: str = Depends(oauth2_scheme)) -> bool: - return get_current_user(token) is not None +def is_authenticated(username: str = Depends(app.manager)) -> bool: + return get_current_user(username) is not None def login_check(is_authenticated: bool = Depends(is_authenticated)): @@ -63,20 +61,24 @@ def login_check(is_authenticated: bool = Depends(is_authenticated)): detail="Not authenticated") -@app.get('/token') -def get_token(token: str = Depends(oauth2_scheme)): - return {"token": token} +# @app.get('/token') +# def get_token(token: str = Depends(app.manager)): +# return {"token": token} @app.post('/login') -def login(form_data: OAuth2PasswordRequestForm = Depends()): +def login(response: Response, + form_data: OAuth2PasswordRequestForm = Depends()): username, password = form_data.username, form_data.password if ((not callable(app.auth) and username in app.auth and app.auth[username] == password) or (callable(app.auth) and app.auth.__call__(username, password))): - token = secrets.token_urlsafe(16) - app.tokens[token] = username - return {"access_token": token, "token_type": "bearer"} + token = app.manager.create_access_token( + data=dict(sub=username) + ) + app.users.add(username) + app.manager.set_cookie(response, token) + return response else: raise HTTPException(status_code=400, detail="Incorrect credentials.") @@ -317,5 +319,5 @@ if __name__ == '__main__': # Run directly for debugging: python app.py app.interface.auth_message = None else: app.auth = None - app.tokens = {} + app.users = [] uvicorn.run(app) diff --git a/gradio/interface.py b/gradio/interface.py index dbdc682b38..e6eb8cf106 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -11,7 +11,6 @@ import markdown2 # type: ignore import os import random import sys -import threading import time from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING import warnings @@ -169,7 +168,7 @@ class Interface: state_param_index = [isinstance(i, i_State) for i in self.input_components].index(True) state_init_value = utils.get_default_args(fn[0])[state_param_index] - except ValueError: + except ValueError: # No default value for the state parameter state_init_value = None self.state_init_value = state_init_value diff --git a/gradio/networking.py b/gradio/networking.py index 106419a531..b7dbf4f616 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -3,8 +3,8 @@ Defines helper methods useful for setting up ports, launching servers, and creating tunnels. """ from __future__ import annotations -import contextlib import fastapi +from fastapi_login import LoginManager import http import json import os @@ -110,7 +110,8 @@ def start_server( auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login. """ server_name = server_name or LOCALHOST_NAME - if server_port is None: # if port is not specified, search for first available port + # if port is not specified, search for first available port + if server_port is None: port = get_first_available_port( INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS ) @@ -132,6 +133,11 @@ def start_server( app.auth = auth else: app.auth = None + app.secret = os.urandom(24).hex() + app.login_manager = LoginManager( + app.secret, token_url="/login", use_cookie=True) + app.users = set() + app.interface = interface app.cwd = os.getcwd() if app.interface.enable_queue: @@ -142,7 +148,7 @@ def start_server( app.queue_thread.start() if interface.save_to is not None: # Used for selenium tests interface.save_to["port"] = port - app.tokens = {} + config = uvicorn.Config(app=app, port=port, host=server_name, log_level="warning") server = Server(config=config) diff --git a/setup.py b/setup.py index 58eadbe061..bad32898ad 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ setup( install_requires=[ 'analytics-python', 'fastapi', + 'fastapi-login', 'ffmpy', 'markdown2', 'matplotlib',