mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Improve error messages for when argument lists mismatch (#3519)
* Improve error messages for when argument lists mismatch * Fix comment * Update changelog * Fix typeerror * Fix single output case * Add argument validation tests * Lint * Move changelog to correct section * Fix typo --------- Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
This commit is contained in:
parent
b5e342cad0
commit
303b4dc0b5
@ -3,7 +3,7 @@
|
||||
|
||||
## New Features:
|
||||
|
||||
No changes to highlight.
|
||||
- Improve error messages when number of inputs/outputs to event handlers mismatch, by [@space-nuko](https://github.com/space-nuko) in [PR 3519](https://github.com/gradio-app/gradio/pull/3519)
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
@ -38,7 +38,7 @@ No changes to highlight.
|
||||
|
||||
## New Features:
|
||||
|
||||
No changes to highlight.
|
||||
- No changes to highlight.
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
|
@ -979,10 +979,53 @@ class Blocks(BlockContext):
|
||||
|
||||
return predictions
|
||||
|
||||
def validate_inputs(self, fn_index: int, inputs: List[Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
dep_inputs = dependency["inputs"]
|
||||
|
||||
# This handles incorrect inputs when args are changed by a JS function
|
||||
# Only check not enough args case, ignore extra arguments (for now)
|
||||
# TODO: make this stricter?
|
||||
if len(inputs) < len(dep_inputs):
|
||||
name = (
|
||||
f" ({block_fn.name})"
|
||||
if block_fn.name and block_fn.name != "<lambda>"
|
||||
else ""
|
||||
)
|
||||
|
||||
wanted_args = []
|
||||
received_args = []
|
||||
for input_id in dep_inputs:
|
||||
block = self.blocks[input_id]
|
||||
wanted_args.append(str(block))
|
||||
for inp in inputs:
|
||||
if isinstance(inp, str):
|
||||
v = f'"{inp}"'
|
||||
else:
|
||||
v = str(inp)
|
||||
received_args.append(v)
|
||||
|
||||
wanted = ", ".join(wanted_args)
|
||||
received = ", ".join(received_args)
|
||||
|
||||
# JS func didn't pass enough arguments
|
||||
raise ValueError(
|
||||
f"""An event handler{name} didn't receive enough input values (needed: {len(dep_inputs)}, got: {len(inputs)}).
|
||||
Check if the event handler calls a Javascript function, and make sure its return value is correct.
|
||||
Wanted inputs:
|
||||
[{wanted}]
|
||||
Received inputs:
|
||||
[{received}]"""
|
||||
)
|
||||
|
||||
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
self.validate_inputs(fn_index, inputs)
|
||||
|
||||
if block_fn.preprocess:
|
||||
processed_input = []
|
||||
for i, input_id in enumerate(dependency["inputs"]):
|
||||
@ -998,6 +1041,45 @@ class Blocks(BlockContext):
|
||||
processed_input = inputs
|
||||
return processed_input
|
||||
|
||||
def validate_outputs(self, fn_index: int, predictions: Any | List[Any]):
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
dep_outputs = dependency["outputs"]
|
||||
|
||||
if type(predictions) is not list and type(predictions) is not tuple:
|
||||
predictions = [predictions]
|
||||
|
||||
if len(predictions) < len(dep_outputs):
|
||||
name = (
|
||||
f" ({block_fn.name})"
|
||||
if block_fn.name and block_fn.name != "<lambda>"
|
||||
else ""
|
||||
)
|
||||
|
||||
wanted_args = []
|
||||
received_args = []
|
||||
for output_id in dep_outputs:
|
||||
block = self.blocks[output_id]
|
||||
wanted_args.append(str(block))
|
||||
for pred in predictions:
|
||||
if isinstance(pred, str):
|
||||
v = f'"{pred}"'
|
||||
else:
|
||||
v = str(pred)
|
||||
received_args.append(v)
|
||||
|
||||
wanted = ", ".join(wanted_args)
|
||||
received = ", ".join(received_args)
|
||||
|
||||
raise ValueError(
|
||||
f"""An event handler{name} didn't receive enough output values (needed: {len(dep_outputs)}, received: {len(predictions)}).
|
||||
Wanted outputs:
|
||||
[{wanted}]
|
||||
Received outputs:
|
||||
[{received}]"""
|
||||
)
|
||||
|
||||
def postprocess_data(
|
||||
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
|
||||
):
|
||||
@ -1015,6 +1097,8 @@ class Blocks(BlockContext):
|
||||
predictions,
|
||||
]
|
||||
|
||||
self.validate_outputs(fn_index, predictions) # type: ignore
|
||||
|
||||
output = []
|
||||
for i, output_id in enumerate(dependency["outputs"]):
|
||||
try:
|
||||
|
@ -561,10 +561,53 @@ class TestBlocksPostprocessing:
|
||||
button.click(lambda x: x, textbox1, [textbox1, textbox2])
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Number of output components does not match number of values returned from from function <lambda>",
|
||||
match=r'An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
|
||||
):
|
||||
demo.postprocess_data(fn_index=0, predictions=["test"], state={})
|
||||
|
||||
def test_error_raised_if_num_outputs_mismatch_with_function_name(self):
|
||||
def infer(x):
|
||||
return x
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
textbox1 = gr.Textbox()
|
||||
textbox2 = gr.Textbox()
|
||||
button = gr.Button()
|
||||
button.click(infer, textbox1, [textbox1, textbox2])
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r'An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
|
||||
):
|
||||
demo.postprocess_data(fn_index=0, predictions=["test"], state={})
|
||||
|
||||
def test_error_raised_if_num_outputs_mismatch_single_output(self):
|
||||
with gr.Blocks() as demo:
|
||||
num1 = gr.Number()
|
||||
num2 = gr.Number()
|
||||
btn = gr.Button(value="1")
|
||||
btn.click(lambda a: a, num1, [num1, num2])
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[number, number\]\nReceived outputs:\n \[1\]",
|
||||
):
|
||||
demo.postprocess_data(fn_index=0, predictions=1, state={})
|
||||
|
||||
def test_error_raised_if_num_outputs_mismatch_tuple_output(self):
|
||||
def infer(a, b):
|
||||
return a, b
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
num1 = gr.Number()
|
||||
num2 = gr.Number()
|
||||
num3 = gr.Number()
|
||||
btn = gr.Button(value="1")
|
||||
btn.click(infer, num1, [num1, num2, num3])
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:\n \[number, number, number\]\nReceived outputs:\n \[1, 2\]",
|
||||
):
|
||||
demo.postprocess_data(fn_index=0, predictions=(1, 2), state={})
|
||||
|
||||
|
||||
class TestCallFunction:
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user