From 7a0f6b1dd218fc39cdfc5ab565fea1f253c5db9c Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 22 Jun 2022 23:48:42 -0700 Subject: [PATCH] Deprecation error if multiple functions are passed into `fn` parameter in `Interface` (#1623) * stack deprection * formatting * rewrite parallel * formatting' --- gradio/interface.py | 8 +++++++- gradio/mix.py | 12 +++++++++++- test/test_mix.py | 18 ++++++++++++++---- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index b709309757..451538b08f 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -106,7 +106,7 @@ class Interface(Blocks): def __init__( self, - fn: Callable | List[Callable], + fn: Callable, inputs: Optional[str | Component | List[str | Component]], outputs: Optional[str | Component | List[str | Component]], examples: Optional[List[Any] | List[List[Any]] | str] = None, @@ -172,6 +172,12 @@ class Interface(Blocks): if not isinstance(fn, list): fn = [fn] + else: + raise DeprecationWarning( + "The `fn` parameter only accepts a single function, support for a list " + "of functions has been deprecated. Please use gradio.mix.Parallel " + "instead." + ) if not isinstance(inputs, list): inputs = [inputs] if not isinstance(outputs, list): diff --git a/gradio/mix.py b/gradio/mix.py index 9eb6e9bf73..9df8580fbe 100644 --- a/gradio/mix.py +++ b/gradio/mix.py @@ -31,8 +31,18 @@ class Parallel(gradio.Interface): fns.extend(io.predict) outputs.extend(io.output_components) + def parallel_fn(*args): + return_values = [] + for fn in fns: + value = fn(*args) + if isinstance(value, tuple): + return_values.extend(value) + else: + return_values.append(value) + return return_values + kwargs = { - "fn": fns, + "fn": parallel_fn, "inputs": interfaces[0].input_components, "outputs": outputs, "_repeat_outputs_per_model": False, diff --git a/test/test_mix.py b/test/test_mix.py index 5857d610b0..87d5924fc2 100644 --- a/test/test_mix.py +++ b/test/test_mix.py @@ -15,8 +15,8 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" class TestSeries(unittest.TestCase): def test_in_interface(self): - io1 = gr.Interface(lambda x: x + " World", "textbox", gr.outputs.Textbox()) - io2 = gr.Interface(lambda x: x + "!", "textbox", gr.outputs.Textbox()) + io1 = gr.Interface(lambda x: x + " World", "textbox", gr.Textbox()) + io2 = gr.Interface(lambda x: x + "!", "textbox", gr.Textbox()) series = mix.Series(io1, io2) self.assertEqual(series.process(["Hello"]), ["Hello World!"]) @@ -33,13 +33,23 @@ class TestSeries(unittest.TestCase): class TestParallel(unittest.TestCase): def test_in_interface(self): - io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.outputs.Textbox()) - io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.outputs.Textbox()) + io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.Textbox()) + io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox()) parallel = mix.Parallel(io1, io2) self.assertEqual( parallel.process(["Hello"]), ["Hello World 1!", "Hello World 2!"] ) + def test_multiple_return_in_interface(self): + io1 = gr.Interface( + lambda x: (x, x + x), "textbox", [gr.Textbox(), gr.Textbox()] + ) + io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox()) + parallel = mix.Parallel(io1, io2) + self.assertEqual( + parallel.process(["Hello"]), ["Hello", "HelloHello", "Hello World 2!"] + ) + def test_with_external(self): io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish") io2 = gr.Interface.load("spaces/abidlabs/english2german")