Access username from gr.Request class (#3296)

* updates to the gr.Request class

* auth

* adds username

* revert utils

* changelog
This commit is contained in:
Abubakar Abid 2023-02-24 08:40:34 -08:00 committed by GitHub
parent f36211050c
commit f5e7b57ceb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 7 deletions

View File

@ -24,6 +24,9 @@ By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3297](https://git
- Adds a disabled mode to the `gr.Button` component by setting `interactive=False` by [@abidlabs](https://github.com/abidlabs) in [PR 3266](https://github.com/gradio-app/gradio/pull/3266) and [PR 3288](https://github.com/gradio-app/gradio/pull/3288)
- Allow the setting of `brush_radius` for the `Image` component both as a default and via `Image.update()` by [@pngwn](https://github.com/pngwn) in [PR 3277](https://github.com/gradio-app/gradio/pull/3277)
- Added `info=` argument to form components to enable extra context provided to users, by [@aliabid94](https://github.com/aliabid94) in [PR 3291](https://github.com/gradio-app/gradio/pull/3291)
- Allow developers to access the username of a logged-in user from the `gr.Request()` object using the `.username` attribute by [@abidlabs](https://github.com/abidlabs) in [PR 3296](https://github.com/gradio-app/gradio/pull/3296)
## Bug Fixes:
- Ensure `mirror_webcam` is always respected by [@pngwn](https://github.com/pngwn) in [PR 3245](https://github.com/gradio-app/gradio/pull/3245)

View File

@ -353,7 +353,6 @@ class App(FastAPI):
body: PredictBody,
request: Request | List[Request],
fn_index_inferred: int,
username: str = Depends(get_current_user),
):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
@ -451,16 +450,17 @@ class App(FastAPI):
body.data = [body.session_hash]
if body.request:
if body.batched:
gr_request = [Request(**req) for req in body.request]
gr_request = [
Request(username=username, **req) for req in body.request
]
else:
assert isinstance(body.request, dict)
gr_request = Request(**body.request)
gr_request = Request(username=username, **body.request)
else:
gr_request = Request(request)
gr_request = Request(username=username, request=request)
result = await run_predict(
body=body,
fn_index_inferred=fn_index_inferred,
username=username,
request=gr_request,
)
return result
@ -637,7 +637,8 @@ class Request:
A Gradio request object that can be used to access the request headers, cookies,
query parameters and other information about the request from within the prediction
function. The class is a thin wrapper around the fastapi.Request class. Attributes
of this class include: `headers`, `client`, `query_params`, and `path_params`,
of this class include: `headers`, `client`, `query_params`, and `path_params`. If
auth is enabled, the `username` attribute can be used to get the logged in user.
Example:
import gradio as gr
def echo(name, request: gr.Request):
@ -647,7 +648,12 @@ class Request:
io = gr.Interface(echo, "textbox", "textbox").launch()
"""
def __init__(self, request: fastapi.Request | None = None, **kwargs):
def __init__(
self,
request: fastapi.Request | None = None,
username: str | None = None,
**kwargs,
):
"""
Can be instantiated with either a fastapi.Request or by manually passing in
attributes (needed for websocket-based queueing).
@ -655,6 +661,7 @@ class Request:
request: A fastapi.Request
"""
self.request = request
self.username = username
self.kwargs: Dict = kwargs
def dict_to_obj(self, d):

View File

@ -439,6 +439,40 @@ class TestPassingRequest:
output = dict(response.json())
assert output["data"] == ["test"]
def test_request_includes_username_as_none_if_no_auth(self):
def identity(name, request: gr.Request):
assert request.username is None
return name
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
prevent_thread_lock=True,
)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": ["test"]})
assert response.status_code == 200
output = dict(response.json())
assert output["data"] == ["test"]
def test_request_includes_username_with_auth(self):
def identity(name, request: gr.Request):
assert request.username == "admin"
return name
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
prevent_thread_lock=True, auth=("admin", "password")
)
client = TestClient(app)
client.post(
"/login",
data=dict(username="admin", password="password"),
)
response = client.post("/api/predict/", json={"data": ["test"]})
assert response.status_code == 200
output = dict(response.json())
assert output["data"] == ["test"]
def test_predict_route_is_blocked_if_api_open_false():
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(