Improve error messages for when argument lists mismatch (#3519)

* Improve error messages for when argument lists mismatch

* Fix comment

* Update changelog

* Fix typeerror

* Fix single output case

* Add argument validation tests

* Lint

* Move changelog to correct section

* Fix typo

---------

Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
This commit is contained in:
space-nuko 2023-04-07 13:22:42 -04:00 committed by GitHub
parent b5e342cad0
commit 303b4dc0b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 130 additions and 3 deletions

View File

@ -3,7 +3,7 @@
## New Features: ## New Features:
No changes to highlight. - Improve error messages when number of inputs/outputs to event handlers mismatch, by [@space-nuko](https://github.com/space-nuko) in [PR 3519](https://github.com/gradio-app/gradio/pull/3519)
## Bug Fixes: ## Bug Fixes:
@ -38,7 +38,7 @@ No changes to highlight.
## New Features: ## New Features:
No changes to highlight. - No changes to highlight.
## Bug Fixes: ## Bug Fixes:

View File

@ -979,10 +979,53 @@ class Blocks(BlockContext):
return predictions return predictions
def validate_inputs(self, fn_index: int, inputs: List[Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
dep_inputs = dependency["inputs"]
# This handles incorrect inputs when args are changed by a JS function
# Only check not enough args case, ignore extra arguments (for now)
# TODO: make this stricter?
if len(inputs) < len(dep_inputs):
name = (
f" ({block_fn.name})"
if block_fn.name and block_fn.name != "<lambda>"
else ""
)
wanted_args = []
received_args = []
for input_id in dep_inputs:
block = self.blocks[input_id]
wanted_args.append(str(block))
for inp in inputs:
if isinstance(inp, str):
v = f'"{inp}"'
else:
v = str(inp)
received_args.append(v)
wanted = ", ".join(wanted_args)
received = ", ".join(received_args)
# JS func didn't pass enough arguments
raise ValueError(
f"""An event handler{name} didn't receive enough input values (needed: {len(dep_inputs)}, got: {len(inputs)}).
Check if the event handler calls a Javascript function, and make sure its return value is correct.
Wanted inputs:
[{wanted}]
Received inputs:
[{received}]"""
)
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]): def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
block_fn = self.fns[fn_index] block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index] dependency = self.dependencies[fn_index]
self.validate_inputs(fn_index, inputs)
if block_fn.preprocess: if block_fn.preprocess:
processed_input = [] processed_input = []
for i, input_id in enumerate(dependency["inputs"]): for i, input_id in enumerate(dependency["inputs"]):
@ -998,6 +1041,45 @@ class Blocks(BlockContext):
processed_input = inputs processed_input = inputs
return processed_input return processed_input
def validate_outputs(self, fn_index: int, predictions: Any | List[Any]):
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
dep_outputs = dependency["outputs"]
if type(predictions) is not list and type(predictions) is not tuple:
predictions = [predictions]
if len(predictions) < len(dep_outputs):
name = (
f" ({block_fn.name})"
if block_fn.name and block_fn.name != "<lambda>"
else ""
)
wanted_args = []
received_args = []
for output_id in dep_outputs:
block = self.blocks[output_id]
wanted_args.append(str(block))
for pred in predictions:
if isinstance(pred, str):
v = f'"{pred}"'
else:
v = str(pred)
received_args.append(v)
wanted = ", ".join(wanted_args)
received = ", ".join(received_args)
raise ValueError(
f"""An event handler{name} didn't receive enough output values (needed: {len(dep_outputs)}, received: {len(predictions)}).
Wanted outputs:
[{wanted}]
Received outputs:
[{received}]"""
)
def postprocess_data( def postprocess_data(
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any] self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
): ):
@ -1015,6 +1097,8 @@ class Blocks(BlockContext):
predictions, predictions,
] ]
self.validate_outputs(fn_index, predictions) # type: ignore
output = [] output = []
for i, output_id in enumerate(dependency["outputs"]): for i, output_id in enumerate(dependency["outputs"]):
try: try:

View File

@ -561,10 +561,53 @@ class TestBlocksPostprocessing:
button.click(lambda x: x, textbox1, [textbox1, textbox2]) button.click(lambda x: x, textbox1, [textbox1, textbox2])
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="Number of output components does not match number of values returned from from function <lambda>", match=r'An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
): ):
demo.postprocess_data(fn_index=0, predictions=["test"], state={}) demo.postprocess_data(fn_index=0, predictions=["test"], state={})
def test_error_raised_if_num_outputs_mismatch_with_function_name(self):
def infer(x):
return x
with gr.Blocks() as demo:
textbox1 = gr.Textbox()
textbox2 = gr.Textbox()
button = gr.Button()
button.click(infer, textbox1, [textbox1, textbox2])
with pytest.raises(
ValueError,
match=r'An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
):
demo.postprocess_data(fn_index=0, predictions=["test"], state={})
def test_error_raised_if_num_outputs_mismatch_single_output(self):
with gr.Blocks() as demo:
num1 = gr.Number()
num2 = gr.Number()
btn = gr.Button(value="1")
btn.click(lambda a: a, num1, [num1, num2])
with pytest.raises(
ValueError,
match=r"An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[number, number\]\nReceived outputs:\n \[1\]",
):
demo.postprocess_data(fn_index=0, predictions=1, state={})
def test_error_raised_if_num_outputs_mismatch_tuple_output(self):
def infer(a, b):
return a, b
with gr.Blocks() as demo:
num1 = gr.Number()
num2 = gr.Number()
num3 = gr.Number()
btn = gr.Button(value="1")
btn.click(infer, num1, [num1, num2, num3])
with pytest.raises(
ValueError,
match=r"An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:\n \[number, number, number\]\nReceived outputs:\n \[1, 2\]",
):
demo.postprocess_data(fn_index=0, predictions=(1, 2), state={})
class TestCallFunction: class TestCallFunction:
@pytest.mark.asyncio @pytest.mark.asyncio