diff --git a/CHANGELOG.md b/CHANGELOG.md index c59e29b3a7..17d5845dc2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,7 +184,7 @@ No changes to highlight. ## Bug Fixes: - Fixes Chatbot issue where new lines were being created every time a message was sent back and forth by [@aliabid94](https://github.com/aliabid94) in [PR 3717](https://github.com/gradio-app/gradio/pull/3717). - +- Fixes false postive warning which is due to too strict type checking by [@yiyuezhuo](https://github.com/yiyuezhuo) in [PR 3837](https://github.com/gradio-app/gradio/pull/3837). ## Documentation Changes: diff --git a/gradio/utils.py b/gradio/utils.py index 55257e4cf7..c795f11d82 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -837,6 +837,14 @@ def get_cancel_function( ) +def get_type_hints(fn): + if inspect.isfunction(fn) or inspect.ismethod(fn): + return typing.get_type_hints(fn) + elif callable(fn): + return typing.get_type_hints(fn.__call__) + return {} + + def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool): """ Checks if the input component set matches the function @@ -854,7 +862,7 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool return is_request or is_event_data signature = inspect.signature(fn) - parameter_types = typing.get_type_hints(fn) if inspect.isfunction(fn) else {} + parameter_types = get_type_hints(fn) min_args = 0 max_args = 0 infinity = -1 diff --git a/test/test_utils.py b/test/test_utils.py index 7aef377cd5..90c89cd884 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,6 +14,7 @@ from httpx import AsyncClient, Response from pydantic import BaseModel from typing_extensions import Literal +from gradio import EventData from gradio.context import Context from gradio.test_data.blocks_configs import ( XRAY_CONFIG, @@ -25,11 +26,13 @@ from gradio.utils import ( abspath, append_unique_suffix, assert_configs_are_equivalent_besides_ids, + check_function_inputs_match, colab_check, delete_none, error_analytics, format_ner_list, get_local_ip_address, + get_type_hints, ipython_check, kaggle_check, launch_analytics, @@ -587,3 +590,51 @@ class TestAbspath: def test_abspath_symlink(self, mock_islink): resolved_path = str(abspath("../gradio/gradio/test_data/lion.jpg")) assert ".." in resolved_path + + +class TestGetTypeHints: + def test_get_type_hints(self): + class F: + def __call__(self, s: str): + return s + + class C: + def f(self, s: str): + return s + + def f(s: str): + return s + + class GenericObject: + pass + + test_objs = [F(), C().f, f] + + for x in test_objs: + hints = get_type_hints(x) + assert len(hints) == 1 + assert hints["s"] == str + + assert len(get_type_hints(GenericObject())) == 0 + + +class TestCheckFunctionInputsMatch: + def test_check_function_inputs_match(self): + class F: + def __call__(self, s: str, evt: EventData): + return s + + class C: + def f(self, s: str, evt: EventData): + return s + + def f(s: str, evt: EventData): + return s + + test_objs = [F(), C().f, f] + + with warnings.catch_warnings(): + warnings.simplefilter("error") # Ensure there're no warnings raised here. + + for x in test_objs: + check_function_inputs_match(x, [None], False)