mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Deprecation error if multiple functions are passed into fn
parameter in Interface
(#1623)
* stack deprection * formatting * rewrite parallel * formatting'
This commit is contained in:
parent
4628ef0a8b
commit
7a0f6b1dd2
@ -106,7 +106,7 @@ class Interface(Blocks):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable | List[Callable],
|
||||
fn: Callable,
|
||||
inputs: Optional[str | Component | List[str | Component]],
|
||||
outputs: Optional[str | Component | List[str | Component]],
|
||||
examples: Optional[List[Any] | List[List[Any]] | str] = None,
|
||||
@ -172,6 +172,12 @@ class Interface(Blocks):
|
||||
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
else:
|
||||
raise DeprecationWarning(
|
||||
"The `fn` parameter only accepts a single function, support for a list "
|
||||
"of functions has been deprecated. Please use gradio.mix.Parallel "
|
||||
"instead."
|
||||
)
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(outputs, list):
|
||||
|
@ -31,8 +31,18 @@ class Parallel(gradio.Interface):
|
||||
fns.extend(io.predict)
|
||||
outputs.extend(io.output_components)
|
||||
|
||||
def parallel_fn(*args):
|
||||
return_values = []
|
||||
for fn in fns:
|
||||
value = fn(*args)
|
||||
if isinstance(value, tuple):
|
||||
return_values.extend(value)
|
||||
else:
|
||||
return_values.append(value)
|
||||
return return_values
|
||||
|
||||
kwargs = {
|
||||
"fn": fns,
|
||||
"fn": parallel_fn,
|
||||
"inputs": interfaces[0].input_components,
|
||||
"outputs": outputs,
|
||||
"_repeat_outputs_per_model": False,
|
||||
|
@ -15,8 +15,8 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
class TestSeries(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.outputs.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.outputs.Textbox())
|
||||
io1 = gr.Interface(lambda x: x + " World", "textbox", gr.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + "!", "textbox", gr.Textbox())
|
||||
series = mix.Series(io1, io2)
|
||||
self.assertEqual(series.process(["Hello"]), ["Hello World!"])
|
||||
|
||||
@ -33,13 +33,23 @@ class TestSeries(unittest.TestCase):
|
||||
|
||||
class TestParallel(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.outputs.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.outputs.Textbox())
|
||||
io1 = gr.Interface(lambda x: x + " World 1!", "textbox", gr.Textbox())
|
||||
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox())
|
||||
parallel = mix.Parallel(io1, io2)
|
||||
self.assertEqual(
|
||||
parallel.process(["Hello"]), ["Hello World 1!", "Hello World 2!"]
|
||||
)
|
||||
|
||||
def test_multiple_return_in_interface(self):
|
||||
io1 = gr.Interface(
|
||||
lambda x: (x, x + x), "textbox", [gr.Textbox(), gr.Textbox()]
|
||||
)
|
||||
io2 = gr.Interface(lambda x: x + " World 2!", "textbox", gr.Textbox())
|
||||
parallel = mix.Parallel(io1, io2)
|
||||
self.assertEqual(
|
||||
parallel.process(["Hello"]), ["Hello", "HelloHello", "Hello World 2!"]
|
||||
)
|
||||
|
||||
def test_with_external(self):
|
||||
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
|
||||
io2 = gr.Interface.load("spaces/abidlabs/english2german")
|
||||
|
Loading…
x
Reference in New Issue
Block a user