Fix false positive warning in check_function_inputs_match (#3837)

* fix false positive warning of check_function_inputs_match

* apply linter

* Update CHANGELOG and add comments in test

* Lint

---------

Co-authored-by: Freddy Boulton <alfonsoboulton@gmail.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
January Desk 2023-04-20 03:44:28 +08:00 committed by GitHub
parent a95f615ce8
commit e0eea96766
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 2 deletions

View File

@ -184,7 +184,7 @@ No changes to highlight.
## Bug Fixes:
- Fixes Chatbot issue where new lines were being created every time a message was sent back and forth by [@aliabid94](https://github.com/aliabid94) in [PR 3717](https://github.com/gradio-app/gradio/pull/3717).
- Fixes false postive warning which is due to too strict type checking by [@yiyuezhuo](https://github.com/yiyuezhuo) in [PR 3837](https://github.com/gradio-app/gradio/pull/3837).
## Documentation Changes:

View File

@ -837,6 +837,14 @@ def get_cancel_function(
)
def get_type_hints(fn):
if inspect.isfunction(fn) or inspect.ismethod(fn):
return typing.get_type_hints(fn)
elif callable(fn):
return typing.get_type_hints(fn.__call__)
return {}
def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool):
"""
Checks if the input component set matches the function
@ -854,7 +862,7 @@ def check_function_inputs_match(fn: Callable, inputs: List, inputs_as_dict: bool
return is_request or is_event_data
signature = inspect.signature(fn)
parameter_types = typing.get_type_hints(fn) if inspect.isfunction(fn) else {}
parameter_types = get_type_hints(fn)
min_args = 0
max_args = 0
infinity = -1

View File

@ -14,6 +14,7 @@ from httpx import AsyncClient, Response
from pydantic import BaseModel
from typing_extensions import Literal
from gradio import EventData
from gradio.context import Context
from gradio.test_data.blocks_configs import (
XRAY_CONFIG,
@ -25,11 +26,13 @@ from gradio.utils import (
abspath,
append_unique_suffix,
assert_configs_are_equivalent_besides_ids,
check_function_inputs_match,
colab_check,
delete_none,
error_analytics,
format_ner_list,
get_local_ip_address,
get_type_hints,
ipython_check,
kaggle_check,
launch_analytics,
@ -587,3 +590,51 @@ class TestAbspath:
def test_abspath_symlink(self, mock_islink):
resolved_path = str(abspath("../gradio/gradio/test_data/lion.jpg"))
assert ".." in resolved_path
class TestGetTypeHints:
def test_get_type_hints(self):
class F:
def __call__(self, s: str):
return s
class C:
def f(self, s: str):
return s
def f(s: str):
return s
class GenericObject:
pass
test_objs = [F(), C().f, f]
for x in test_objs:
hints = get_type_hints(x)
assert len(hints) == 1
assert hints["s"] == str
assert len(get_type_hints(GenericObject())) == 0
class TestCheckFunctionInputsMatch:
def test_check_function_inputs_match(self):
class F:
def __call__(self, s: str, evt: EventData):
return s
class C:
def f(self, s: str, evt: EventData):
return s
def f(s: str, evt: EventData):
return s
test_objs = [F(), C().f, f]
with warnings.catch_warnings():
warnings.simplefilter("error") # Ensure there're no warnings raised here.
for x in test_objs:
check_function_inputs_match(x, [None], False)