Adding ability to access in Fastapi request object into your function (#2641)

* fastapi requests

* formatting

* implement

* fix

* formatting

* formatting

* changelog

* added demo

* remove print

* added to guide

* changes for queuing

* changes to gr.Request

* formatting

* formatting

* fixes

* lint

* fixed tests

* fix batching

* fixing tests

* cleanup

* lint

* added tests; fixed review

* improve docs
This commit is contained in:
Abubakar Abid 2022-11-19 00:52:06 -08:00 committed by GitHub
parent dbfa4dced1
commit 5e148c3752
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 343 additions and 114 deletions

View File

@ -1,7 +1,23 @@
# Upcoming Release
## New Features:
No changes to highlight.
### Accessing the Requests Object Directly
You can now access the Request object directly in your Python function by [@abidlabs](https://github.com/abidlabs) in [PR 2641](https://github.com/gradio-app/gradio/pull/2641). This means that you can access request headers, the client IP address, and so on. In order to use it, add a parameter to your function and set its type hint to be `gr.Request`. Here's a simple example:
```py
import gradio as gr
def echo(name, request: gr.Request):
if request:
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
return name
io = gr.Interface(echo, "textbox", "textbox").launch()
```
## Bug Fixes:
No changes to highlight.
@ -46,8 +62,6 @@ No changes to highlight.
# 3.10.0
## New Features:
* Add support for `'password'` and `'email'` types to `Textbox`. [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653)
* `gr.Textbox` component will now raise an exception if `type` is not "text", "email", or "password" [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653). This will cause demos using the deprecated `gr.Textbox(type="number")` to raise an exception.
@ -69,7 +83,7 @@ No changes to highlight.
No changes to highlight.
## Full Changelog:
No changes to highlight.
* Add support for `'password'` and `'email'` types to `Textbox`. [@pngwn](https://github.com/pngwn) in [PR 2653](https://github.com/gradio-app/gradio/pull/2653)
## Contributors Shoutout:
No changes to highlight.

View File

@ -11,7 +11,14 @@ with gr.Blocks() as demo:
with gr.Tab("Interface"):
gr.Interface(lambda x:x, "audio", "audio", examples=[audio_file])
with gr.Tab("console"):
ip = gr.Textbox(label="User IP Address")
gr.Interface(lambda cmd:subprocess.run([cmd], capture_output=True, shell=True).stdout.decode('utf-8').strip(), "text", "text")
def get_ip(request: gr.Request):
return request.client.host
demo.load(get_ip, None, ip)
if __name__ == "__main__":
demo.queue()
demo.launch()

View File

@ -59,7 +59,7 @@ from gradio.interface import Interface, TabbedInterface, close_all
from gradio.ipython_ext import load_ipython_extension
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
from gradio.mix import Parallel, Series
from gradio.routes import mount_gradio_app
from gradio.routes import Request, mount_gradio_app
from gradio.templates import (
Files,
ImageMask,

View File

@ -9,6 +9,7 @@ import pkgutil
import random
import sys
import time
import typing
import warnings
import webbrowser
from types import ModuleType
@ -58,6 +59,7 @@ from gradio.utils import (
set_documentation_group("blocks")
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import comet_ml
import mlflow
@ -457,6 +459,23 @@ def convert_component_dict_to_list(outputs_ids: List[int], predictions: Dict) ->
return predictions
def add_request_to_inputs(
fn: Callable, inputs: List[Any], request: routes.Request | List[routes.Request]
):
"""
Adds the FastAPI Request object to the inputs of a function if the type of the parameter is FastAPI.Request.
"""
param_names = inspect.getfullargspec(fn)[0]
try:
parameter_types = typing.get_type_hints(fn)
for idx, param_name in enumerate(param_names):
if parameter_types.get(param_name, "") == routes.Request:
inputs.insert(idx, request)
except TypeError: # A TypeError is raised if the function is a partial or other rare cases.
pass
return inputs
@document("load")
class Blocks(BlockContext):
"""
@ -796,7 +815,12 @@ class Blocks(BlockContext):
if batch:
processed_inputs = [[inp] for inp in processed_inputs]
outputs = utils.synchronize_async(self.process_api, fn_index, processed_inputs)
outputs = utils.synchronize_async(
self.process_api,
fn_index=fn_index,
inputs=processed_inputs,
request=None,
)
outputs = outputs["data"]
if batch:
@ -812,11 +836,11 @@ class Blocks(BlockContext):
fn_index: int,
processed_input: List[Any],
iterator: Iterator[Any] | None = None,
request: routes.Request | List[routes.Request] | None = None,
):
"""Calls and times function with given index and preprocessed input."""
block_fn = self.fns[fn_index]
is_generating = False
start = time.time()
if block_fn.inputs_as_dict:
processed_input = [
@ -826,6 +850,12 @@ class Blocks(BlockContext):
}
]
processed_input = add_request_to_inputs(
block_fn.fn, list(processed_input), request
)
start = time.time()
if iterator is None: # If not a generator function that has already run
if inspect.iscoroutinefunction(block_fn.fn):
prediction = await block_fn.fn(*processed_input)
@ -944,6 +974,7 @@ class Blocks(BlockContext):
self,
fn_index: int,
inputs: List[Any],
request: routes.Request | List[routes.Request] | None = None,
username: str = None,
state: Dict[int, Any] | List[Dict[int, Any]] | None = None,
iterators: Dict[int, Any] | None = None,
@ -980,7 +1011,7 @@ class Blocks(BlockContext):
)
inputs = [self.preprocess_data(fn_index, i, state) for i in zip(*inputs)]
result = await self.call_function(fn_index, zip(*inputs), None)
result = await self.call_function(fn_index, zip(*inputs), None, request)
preds = result["prediction"]
data = [self.postprocess_data(fn_index, o, state) for o in zip(*preds)]
data = list(zip(*data))
@ -988,7 +1019,7 @@ class Blocks(BlockContext):
else:
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None) if iterators else None
result = await self.call_function(fn_index, inputs, iterator)
result = await self.call_function(fn_index, inputs, iterator, request)
data = self.postprocess_data(fn_index, result["prediction"], state)
is_generating, iterator = result["is_generating"], result["iterator"]

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
@ -10,6 +10,9 @@ class PredictBody(BaseModel):
batched: Optional[
bool
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
request: Optional[
Union[Dict, List[Dict]]
] = None # dictionary of request headers, query parameters, url, etc. (used to to pass in request for queuing)
class ResetBody(BaseModel):

View File

@ -290,7 +290,7 @@ class Examples:
if self.batch:
processed_input = [[value] for value in processed_input]
prediction = await Context.root_block.process_api(
fn_index, processed_input
fn_index=fn_index, inputs=processed_input, request=None
)
output = prediction["data"]
if self.batch:

View File

@ -5,14 +5,13 @@ import copy
import sys
import time
from collections import deque
from itertools import islice
from typing import Deque, Dict, List, Optional, Tuple
from typing import Any, Deque, Dict, List, Optional, Tuple
import fastapi
from pydantic import BaseModel
from gradio.dataclasses import PredictBody
from gradio.utils import Request, run_coro_in_background, set_task_name
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name
class Estimation(BaseModel):
@ -26,7 +25,11 @@ class Estimation(BaseModel):
class Event:
def __init__(self, websocket: fastapi.WebSocket, fn_index: int | None = None):
def __init__(
self,
websocket: fastapi.WebSocket,
fn_index: int | None = None,
):
self.websocket = websocket
self.data: PredictBody | None = None
self.lost_connection_time: float | None = None
@ -157,18 +160,6 @@ class Queue:
if self.live_updates:
await self.broadcast_estimations()
async def gather_data_for_first_ranks(self) -> None:
"""
Gather data for the first x events.
"""
# Send all messages concurrently
await asyncio.gather(
*[
self.gather_event_data(event)
for event in islice(self.event_queue, self.data_gathering_start)
]
)
async def gather_event_data(self, event: Event) -> bool:
"""
Gather data for the event
@ -253,14 +244,34 @@ class Queue:
queue_eta=self.queue_duration,
)
def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
return {
"url": str(websocket.url),
"headers": dict(websocket.headers),
"query_params": dict(websocket.query_params),
"path_params": dict(websocket.path_params),
"client": dict(host=websocket.client.host, port=websocket.client.port),
}
async def call_prediction(self, events: List[Event], batch: bool):
data = events[0].data
token = events[0].token
try:
data.request = self.get_request_params(events[0].websocket)
except ValueError:
pass
if batch:
data.data = list(zip(*[event.data.data for event in events if event.data]))
data.request = [
self.get_request_params(event.websocket)
for event in events
if event.data
]
data.batched = True
response = await Request(
method=Request.Method.POST,
response = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=f"{self.server_path}api/predict",
json=dict(data),
headers={"Authorization": f"Bearer {self.access_token}"},
@ -370,8 +381,8 @@ class Queue:
return None
async def reset_iterators(self, session_hash: str, fn_index: int):
await Request(
method=Request.Method.POST,
await AsyncRequest(
method=AsyncRequest.Method.POST,
url=f"{self.server_path}reset",
json={
"session_hash": session_hash,

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import inspect
import io
import json
import mimetypes
import os
import posixpath
@ -13,20 +14,20 @@ import traceback
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Any, List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from urllib.parse import urlparse
import fastapi
import orjson
import pkg_resources
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi import Depends, FastAPI, HTTPException, WebSocket, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound
from starlette.responses import RedirectResponse
from starlette.websockets import WebSocket, WebSocketState
from starlette.websockets import WebSocketState
import gradio
from gradio import encryptor, utils
@ -108,7 +109,7 @@ class App(FastAPI):
@app.get("/user")
@app.get("/user/")
def get_current_user(request: Request) -> Optional[str]:
def get_current_user(request: fastapi.Request) -> Optional[str]:
token = request.cookies.get("access-token")
return app.tokens.get(token)
@ -127,13 +128,13 @@ class App(FastAPI):
@app.get("/token")
@app.get("/token/")
def get_token(request: Request) -> dict:
def get_token(request: fastapi.Request) -> dict:
token = request.cookies.get("access-token")
return {"token": token, "user": app.tokens.get(token)}
@app.get("/app_id")
@app.get("/app_id/")
def app_id(request: Request) -> int:
def app_id(request: fastapi.Request) -> int:
return {"app_id": app.blocks.app_id}
@app.post("/login")
@ -159,7 +160,7 @@ class App(FastAPI):
@app.head("/", response_class=HTMLResponse)
@app.get("/", response_class=HTMLResponse)
def main(request: Request, user: str = Depends(get_current_user)):
def main(request: fastapi.Request, user: str = Depends(get_current_user)):
mimetypes.add_type("application/javascript", ".js")
if app.auth is None or not (user is None):
@ -258,7 +259,9 @@ class App(FastAPI):
return {"success": True}
async def run_predict(
body: PredictBody, username: str = Depends(get_current_user)
body: PredictBody,
request: Request,
username: str = Depends(get_current_user),
):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
@ -287,7 +290,12 @@ class App(FastAPI):
raw_input = [raw_input]
try:
output = await app.blocks.process_api(
fn_index, raw_input, username, session_state, iterators
fn_index=fn_index,
inputs=raw_input,
request=request,
username=username,
state=session_state,
iterators=iterators,
)
iterator = output.pop("iterator", None)
if hasattr(body, "session_hash"):
@ -317,7 +325,7 @@ class App(FastAPI):
async def predict(
api_name: str,
body: PredictBody,
request: Request,
request: fastapi.Request,
username: str = Depends(get_current_user),
):
if body.fn_index is None:
@ -345,12 +353,20 @@ class App(FastAPI):
# current session hash
if app.blocks.dependencies[body.fn_index]["cancels"]:
body.data = [body.session_hash]
result = await run_predict(body=body, username=username)
if body.request:
if body.batched:
request = [Request(**req) for req in body.request]
else:
request = Request(**body.request)
else:
request = Request(request)
result = await run_predict(body=body, username=username, request=request)
return result
@app.websocket("/queue/join")
async def join_queue(
websocket: WebSocket, token: str = Depends(ws_login_check)
websocket: WebSocket,
token: str = Depends(ws_login_check),
):
if app.auth is not None and token is None:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
@ -462,6 +478,64 @@ def get_server_url_from_ws_url(ws_url: str):
set_documentation_group("routes")
class Obj:
"""
Using a class to convert dictionaries into objects. Used by the `Request` class.
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/
"""
def __init__(self, dict1):
self.__dict__.update(dict1)
def __str__(self) -> str:
return str(self.__dict__)
def __repr__(self) -> str:
return str(self.__dict__)
@document()
class Request:
"""
A Gradio request object that can be used to access the request headers, cookies,
query parameters and other information about the request from within the prediction
function. The class is a thin wrapper around the fastapi.Request class. Attributes
of this class include: `headers`, `client`, `query_params`, and `path_params`,
Example:
import gradio as gr
def echo(name, request: gr.Request):
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
return name
io = gr.Interface(echo, "textbox", "textbox").launch()
"""
def __init__(self, request: fastapi.Request | None = None, **kwargs):
"""
Can be instantiated with either a fastapi.Request or by manually passing in
attributes (needed for websocket-based queueing).
"""
self.request: fastapi.Request = request
self.kwargs: Dict = kwargs
def dict_to_obj(self, d):
if isinstance(d, dict):
return json.loads(json.dumps(d), object_hook=Obj)
else:
return d
def __getattr__(self, name):
if self.request:
return self.dict_to_obj(getattr(self.request, name))
else:
try:
obj = self.kwargs[name]
except KeyError:
raise AttributeError(f"'Request' object has no attribute '{name}'")
return self.dict_to_obj(obj)
@document()
def mount_gradio_app(
app: fastapi.FastAPI,

View File

@ -12,6 +12,7 @@ import pkgutil
import random
import sys
import time
import typing
import warnings
from contextlib import contextmanager
from distutils.version import StrictVersion
@ -395,19 +396,19 @@ def async_iteration(iterator):
raise StopAsyncIteration()
class Request:
class AsyncRequest:
"""
The Request class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
Compared to making calls by using httpx directly, Request offers more flexibility and control over:
The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
Compared to making calls by using httpx directly, AsyncRequest offers more flexibility and control over:
(1) Includes response validation functionality both using validation models and functions.
(2) Since we're still using httpx.Request class by wrapping it, we have all it's functionalities.
(3) Exceptions are handled silently during the request call, which gives us the ability to inspect each one
individually in the case of multiple asynchronous request calls and some of them failing.
(4) Provides HTTP request types with Request.Method Enum class for ease of usage
Request also offers some util functions such as has_exception, is_valid and status to inspect get detailed
(4) Provides HTTP request types with AsyncRequest.Method Enum class for ease of usage
AsyncRequest also offers some util functions such as has_exception, is_valid and status to inspect get detailed
information about executed request call.
The basic usage of Request is as follows: create a Request object with inputs(method, url etc.). Then use it
The basic usage of AsyncRequest is as follows: create a AsyncRequest object with inputs(method, url etc.). Then use it
with the "await" statement, and then you can use util functions to do some post request checks depending on your use-case.
Finally, call the get_validated_data function to get the response data.
@ -466,13 +467,13 @@ class Request:
# Create request
self._request = self._create_request(method, url, **kwargs)
def __await__(self) -> Generator[None, Any, "Request"]:
def __await__(self) -> Generator[None, Any, "AsyncRequest"]:
"""
Wrap Request's __await__ magic function to create request calls which are executed in one line.
"""
return self.__run().__await__()
async def __run(self) -> Request:
async def __run(self) -> AsyncRequest:
"""
Manage the request call lifecycle.
Execute the request by sending it through the client, then check its status.
@ -486,7 +487,9 @@ class Request:
"""
try:
# Send the request and get the response.
self._response: httpx.Response = await Request.client.send(self._request)
self._response: httpx.Response = await AsyncRequest.client.send(
self._request
)
# Raise for _status
self._status = self._response.status_code
if self._raise_for_status:
@ -503,7 +506,7 @@ class Request:
return self
@staticmethod
def _create_request(method: Method, url: str, **kwargs) -> Request:
def _create_request(method: Method, url: str, **kwargs) -> AsyncRequest:
"""
Create a request. This is a httpx request wrapper function.
Args:
@ -743,15 +746,24 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
Checks if the input component set matches the function
Returns: None if valid, a string error message if mismatch
"""
def is_special_typed_parameter(name):
from gradio.routes import Request
"""Checks if parameter has a type hint designating it as a gr.Request"""
return parameter_types.get(name, "") == Request
signature = inspect.signature(fn)
parameter_types = typing.get_type_hints(fn) if inspect.isfunction(fn) else {}
min_args = 0
max_args = 0
for param in signature.parameters.values():
for name, param in signature.parameters.items():
has_default = param.default != param.empty
if param.kind in [param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD]:
if not has_default:
min_args += 1
max_args += 1
if not (is_special_typed_parameter(name)):
if not has_default:
min_args += 1
max_args += 1
elif param.kind == param.VAR_POSITIONAL:
max_args = "infinity"
elif param.kind == param.KEYWORD_ONLY:

View File

@ -106,6 +106,27 @@ def same_auth(username, password):
demo.launch(auth=same_auth)
```
## Accessing the Network Request Directly
When a user makes a prediction to your app, you may need the underlying network request, in order to get the request headers (e.g. for advanced authentication), log the client's IP address, or for other reasons. Gradio supports this in a similar manner to FastAPI: simply add a function parameter whose type hint is `gr.Request` and Gradio will pass in the network request as that parameter. Here is an example:
```python
import gradio as gr
def echo(name, request: gr.Request):
if request:
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
return name
io = gr.Interface(echo, "textbox", "textbox").launch()
```
Note: if your function is called directly instead of through the UI (this happens, for
example, when examples are cached), then `request` will be `None`. You should handle
this case explicitly to ensure that your app does not throw any errors. That is why
we have the explicit check `if request`.
## Mounting Within Another FastAPI App
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo.

View File

@ -1,4 +1,5 @@
import asyncio
import copy
import io
import json
import os
@ -945,6 +946,70 @@ class TestEvery:
break
class TestAddRequests:
def test_no_type_hints(self):
def moo(a, b):
return a + b
inputs = [1, 2]
request = gr.Request()
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
assert inputs_ == inputs
boo = partial(moo, a=1)
inputs = [2]
inputs_ = gr.blocks.add_request_to_inputs(boo, copy.deepcopy(inputs), request)
assert inputs_ == inputs
def test_no_type_hints_with_request(self):
def moo(a: str, b: int):
return a + str(b)
inputs = ["abc", 2]
request = gr.Request()
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
assert inputs_ == inputs
boo = partial(moo, a="def")
inputs = [2]
inputs_ = gr.blocks.add_request_to_inputs(boo, copy.deepcopy(inputs), request)
assert inputs_ == inputs
def test_type_hints_with_request(self):
def moo(a: str, b: gr.Request):
return a
inputs = ["abc"]
request = gr.Request()
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
assert inputs_ == inputs + [request]
def moo(a: gr.Request, b, c: int):
return c
inputs = ["abc", 5]
request = gr.Request()
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
assert inputs_ == [request] + inputs
def test_type_hints_with_multiple_requests(self):
def moo(a: str, b: gr.Request, c: gr.Request):
return a
inputs = ["abc"]
request = gr.Request()
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
assert inputs_ == inputs + [request, request]
def moo(a: gr.Request, b, c: int, d: gr.Request):
return c
inputs = ["abc", 5]
request = gr.Request()
inputs_ = gr.blocks.add_request_to_inputs(moo, copy.deepcopy(inputs), request)
assert inputs_ == [request] + inputs + [request]
def test_queue_enabled_for_fn():
with gr.Blocks() as demo:
input = gr.Textbox()

View File

@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from gradio.queue import Event, Queue
from gradio.utils import Request
from gradio.utils import AsyncRequest
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -44,17 +44,6 @@ class TestQueueMethods:
assert queue.stopped is False
assert queue.get_active_worker_count() == 0
@pytest.mark.asyncio
async def test_dont_gather_data_while_broadcasting(self, queue: Queue):
queue.broadcast_live_estimations = AsyncMock()
queue.gather_data_for_first_ranks = AsyncMock()
await queue.broadcast_live_estimations()
# Should not gather data while broadcasting estimations
# Have seen weird race conditions come up in very viral
# spaces
queue.gather_data_for_first_ranks.assert_not_called()
@pytest.mark.asyncio
async def test_stop_resume(self, queue: Queue):
await queue.start()
@ -107,21 +96,6 @@ class TestQueueMethods:
assert await queue.gather_event_data(mock_event)
assert not (queue.send_message.called)
@pytest.mark.asyncio
async def test_gather_data_for_first_ranks(self, queue: Queue, mock_event: Event):
websocket = MagicMock()
mock_event2 = Event(websocket=websocket, fn_index=0)
queue.send_message = AsyncMock()
queue.get_message = AsyncMock()
queue.send_message.return_value = True
queue.get_message.return_value = {"data": ["test"], "fn": 0}
queue.push(mock_event)
queue.push(mock_event2)
await queue.gather_data_for_first_ranks()
assert mock_event.data is not None
assert mock_event2.data is None
class TestQueueEstimation:
def test_get_update_estimation(self, queue: Queue):
@ -170,7 +144,7 @@ class TestQueueProcessEvents:
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queue.Request", new_callable=AsyncMock)
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
@ -190,7 +164,7 @@ class TestQueueProcessEvents:
mock_event.disconnect.assert_called_once()
queue.clean_event.assert_called_once()
mock_request.assert_called_with(
method=Request.Method.POST,
method=AsyncRequest.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,
@ -280,7 +254,7 @@ class TestQueueProcessEvents:
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queue.Request", new_callable=AsyncMock)
@patch("gradio.queue.AsyncRequest", new_callable=AsyncMock)
async def test_process_event_handles_exception_during_disconnect(
self, mock_request, queue: Queue, mock_event: Event
):
@ -295,7 +269,7 @@ class TestQueueProcessEvents:
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
mock_request.assert_called_with(
method=Request.Method.POST,
method=AsyncRequest.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,

View File

@ -311,6 +311,23 @@ class TestDevMode:
assert not gradio_fast_api.app.blocks.dev_mode
class TestPassingRequest:
def test_request_included_with_regular_function(self):
def identity(name, request: gr.Request):
assert isinstance(request.client.host, str)
return name
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
prevent_thread_lock=True,
)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": ["test"]})
assert response.status_code == 200
output = dict(response.json())
assert output["data"] == ["test"]
def test_predict_route_is_blocked_if_api_open_false():
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(
api_open=False

View File

@ -18,7 +18,7 @@ from gradio.test_data.blocks_configs import (
XRAY_CONFIG_WITH_MISTAKE,
)
from gradio.utils import (
Request,
AsyncRequest,
append_unique_suffix,
assert_configs_are_equivalent_besides_ids,
colab_check,
@ -256,15 +256,15 @@ async def client():
A fixture to mock the async client object.
"""
async with AsyncClient() as mock_client:
with mock.patch("gradio.utils.Request.client", mock_client):
with mock.patch("gradio.utils.AsyncRequest.client", mock_client):
yield
class TestRequest:
@pytest.mark.asyncio
async def test_get(self):
client_response: Request = await Request(
method=Request.Method.GET,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.GET,
url="http://headers.jsontest.com/",
)
validated_data = client_response.get_validated_data()
@ -273,8 +273,8 @@ class TestRequest:
@pytest.mark.asyncio
async def test_post(self):
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url="https://reqres.in/api/users",
json={"name": "morpheus", "job": "leader"},
)
@ -291,8 +291,8 @@ class TestRequest:
id: str
createdAt: str
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url="https://reqres.in/api/users",
json={"name": "morpheus", "job": "leader"},
validation_model=TestModel,
@ -305,8 +305,8 @@ class TestRequest:
name: Literal["John"] = "John"
job: str
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url="https://reqres.in/api/users",
json={"name": "morpheus", "job": "leader"},
validation_model=TestModel,
@ -330,8 +330,8 @@ async def test_get(respx_mock):
make_mock_response({"Host": "headers.jsontest.com"})
)
client_response: Request = await Request(
method=Request.Method.GET,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.GET,
url=MOCK_REQUEST_URL,
)
validated_data = client_response.get_validated_data()
@ -345,8 +345,8 @@ async def test_post(respx_mock):
payload = {"name": "morpheus", "job": "leader"}
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response(payload))
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=MOCK_REQUEST_URL,
json=payload,
)
@ -375,8 +375,8 @@ async def test_validate_with_model(respx_mock):
id: str
createdAt: str
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=MOCK_REQUEST_URL,
json={"name": "morpheus", "job": "leader"},
validation_model=TestModel,
@ -387,14 +387,14 @@ async def test_validate_with_model(respx_mock):
@pytest.mark.asyncio
async def test_validate_and_fail_with_model(respx_mock):
class TestModel(BaseModel):
name: Literal[str] = "John"
name: Literal["John"]
job: str
payload = {"name": "morpheus", "job": "leader"}
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response(payload))
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=MOCK_REQUEST_URL,
json=payload,
validation_model=TestModel,
@ -405,7 +405,7 @@ async def test_validate_and_fail_with_model(respx_mock):
assert isinstance(client_response.exception, Exception)
@mock.patch("gradio.utils.Request._validate_response_data")
@mock.patch("gradio.utils.AsyncRequest._validate_response_data")
@pytest.mark.asyncio
async def test_exception_type(validate_response_data, respx_mock):
class ResponseValidationException(Exception):
@ -415,8 +415,8 @@ async def test_exception_type(validate_response_data, respx_mock):
respx_mock.get(MOCK_REQUEST_URL).mock(Response(201))
client_response: Request = await Request(
method=Request.Method.GET,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.GET,
url=MOCK_REQUEST_URL,
exception_type=ResponseValidationException,
)
@ -435,8 +435,8 @@ async def test_validate_with_function(respx_mock):
return response
raise Exception
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=MOCK_REQUEST_URL,
json={"name": "morpheus", "job": "leader"},
validation_function=has_name,
@ -457,8 +457,8 @@ async def test_validate_and_fail_with_function(respx_mock):
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response({"name": "morpheus"}))
client_response: Request = await Request(
method=Request.Method.POST,
client_response: AsyncRequest = await AsyncRequest(
method=AsyncRequest.Method.POST,
url=MOCK_REQUEST_URL,
json={"name": "morpheus", "job": "leader"},
validation_function=has_name,