mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
Improve error messages for when argument lists mismatch (#3519)
* Improve error messages for when argument lists mismatch * Fix comment * Update changelog * Fix typeerror * Fix single output case * Add argument validation tests * Lint * Move changelog to correct section * Fix typo --------- Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
This commit is contained in:
parent
b5e342cad0
commit
303b4dc0b5
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
## New Features:
|
## New Features:
|
||||||
|
|
||||||
No changes to highlight.
|
- Improve error messages when number of inputs/outputs to event handlers mismatch, by [@space-nuko](https://github.com/space-nuko) in [PR 3519](https://github.com/gradio-app/gradio/pull/3519)
|
||||||
|
|
||||||
## Bug Fixes:
|
## Bug Fixes:
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ No changes to highlight.
|
|||||||
|
|
||||||
## New Features:
|
## New Features:
|
||||||
|
|
||||||
No changes to highlight.
|
- No changes to highlight.
|
||||||
|
|
||||||
## Bug Fixes:
|
## Bug Fixes:
|
||||||
|
|
||||||
|
@ -979,10 +979,53 @@ class Blocks(BlockContext):
|
|||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
def validate_inputs(self, fn_index: int, inputs: List[Any]):
|
||||||
|
block_fn = self.fns[fn_index]
|
||||||
|
dependency = self.dependencies[fn_index]
|
||||||
|
|
||||||
|
dep_inputs = dependency["inputs"]
|
||||||
|
|
||||||
|
# This handles incorrect inputs when args are changed by a JS function
|
||||||
|
# Only check not enough args case, ignore extra arguments (for now)
|
||||||
|
# TODO: make this stricter?
|
||||||
|
if len(inputs) < len(dep_inputs):
|
||||||
|
name = (
|
||||||
|
f" ({block_fn.name})"
|
||||||
|
if block_fn.name and block_fn.name != "<lambda>"
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
wanted_args = []
|
||||||
|
received_args = []
|
||||||
|
for input_id in dep_inputs:
|
||||||
|
block = self.blocks[input_id]
|
||||||
|
wanted_args.append(str(block))
|
||||||
|
for inp in inputs:
|
||||||
|
if isinstance(inp, str):
|
||||||
|
v = f'"{inp}"'
|
||||||
|
else:
|
||||||
|
v = str(inp)
|
||||||
|
received_args.append(v)
|
||||||
|
|
||||||
|
wanted = ", ".join(wanted_args)
|
||||||
|
received = ", ".join(received_args)
|
||||||
|
|
||||||
|
# JS func didn't pass enough arguments
|
||||||
|
raise ValueError(
|
||||||
|
f"""An event handler{name} didn't receive enough input values (needed: {len(dep_inputs)}, got: {len(inputs)}).
|
||||||
|
Check if the event handler calls a Javascript function, and make sure its return value is correct.
|
||||||
|
Wanted inputs:
|
||||||
|
[{wanted}]
|
||||||
|
Received inputs:
|
||||||
|
[{received}]"""
|
||||||
|
)
|
||||||
|
|
||||||
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
|
def preprocess_data(self, fn_index: int, inputs: List[Any], state: Dict[int, Any]):
|
||||||
block_fn = self.fns[fn_index]
|
block_fn = self.fns[fn_index]
|
||||||
dependency = self.dependencies[fn_index]
|
dependency = self.dependencies[fn_index]
|
||||||
|
|
||||||
|
self.validate_inputs(fn_index, inputs)
|
||||||
|
|
||||||
if block_fn.preprocess:
|
if block_fn.preprocess:
|
||||||
processed_input = []
|
processed_input = []
|
||||||
for i, input_id in enumerate(dependency["inputs"]):
|
for i, input_id in enumerate(dependency["inputs"]):
|
||||||
@ -998,6 +1041,45 @@ class Blocks(BlockContext):
|
|||||||
processed_input = inputs
|
processed_input = inputs
|
||||||
return processed_input
|
return processed_input
|
||||||
|
|
||||||
|
def validate_outputs(self, fn_index: int, predictions: Any | List[Any]):
|
||||||
|
block_fn = self.fns[fn_index]
|
||||||
|
dependency = self.dependencies[fn_index]
|
||||||
|
|
||||||
|
dep_outputs = dependency["outputs"]
|
||||||
|
|
||||||
|
if type(predictions) is not list and type(predictions) is not tuple:
|
||||||
|
predictions = [predictions]
|
||||||
|
|
||||||
|
if len(predictions) < len(dep_outputs):
|
||||||
|
name = (
|
||||||
|
f" ({block_fn.name})"
|
||||||
|
if block_fn.name and block_fn.name != "<lambda>"
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
wanted_args = []
|
||||||
|
received_args = []
|
||||||
|
for output_id in dep_outputs:
|
||||||
|
block = self.blocks[output_id]
|
||||||
|
wanted_args.append(str(block))
|
||||||
|
for pred in predictions:
|
||||||
|
if isinstance(pred, str):
|
||||||
|
v = f'"{pred}"'
|
||||||
|
else:
|
||||||
|
v = str(pred)
|
||||||
|
received_args.append(v)
|
||||||
|
|
||||||
|
wanted = ", ".join(wanted_args)
|
||||||
|
received = ", ".join(received_args)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"""An event handler{name} didn't receive enough output values (needed: {len(dep_outputs)}, received: {len(predictions)}).
|
||||||
|
Wanted outputs:
|
||||||
|
[{wanted}]
|
||||||
|
Received outputs:
|
||||||
|
[{received}]"""
|
||||||
|
)
|
||||||
|
|
||||||
def postprocess_data(
|
def postprocess_data(
|
||||||
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
|
self, fn_index: int, predictions: List | Dict, state: Dict[int, Any]
|
||||||
):
|
):
|
||||||
@ -1015,6 +1097,8 @@ class Blocks(BlockContext):
|
|||||||
predictions,
|
predictions,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.validate_outputs(fn_index, predictions) # type: ignore
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for i, output_id in enumerate(dependency["outputs"]):
|
for i, output_id in enumerate(dependency["outputs"]):
|
||||||
try:
|
try:
|
||||||
|
@ -561,10 +561,53 @@ class TestBlocksPostprocessing:
|
|||||||
button.click(lambda x: x, textbox1, [textbox1, textbox2])
|
button.click(lambda x: x, textbox1, [textbox1, textbox2])
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="Number of output components does not match number of values returned from from function <lambda>",
|
match=r'An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
|
||||||
):
|
):
|
||||||
demo.postprocess_data(fn_index=0, predictions=["test"], state={})
|
demo.postprocess_data(fn_index=0, predictions=["test"], state={})
|
||||||
|
|
||||||
|
def test_error_raised_if_num_outputs_mismatch_with_function_name(self):
|
||||||
|
def infer(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
textbox1 = gr.Textbox()
|
||||||
|
textbox2 = gr.Textbox()
|
||||||
|
button = gr.Button()
|
||||||
|
button.click(infer, textbox1, [textbox1, textbox2])
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r'An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[textbox, textbox\]\nReceived outputs:\n \["test"\]',
|
||||||
|
):
|
||||||
|
demo.postprocess_data(fn_index=0, predictions=["test"], state={})
|
||||||
|
|
||||||
|
def test_error_raised_if_num_outputs_mismatch_single_output(self):
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
num1 = gr.Number()
|
||||||
|
num2 = gr.Number()
|
||||||
|
btn = gr.Button(value="1")
|
||||||
|
btn.click(lambda a: a, num1, [num1, num2])
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:\n \[number, number\]\nReceived outputs:\n \[1\]",
|
||||||
|
):
|
||||||
|
demo.postprocess_data(fn_index=0, predictions=1, state={})
|
||||||
|
|
||||||
|
def test_error_raised_if_num_outputs_mismatch_tuple_output(self):
|
||||||
|
def infer(a, b):
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
num1 = gr.Number()
|
||||||
|
num2 = gr.Number()
|
||||||
|
num3 = gr.Number()
|
||||||
|
btn = gr.Button(value="1")
|
||||||
|
btn.click(infer, num1, [num1, num2, num3])
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r"An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:\n \[number, number, number\]\nReceived outputs:\n \[1, 2\]",
|
||||||
|
):
|
||||||
|
demo.postprocess_data(fn_index=0, predictions=(1, 2), state={})
|
||||||
|
|
||||||
|
|
||||||
class TestCallFunction:
|
class TestCallFunction:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user