mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Async Request Class (#1595)
* Implement Request class and its tests. * Add new requirements * Reformat codebase. * Fix formatting. * Add library level requirements. * Convert validated_data property to get_validated_data function. * Fix the client fixture. * Update test/test_utils.py * Update test/test_utils.py * Fix the client fixture. * Add missing initialization for Request._validated_data * Fix async test problem with test_tunneling.py * Update gradio/utils.py * Update gradio/utils.py * Fix formatting. Co-authored-by: Ömer Faruk Özdemir <farukozderim@gmail.com>
This commit is contained in:
parent
7a0f6b1dd2
commit
51c8c34486
@ -16,3 +16,5 @@ pydub
|
||||
requests
|
||||
uvicorn
|
||||
Jinja2
|
||||
httpx
|
||||
pydantic
|
||||
|
224
gradio/utils.py
224
gradio/utils.py
@ -12,13 +12,16 @@ import sys
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from distutils.version import StrictVersion
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NewType, Type
|
||||
|
||||
import aiohttp
|
||||
import analytics
|
||||
import fsspec.asyn
|
||||
import httpx
|
||||
import pkg_resources
|
||||
import requests
|
||||
from pydantic import BaseModel, Json, parse_obj_as
|
||||
|
||||
import gradio
|
||||
|
||||
@ -346,3 +349,222 @@ def run_coro_in_background(func: Callable, *args, **kwargs):
|
||||
"""
|
||||
event_loop = asyncio.get_event_loop()
|
||||
_ = event_loop.create_task(func(*args, **kwargs))
|
||||
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
|
||||
class Request:
|
||||
"""
|
||||
The Request class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
|
||||
Compared to making calls by using httpx directly, Request offers more flexibility and control over:
|
||||
(1) Includes response validation functionality both using validation models and functions.
|
||||
(2) Since we're still using httpx.Request class by wrapping it, we have all it's functionalities.
|
||||
(3) Exceptions are handled silently during the request call, which gives us the ability to inspect each one
|
||||
individually in the case of multiple asynchronous request calls and some of them failing.
|
||||
(4) Provides HTTP request types with Request.Method Enum class for ease of usage
|
||||
Request also offers some util functions such as has_exception, is_valid and status to inspect get detailed
|
||||
information about executed request call.
|
||||
|
||||
The basic usage of Request is as follows: create a Request object with inputs(method, url etc.). Then use it
|
||||
with the "await" statement, and then you can use util functions to do some post request checks depending on your use-case.
|
||||
Finally, call the get_validated_data function to get the response data.
|
||||
|
||||
You can see example usages in test_utils.py.
|
||||
"""
|
||||
|
||||
ResponseJson = NewType("ResponseJson", Json)
|
||||
|
||||
class Method(str, Enum):
|
||||
"""
|
||||
Method is an enumeration class that contains possible types of HTTP request methods.
|
||||
"""
|
||||
|
||||
ANY = "*"
|
||||
CONNECT = "CONNECT"
|
||||
HEAD = "HEAD"
|
||||
GET = "GET"
|
||||
DELETE = "DELETE"
|
||||
OPTIONS = "OPTIONS"
|
||||
PATCH = "PATCH"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
TRACE = "TRACE"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: Method,
|
||||
url: str,
|
||||
*,
|
||||
validation_model: Type[BaseModel] = None,
|
||||
validation_function: Callable = None,
|
||||
exception_type: Type[Exception] = Exception,
|
||||
raise_for_status: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Request instance.
|
||||
Args:
|
||||
method(Request.Method) : method of the request
|
||||
url(str): url of the request
|
||||
*
|
||||
validation_model(Type[BaseModel]): a pydantic validation class type to use in validation of the response
|
||||
validation_function(Callable): a callable instance to use in validation of the response
|
||||
exception_class(Type[Exception]): a exception type to throw with its type
|
||||
raise_for_status(bool): a flag that determines to raise httpx.Request.raise_for_status() exceptions.
|
||||
"""
|
||||
self._response = None
|
||||
self._exception = None
|
||||
self._status = None
|
||||
self._raise_for_status = raise_for_status
|
||||
self._validation_model = validation_model
|
||||
self._validation_function = validation_function
|
||||
self._exception_type = exception_type
|
||||
self._validated_data = None
|
||||
# Create request
|
||||
self._request = self._create_request(method, url, **kwargs)
|
||||
|
||||
def __await__(self):
|
||||
"""
|
||||
Wrap Request's __await__ magic function to create request calls which are executed in one line.
|
||||
"""
|
||||
return self.__run().__await__()
|
||||
|
||||
async def __run(self) -> Request:
|
||||
"""
|
||||
Manage the request call lifecycle.
|
||||
Execute the request by sending it through the client, then check its status.
|
||||
Then parse the request into Json format. And then validate it using the provided validation methods.
|
||||
If a problem occurs in this sequential process,
|
||||
an exception will be raised within the corresponding method, and allowed to be examined.
|
||||
Manage the request call lifecycle.
|
||||
|
||||
Returns:
|
||||
Request
|
||||
"""
|
||||
try:
|
||||
# Send the request and get the response.
|
||||
self._response: httpx.Response = await client.send(self._request)
|
||||
# Raise for _status
|
||||
self._status = self._response.status_code
|
||||
if self._raise_for_status:
|
||||
self._response.raise_for_status()
|
||||
# Parse client response data to JSON
|
||||
self._json_response_data = self._response.json()
|
||||
# Validate response data
|
||||
self._validated_data = self._validate_response_data(
|
||||
self._json_response_data
|
||||
)
|
||||
except Exception as exception:
|
||||
# If there is an exception, store it to do further inspections.
|
||||
self._exception = self._exception_type(exception)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _create_request(method: Method, url: str, **kwargs) -> Request:
|
||||
"""
|
||||
Create a request. This is a httpx request wrapper function.
|
||||
Args:
|
||||
method(Request.Method): request method type
|
||||
url(str): target url of the request
|
||||
**kwargs
|
||||
Returns:
|
||||
Request
|
||||
"""
|
||||
request = httpx.Request(method, url, **kwargs)
|
||||
return request
|
||||
|
||||
def _validate_response_data(self, response: ResponseJson) -> ResponseJson:
|
||||
"""
|
||||
Validate response using given validation methods. If there is a validation method and response is not valid,
|
||||
validation functions will raise an exception for them.
|
||||
Args:
|
||||
response(ResponseJson): response object
|
||||
Returns:
|
||||
ResponseJson: Validated Json object.
|
||||
"""
|
||||
|
||||
# We use raw response as a default value if there is no validation method or response is not valid.
|
||||
validated_response = response
|
||||
|
||||
try:
|
||||
# If a validation model is provided, validate response using the validation model.
|
||||
if self._validation_model:
|
||||
validated_response = self._validate_response_by_model(
|
||||
validated_response
|
||||
)
|
||||
# Then, If a validation function is provided, validate response using the validation function.
|
||||
if self._validation_function:
|
||||
validated_response = self._validate_response_by_validation_function(
|
||||
validated_response
|
||||
)
|
||||
except Exception as exception:
|
||||
# If one of the validation methods does not confirm, raised exception will be silently handled.
|
||||
# We assign this exception to classes instance to do further inspections via is_valid function.
|
||||
self._exception = exception
|
||||
|
||||
return validated_response
|
||||
|
||||
def _validate_response_by_model(self, response: ResponseJson) -> ResponseJson:
|
||||
"""
|
||||
Validate response json using the validation model.
|
||||
Args:
|
||||
response(ResponseJson): response object
|
||||
Returns:
|
||||
ResponseJson: Validated Json object.
|
||||
"""
|
||||
validated_data = parse_obj_as(self._validation_model, response)
|
||||
return validated_data
|
||||
|
||||
def _validate_response_by_validation_function(
|
||||
self, response: ResponseJson
|
||||
) -> ResponseJson:
|
||||
"""
|
||||
Validate response json using the validation function.
|
||||
Args:
|
||||
response(ResponseJson): response object
|
||||
Returns:
|
||||
ResponseJson: Validated Json object.
|
||||
"""
|
||||
validated_data = self._validation_function(response)
|
||||
return validated_data
|
||||
|
||||
def is_valid(self, raise_exceptions: bool = False) -> bool:
|
||||
"""
|
||||
Check response object's validity+. Raise exceptions if raise_exceptions flag is True.
|
||||
Args:
|
||||
raise_exceptions(bool) : a flag to raise exceptions in this check
|
||||
Returns:
|
||||
bool: validity of the data
|
||||
"""
|
||||
if self.has_exception:
|
||||
if raise_exceptions:
|
||||
raise self._exception
|
||||
return False
|
||||
else:
|
||||
# If there is no exception, that means there is no validation error.
|
||||
return True
|
||||
|
||||
def get_validated_data(self):
|
||||
return self._validated_data
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
return self._json_response_data
|
||||
|
||||
@property
|
||||
def exception(self):
|
||||
return self._exception
|
||||
|
||||
@property
|
||||
def has_exception(self):
|
||||
return self.exception is not None
|
||||
|
||||
@property
|
||||
def raise_exceptions(self):
|
||||
if self.has_exception:
|
||||
raise self._exception
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._status
|
||||
|
2
setup.py
2
setup.py
@ -44,6 +44,8 @@ setup(
|
||||
"uvicorn",
|
||||
"Jinja2",
|
||||
"fsspec",
|
||||
"httpx",
|
||||
"pydantic",
|
||||
],
|
||||
entry_points={
|
||||
'console_scripts': ['gradio=gradio.reload:run_in_reload_mode']
|
||||
|
@ -18,3 +18,5 @@ black
|
||||
isort
|
||||
flake8
|
||||
torch
|
||||
httpx
|
||||
pydantic
|
@ -1,11 +1,17 @@
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
from typing import Literal
|
||||
|
||||
import pkg_resources
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
from httpx import AsyncClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gradio.test_data.blocks_configs import (
|
||||
XRAY_CONFIG,
|
||||
@ -13,6 +19,7 @@ from gradio.test_data.blocks_configs import (
|
||||
XRAY_CONFIG_WITH_MISTAKE,
|
||||
)
|
||||
from gradio.utils import (
|
||||
Request,
|
||||
assert_configs_are_equivalent_besides_ids,
|
||||
colab_check,
|
||||
delete_none,
|
||||
@ -32,7 +39,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
class TestUtils(unittest.TestCase):
|
||||
@mock.patch("pkg_resources.require")
|
||||
def test_should_fail_with_distribution_not_found(self, mock_require):
|
||||
|
||||
mock_require.side_effect = pkg_resources.DistributionNotFound()
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
@ -45,7 +51,6 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_should_warn_with_unable_to_parse(self, mock_get):
|
||||
|
||||
mock_get.side_effect = json.decoder.JSONDecodeError("Expecting value", "", 0)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
@ -57,7 +62,6 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
@mock.patch("requests.Response.json")
|
||||
def test_should_warn_url_not_having_version(self, mock_json):
|
||||
|
||||
mock_json.return_value = {"foo": "bar"}
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
@ -69,7 +73,6 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
error_analytics("placeholder", "placeholder")
|
||||
mock_post.assert_called()
|
||||
@ -184,5 +187,124 @@ class TestDeleteNone(unittest.TestCase):
|
||||
self.assertEqual(delete_none(input), truth)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function", autouse=True)
|
||||
async def client():
|
||||
"""
|
||||
A fixture to mock the async client object.
|
||||
"""
|
||||
async with AsyncClient() as mock_client:
|
||||
with mock.patch("gradio.utils.client", mock_client):
|
||||
yield
|
||||
|
||||
|
||||
class TestRequest:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get(self):
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.GET,
|
||||
url="http://headers.jsontest.com/",
|
||||
)
|
||||
validated_data = client_response.get_validated_data()
|
||||
assert client_response.is_valid() is True
|
||||
assert validated_data["Host"] == "headers.jsontest.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post(self):
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
)
|
||||
validated_data = client_response.get_validated_data()
|
||||
assert client_response.status == 201
|
||||
assert validated_data["job"] == "leader"
|
||||
assert validated_data["name"] == "morpheus"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_with_model(self):
|
||||
class TestModel(BaseModel):
|
||||
name: str
|
||||
job: str
|
||||
id: str
|
||||
createdAt: str
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_model=TestModel,
|
||||
)
|
||||
assert isinstance(client_response.get_validated_data(), TestModel)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_fail_with_model(self):
|
||||
class TestModel(BaseModel):
|
||||
name: Literal[str] = "John"
|
||||
job: str
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_model=TestModel,
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
client_response.is_valid(raise_exceptions=True)
|
||||
assert client_response.has_exception is True
|
||||
assert isinstance(client_response.exception, Exception)
|
||||
|
||||
@mock.patch("gradio.utils.Request._validate_response_data")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_type(self, validate_response_data):
|
||||
class ResponseValidationException(Exception):
|
||||
message = "Response object is not valid."
|
||||
|
||||
validate_response_data.side_effect = Exception()
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.GET,
|
||||
url="https://reqres.in/api/users",
|
||||
exception_type=ResponseValidationException,
|
||||
)
|
||||
assert isinstance(client_response.exception, ResponseValidationException)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_with_function(self):
|
||||
def has_name(response):
|
||||
if response["name"] is not None:
|
||||
return response
|
||||
raise Exception
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_function=has_name,
|
||||
)
|
||||
validated_data = client_response.get_validated_data()
|
||||
assert client_response.is_valid() is True
|
||||
assert validated_data["id"] is not None
|
||||
assert client_response.exception is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_fail_with_function(self):
|
||||
def has_name(response):
|
||||
if response["name"] is not None:
|
||||
if response["name"] == "Alex":
|
||||
return response
|
||||
raise Exception
|
||||
|
||||
client_response: Request = await Request(
|
||||
method=Request.Method.POST,
|
||||
url="https://reqres.in/api/users",
|
||||
json={"name": "morpheus", "job": "leader"},
|
||||
validation_function=has_name,
|
||||
)
|
||||
assert client_response.is_valid() is False
|
||||
with pytest.raises(Exception):
|
||||
client_response.is_valid(raise_exceptions=True)
|
||||
assert client_response.exception is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user