mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
cookie-based auth
This commit is contained in:
parent
58604029ee
commit
e54576de85
@ -1,5 +1,6 @@
|
||||
analytics-python
|
||||
fastapi
|
||||
fastapi-login
|
||||
ffmpy
|
||||
markdown2
|
||||
matplotlib
|
||||
|
@ -11,7 +11,7 @@ import os
|
||||
import posixpath
|
||||
import pkg_resources
|
||||
import secrets
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.responses import Response, RedirectResponse
|
||||
import traceback
|
||||
from typing import List, Optional, Type, TYPE_CHECKING
|
||||
import urllib
|
||||
@ -45,16 +45,14 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
###########
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login", auto_error=False)
|
||||
@app.manager.user_loader()
|
||||
def get_current_user(username: str) -> Optional[str]:
|
||||
if username in app.users:
|
||||
return username
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[str]:
|
||||
if token in app.tokens:
|
||||
return app.tokens[token]
|
||||
|
||||
|
||||
def is_authenticated(token: str = Depends(oauth2_scheme)) -> bool:
|
||||
return get_current_user(token) is not None
|
||||
def is_authenticated(username: str = Depends(app.manager)) -> bool:
|
||||
return get_current_user(username) is not None
|
||||
|
||||
|
||||
def login_check(is_authenticated: bool = Depends(is_authenticated)):
|
||||
@ -63,20 +61,24 @@ def login_check(is_authenticated: bool = Depends(is_authenticated)):
|
||||
detail="Not authenticated")
|
||||
|
||||
|
||||
@app.get('/token')
|
||||
def get_token(token: str = Depends(oauth2_scheme)):
|
||||
return {"token": token}
|
||||
# @app.get('/token')
|
||||
# def get_token(token: str = Depends(app.manager)):
|
||||
# return {"token": token}
|
||||
|
||||
|
||||
@app.post('/login')
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
def login(response: Response,
|
||||
form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username, password = form_data.username, form_data.password
|
||||
if ((not callable(app.auth) and username in app.auth
|
||||
and app.auth[username] == password)
|
||||
or (callable(app.auth) and app.auth.__call__(username, password))):
|
||||
token = secrets.token_urlsafe(16)
|
||||
app.tokens[token] = username
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
token = app.manager.create_access_token(
|
||||
data=dict(sub=username)
|
||||
)
|
||||
app.users.add(username)
|
||||
app.manager.set_cookie(response, token)
|
||||
return response
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Incorrect credentials.")
|
||||
|
||||
@ -317,5 +319,5 @@ if __name__ == '__main__': # Run directly for debugging: python app.py
|
||||
app.interface.auth_message = None
|
||||
else:
|
||||
app.auth = None
|
||||
app.tokens = {}
|
||||
app.users = []
|
||||
uvicorn.run(app)
|
||||
|
@ -11,7 +11,6 @@ import markdown2 # type: ignore
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
import warnings
|
||||
@ -169,7 +168,7 @@ class Interface:
|
||||
state_param_index = [isinstance(i, i_State)
|
||||
for i in self.input_components].index(True)
|
||||
state_init_value = utils.get_default_args(fn[0])[state_param_index]
|
||||
except ValueError:
|
||||
except ValueError: # No default value for the state parameter
|
||||
state_init_value = None
|
||||
self.state_init_value = state_init_value
|
||||
|
||||
|
@ -3,8 +3,8 @@ Defines helper methods useful for setting up ports, launching servers, and
|
||||
creating tunnels.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import fastapi
|
||||
from fastapi_login import LoginManager
|
||||
import http
|
||||
import json
|
||||
import os
|
||||
@ -110,7 +110,8 @@ def start_server(
|
||||
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
|
||||
"""
|
||||
server_name = server_name or LOCALHOST_NAME
|
||||
if server_port is None: # if port is not specified, search for first available port
|
||||
# if port is not specified, search for first available port
|
||||
if server_port is None:
|
||||
port = get_first_available_port(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
||||
)
|
||||
@ -132,6 +133,11 @@ def start_server(
|
||||
app.auth = auth
|
||||
else:
|
||||
app.auth = None
|
||||
app.secret = os.urandom(24).hex()
|
||||
app.login_manager = LoginManager(
|
||||
app.secret, token_url="/login", use_cookie=True)
|
||||
app.users = set()
|
||||
|
||||
app.interface = interface
|
||||
app.cwd = os.getcwd()
|
||||
if app.interface.enable_queue:
|
||||
@ -142,7 +148,7 @@ def start_server(
|
||||
app.queue_thread.start()
|
||||
if interface.save_to is not None: # Used for selenium tests
|
||||
interface.save_to["port"] = port
|
||||
app.tokens = {}
|
||||
|
||||
config = uvicorn.Config(app=app, port=port, host=server_name,
|
||||
log_level="warning")
|
||||
server = Server(config=config)
|
||||
|
Loading…
Reference in New Issue
Block a user