Deprecation error if multiple functions are passed into fn parameter in Interface (#1623)

* stack deprection

* formatting

* rewrite parallel

* formatting'
This commit is contained in:
Abubakar Abid 2022-06-22 23:48:42 -07:00 committed by GitHub
parent 4628ef0a8b
commit 7a0f6b1dd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 6 deletions

View File

@ -106,7 +106,7 @@ class Interface(Blocks):
def __init__(
self,
fn: Callable | List[Callable],
fn: Callable,
inputs: Optional[str | Component | List[str | Component]],
outputs: Optional[str | Component | List[str | Component]],
examples: Optional[List[Any] | List[List[Any]] | str] = None,
@ -172,6 +172,12 @@ class Interface(Blocks):
if not isinstance(fn, list):
fn = [fn]
else:
raise DeprecationWarning(
"The `fn` parameter only accepts a single function, support for a list "
"of functions has been deprecated. Please use gradio.mix.Parallel "
"instead."
)
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):

View File

@ -31,8 +31,18 @@ class Parallel(gradio.Interface):
fns.extend(io.predict)
outputs.extend(io.output_components)
def parallel_fn(*args):
return_values = []
for fn in fns:
value = fn(*args)
if isinstance(value, tuple):
return_values.extend(value)
else:
return_values.append(value)
return return_values
kwargs = {
"fn": fns,
"fn": parallel_fn,
"inputs": interfaces[0].input_components,
"outputs": outputs,
"_repeat_outputs_per_model": False,

View File

@ -15,8 +15,8 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestSeries(unittest.TestCase):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.outputs.Textbox())
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.outputs.Textbox())
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.Textbox())
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.Textbox())
series = mix.Series(io1, io2)
self.assertEqual(series.process(["Hello"]), ["Hello World!"])
@ -33,13 +33,23 @@ class TestSeries(unittest.TestCase):
class TestParallel(unittest.TestCase):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.outputs.Textbox())
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.outputs.Textbox())
io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.Textbox())
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox())
parallel = mix.Parallel(io1, io2)
self.assertEqual(
parallel.process(["Hello"]), ["Hello World 1!", "Hello World 2!"]
)
def test_multiple_return_in_interface(self):
io1 = gr.Interface(
lambda x: (x, x + x), "textbox", [gr.Textbox(), gr.Textbox()]
)
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox())
parallel = mix.Parallel(io1, io2)
self.assertEqual(
parallel.process(["Hello"]), ["Hello", "HelloHello", "Hello World 2!"]
)
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
io2 = gr.Interface.load("spaces/abidlabs/english2german")