Allow multiple instances of Gradio with authentication to run on different ports (#5588)

* Allow gradio auth to work across different ports

* lint

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-09-18 12:25:47 -07:00 committed by GitHub
parent 1d8c4de962
commit acdeff57ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 12 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Allow multiple instances of Gradio with authentication to run on different ports

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
import json
import os
import pkgutil
import threading
import urllib.parse
import warnings

View File

@ -117,6 +117,7 @@ class App(FastAPI):
self.iterators = defaultdict(dict)
self.iterators_to_reset = defaultdict(set)
self.lock = utils.safe_get_lock()
self.cookie_id = secrets.token_urlsafe(32)
self.queue_token = secrets.token_urlsafe(32)
self.startup_events_triggered = False
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
@ -188,9 +189,9 @@ class App(FastAPI):
@app.get("/user")
@app.get("/user/")
def get_current_user(request: fastapi.Request) -> Optional[str]:
token = request.cookies.get("access-token") or request.cookies.get(
"access-token-unsecure"
)
token = request.cookies.get(
f"access-token-{app.cookie_id}"
) or request.cookies.get(f"access-token-unsecure-{app.cookie_id}")
return app.tokens.get(token)
@app.get("/login_check")
@ -203,15 +204,15 @@ class App(FastAPI):
)
async def ws_login_check(websocket: WebSocket) -> Optional[str]:
token = websocket.cookies.get("access-token") or websocket.cookies.get(
"access-token-unsecure"
)
token = websocket.cookies.get(
f"access-token-{app.cookie_id}"
) or websocket.cookies.get(f"access-token-unsecure-{app.cookie_id}")
return token # token is returned to authenticate the websocket connection in the endpoint handler.
@app.get("/token")
@app.get("/token/")
def get_token(request: fastapi.Request) -> dict:
token = request.cookies.get("access-token")
token = request.cookies.get(f"access-token-{app.cookie_id}")
return {"token": token, "user": app.tokens.get(token)}
@app.get("/app_id")
@ -267,14 +268,16 @@ class App(FastAPI):
app.tokens[token] = username
response = JSONResponse(content={"success": True})
response.set_cookie(
key="access-token",
key=f"access-token-{app.cookie_id}",
value=token,
httponly=True,
samesite="none",
secure=True,
)
response.set_cookie(
key="access-token-unsecure", value=token, httponly=True
key=f"access-token-unsecure-{app.cookie_id}",
value=token,
httponly=True,
)
return response
else:

View File

@ -1680,7 +1680,7 @@ async def test_queue_when_using_auth():
follow_redirects=False,
)
assert resp.status_code == 200
token = resp.cookies.get("access-token")
token = resp.cookies.get(f"access-token-{demo.app.cookie_id}")
assert token
with pytest.raises(Exception) as e:
@ -1693,7 +1693,7 @@ async def test_queue_when_using_auth():
async def run_ws(i):
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join",
extra_headers={"Cookie": f"access-token={token}"},
extra_headers={"Cookie": f"access-token-{demo.app.cookie_id}={token}"},
) as ws:
while True:
try: