From 51c8c34486bfddca5948e46e498de44e21ab6496 Mon Sep 17 00:00:00 2001 From: Halil Ibrahim Bestil Date: Thu, 23 Jun 2022 20:44:04 +0300 Subject: [PATCH] Async Request Class (#1595) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement Request class and its tests. * Add new requirements * Reformat codebase. * Fix formatting. * Add library level requirements. * Convert validated_data property to get_validated_data function. * Fix the client fixture. * Update test/test_utils.py * Update test/test_utils.py * Fix the client fixture. * Add missing initialization for Request._validated_data * Fix async test problem with test_tunneling.py * Update gradio/utils.py * Update gradio/utils.py * Fix formatting. Co-authored-by: Ömer Faruk Özdemir --- gradio.egg-info/requires.txt | 2 + gradio/utils.py | 224 +++++++++++++++++- setup.py | 2 + test/requirements.in | 2 + test/test_utils.py | 130 +++++++++- ...{test_tunneling.py => test_x_tunneling.py} | 0 6 files changed, 355 insertions(+), 5 deletions(-) rename test/{test_tunneling.py => test_x_tunneling.py} (100%) diff --git a/gradio.egg-info/requires.txt b/gradio.egg-info/requires.txt index 20089b7116..294f278307 100644 --- a/gradio.egg-info/requires.txt +++ b/gradio.egg-info/requires.txt @@ -16,3 +16,5 @@ pydub requests uvicorn Jinja2 +httpx +pydantic diff --git a/gradio/utils.py b/gradio/utils.py index 21ca8e6710..89e325fdbb 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -12,13 +12,16 @@ import sys import warnings from copy import deepcopy from distutils.version import StrictVersion -from typing import TYPE_CHECKING, Any, Callable, Dict, List +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, List, NewType, Type import aiohttp import analytics import fsspec.asyn +import httpx import pkg_resources import requests +from pydantic import BaseModel, Json, parse_obj_as import gradio @@ -346,3 +349,222 @@ def run_coro_in_background(func: Callable, *args, **kwargs): """ event_loop = asyncio.get_event_loop() _ = event_loop.create_task(func(*args, **kwargs)) + + +client = httpx.AsyncClient() + + +class Request: + """ + The Request class is a low-level API that allow you to create asynchronous HTTP requests without a context manager. + Compared to making calls by using httpx directly, Request offers more flexibility and control over: + (1) Includes response validation functionality both using validation models and functions. + (2) Since we're still using httpx.Request class by wrapping it, we have all it's functionalities. + (3) Exceptions are handled silently during the request call, which gives us the ability to inspect each one + individually in the case of multiple asynchronous request calls and some of them failing. + (4) Provides HTTP request types with Request.Method Enum class for ease of usage + Request also offers some util functions such as has_exception, is_valid and status to inspect get detailed + information about executed request call. + + The basic usage of Request is as follows: create a Request object with inputs(method, url etc.). Then use it + with the "await" statement, and then you can use util functions to do some post request checks depending on your use-case. + Finally, call the get_validated_data function to get the response data. + + You can see example usages in test_utils.py. + """ + + ResponseJson = NewType("ResponseJson", Json) + + class Method(str, Enum): + """ + Method is an enumeration class that contains possible types of HTTP request methods. + """ + + ANY = "*" + CONNECT = "CONNECT" + HEAD = "HEAD" + GET = "GET" + DELETE = "DELETE" + OPTIONS = "OPTIONS" + PATCH = "PATCH" + POST = "POST" + PUT = "PUT" + TRACE = "TRACE" + + def __init__( + self, + method: Method, + url: str, + *, + validation_model: Type[BaseModel] = None, + validation_function: Callable = None, + exception_type: Type[Exception] = Exception, + raise_for_status: bool = False, + **kwargs, + ): + """ + Initialize the Request instance. + Args: + method(Request.Method) : method of the request + url(str): url of the request + * + validation_model(Type[BaseModel]): a pydantic validation class type to use in validation of the response + validation_function(Callable): a callable instance to use in validation of the response + exception_class(Type[Exception]): a exception type to throw with its type + raise_for_status(bool): a flag that determines to raise httpx.Request.raise_for_status() exceptions. + """ + self._response = None + self._exception = None + self._status = None + self._raise_for_status = raise_for_status + self._validation_model = validation_model + self._validation_function = validation_function + self._exception_type = exception_type + self._validated_data = None + # Create request + self._request = self._create_request(method, url, **kwargs) + + def __await__(self): + """ + Wrap Request's __await__ magic function to create request calls which are executed in one line. + """ + return self.__run().__await__() + + async def __run(self) -> Request: + """ + Manage the request call lifecycle. + Execute the request by sending it through the client, then check its status. + Then parse the request into Json format. And then validate it using the provided validation methods. + If a problem occurs in this sequential process, + an exception will be raised within the corresponding method, and allowed to be examined. + Manage the request call lifecycle. + + Returns: + Request + """ + try: + # Send the request and get the response. + self._response: httpx.Response = await client.send(self._request) + # Raise for _status + self._status = self._response.status_code + if self._raise_for_status: + self._response.raise_for_status() + # Parse client response data to JSON + self._json_response_data = self._response.json() + # Validate response data + self._validated_data = self._validate_response_data( + self._json_response_data + ) + except Exception as exception: + # If there is an exception, store it to do further inspections. + self._exception = self._exception_type(exception) + return self + + @staticmethod + def _create_request(method: Method, url: str, **kwargs) -> Request: + """ + Create a request. This is a httpx request wrapper function. + Args: + method(Request.Method): request method type + url(str): target url of the request + **kwargs + Returns: + Request + """ + request = httpx.Request(method, url, **kwargs) + return request + + def _validate_response_data(self, response: ResponseJson) -> ResponseJson: + """ + Validate response using given validation methods. If there is a validation method and response is not valid, + validation functions will raise an exception for them. + Args: + response(ResponseJson): response object + Returns: + ResponseJson: Validated Json object. + """ + + # We use raw response as a default value if there is no validation method or response is not valid. + validated_response = response + + try: + # If a validation model is provided, validate response using the validation model. + if self._validation_model: + validated_response = self._validate_response_by_model( + validated_response + ) + # Then, If a validation function is provided, validate response using the validation function. + if self._validation_function: + validated_response = self._validate_response_by_validation_function( + validated_response + ) + except Exception as exception: + # If one of the validation methods does not confirm, raised exception will be silently handled. + # We assign this exception to classes instance to do further inspections via is_valid function. + self._exception = exception + + return validated_response + + def _validate_response_by_model(self, response: ResponseJson) -> ResponseJson: + """ + Validate response json using the validation model. + Args: + response(ResponseJson): response object + Returns: + ResponseJson: Validated Json object. + """ + validated_data = parse_obj_as(self._validation_model, response) + return validated_data + + def _validate_response_by_validation_function( + self, response: ResponseJson + ) -> ResponseJson: + """ + Validate response json using the validation function. + Args: + response(ResponseJson): response object + Returns: + ResponseJson: Validated Json object. + """ + validated_data = self._validation_function(response) + return validated_data + + def is_valid(self, raise_exceptions: bool = False) -> bool: + """ + Check response object's validity+. Raise exceptions if raise_exceptions flag is True. + Args: + raise_exceptions(bool) : a flag to raise exceptions in this check + Returns: + bool: validity of the data + """ + if self.has_exception: + if raise_exceptions: + raise self._exception + return False + else: + # If there is no exception, that means there is no validation error. + return True + + def get_validated_data(self): + return self._validated_data + + @property + def json(self): + return self._json_response_data + + @property + def exception(self): + return self._exception + + @property + def has_exception(self): + return self.exception is not None + + @property + def raise_exceptions(self): + if self.has_exception: + raise self._exception + + @property + def status(self): + return self._status diff --git a/setup.py b/setup.py index 8dcb5f180e..bc98eb5c60 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,8 @@ setup( "uvicorn", "Jinja2", "fsspec", + "httpx", + "pydantic", ], entry_points={ 'console_scripts': ['gradio=gradio.reload:run_in_reload_mode'] diff --git a/test/requirements.in b/test/requirements.in index 572eb7096a..d17f4b0bd3 100644 --- a/test/requirements.in +++ b/test/requirements.in @@ -18,3 +18,5 @@ black isort flake8 torch +httpx +pydantic \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index eabe2d6255..c4b01980cf 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,11 +1,17 @@ +import asyncio import ipaddress import os import unittest import unittest.mock as mock import warnings +from typing import Literal import pkg_resources +import pytest +import pytest_asyncio import requests +from httpx import AsyncClient +from pydantic import BaseModel from gradio.test_data.blocks_configs import ( XRAY_CONFIG, @@ -13,6 +19,7 @@ from gradio.test_data.blocks_configs import ( XRAY_CONFIG_WITH_MISTAKE, ) from gradio.utils import ( + Request, assert_configs_are_equivalent_besides_ids, colab_check, delete_none, @@ -32,7 +39,6 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" class TestUtils(unittest.TestCase): @mock.patch("pkg_resources.require") def test_should_fail_with_distribution_not_found(self, mock_require): - mock_require.side_effect = pkg_resources.DistributionNotFound() with warnings.catch_warnings(record=True) as w: @@ -45,7 +51,6 @@ class TestUtils(unittest.TestCase): @mock.patch("requests.get") def test_should_warn_with_unable_to_parse(self, mock_get): - mock_get.side_effect = json.decoder.JSONDecodeError("Expecting value", "", 0) with warnings.catch_warnings(record=True) as w: @@ -57,7 +62,6 @@ class TestUtils(unittest.TestCase): @mock.patch("requests.Response.json") def test_should_warn_url_not_having_version(self, mock_json): - mock_json.return_value = {"foo": "bar"} with warnings.catch_warnings(record=True) as w: @@ -69,7 +73,6 @@ class TestUtils(unittest.TestCase): @mock.patch("requests.post") def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post): - mock_post.side_effect = requests.ConnectionError() error_analytics("placeholder", "placeholder") mock_post.assert_called() @@ -184,5 +187,124 @@ class TestDeleteNone(unittest.TestCase): self.assertEqual(delete_none(input), truth) +@pytest_asyncio.fixture(scope="function", autouse=True) +async def client(): + """ + A fixture to mock the async client object. + """ + async with AsyncClient() as mock_client: + with mock.patch("gradio.utils.client", mock_client): + yield + + +class TestRequest: + @pytest.mark.asyncio + async def test_get(self): + client_response: Request = await Request( + method=Request.Method.GET, + url="http://headers.jsontest.com/", + ) + validated_data = client_response.get_validated_data() + assert client_response.is_valid() is True + assert validated_data["Host"] == "headers.jsontest.com" + + @pytest.mark.asyncio + async def test_post(self): + client_response: Request = await Request( + method=Request.Method.POST, + url="https://reqres.in/api/users", + json={"name": "morpheus", "job": "leader"}, + ) + validated_data = client_response.get_validated_data() + assert client_response.status == 201 + assert validated_data["job"] == "leader" + assert validated_data["name"] == "morpheus" + + @pytest.mark.asyncio + async def test_validate_with_model(self): + class TestModel(BaseModel): + name: str + job: str + id: str + createdAt: str + + client_response: Request = await Request( + method=Request.Method.POST, + url="https://reqres.in/api/users", + json={"name": "morpheus", "job": "leader"}, + validation_model=TestModel, + ) + assert isinstance(client_response.get_validated_data(), TestModel) + + @pytest.mark.asyncio + async def test_validate_and_fail_with_model(self): + class TestModel(BaseModel): + name: Literal[str] = "John" + job: str + + client_response: Request = await Request( + method=Request.Method.POST, + url="https://reqres.in/api/users", + json={"name": "morpheus", "job": "leader"}, + validation_model=TestModel, + ) + with pytest.raises(Exception): + client_response.is_valid(raise_exceptions=True) + assert client_response.has_exception is True + assert isinstance(client_response.exception, Exception) + + @mock.patch("gradio.utils.Request._validate_response_data") + @pytest.mark.asyncio + async def test_exception_type(self, validate_response_data): + class ResponseValidationException(Exception): + message = "Response object is not valid." + + validate_response_data.side_effect = Exception() + + client_response: Request = await Request( + method=Request.Method.GET, + url="https://reqres.in/api/users", + exception_type=ResponseValidationException, + ) + assert isinstance(client_response.exception, ResponseValidationException) + + @pytest.mark.asyncio + async def test_validate_with_function(self): + def has_name(response): + if response["name"] is not None: + return response + raise Exception + + client_response: Request = await Request( + method=Request.Method.POST, + url="https://reqres.in/api/users", + json={"name": "morpheus", "job": "leader"}, + validation_function=has_name, + ) + validated_data = client_response.get_validated_data() + assert client_response.is_valid() is True + assert validated_data["id"] is not None + assert client_response.exception is None + + @pytest.mark.asyncio + async def test_validate_and_fail_with_function(self): + def has_name(response): + if response["name"] is not None: + if response["name"] == "Alex": + return response + raise Exception + + client_response: Request = await Request( + method=Request.Method.POST, + url="https://reqres.in/api/users", + json={"name": "morpheus", "job": "leader"}, + validation_function=has_name, + ) + assert client_response.is_valid() is False + with pytest.raises(Exception): + client_response.is_valid(raise_exceptions=True) + assert client_response.exception is not None + + if __name__ == "__main__": unittest.main() diff --git a/test/test_tunneling.py b/test/test_x_tunneling.py similarity index 100% rename from test/test_tunneling.py rename to test/test_x_tunneling.py