Series working with external

This commit is contained in:
Abubakar Abid 2021-11-15 09:39:06 -05:00
parent f361f36625
commit 44e8769a5a
2 changed files with 29 additions and 6 deletions

View File

@ -45,18 +45,20 @@ class Series(gradio.Interface):
def connected_fn(*data): # Run each function with the appropriate preprocessing and postprocessing
for idx, io in enumerate(interfaces):
# skip preprocessing for first interface since the compound interface will include it
if idx > 0:
data = [input_interface.preprocess(data[i]) for i, input_interface in enumerate(io.input_components)]
# skip preprocessing for first interface since the Series interface will include it
if idx > 0 and not(io.api_mode):
data = [input_component.preprocess(data[i]) for i, input_component in enumerate(io.input_components)]
# run all of predictions sequentially
predictions = []
for predict_fn in io.predict:
prediction = predict_fn(*data)
predictions.append(prediction)
data = predictions
# skip postprocessing for final interface since the compound interface will include it
if idx < len(interfaces) - 1:
data = [output_interface.postprocess(data[i]) for i, output_interface in enumerate(io.output_components)]
# skip postprocessing for final interface since the Series interface will include it
if idx < len(interfaces) - 1 and not(io.api_mode):
data = [output_component.postprocess(data[i]) for i, output_component in enumerate(io.output_components)]
return data[0]
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
@ -65,6 +67,7 @@ class Series(gradio.Interface):
"fn": connected_fn,
"inputs": interfaces[0].input_components,
"outputs": interfaces[-1].output_components,
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): allow mixing api_mode and non-api_mode interfaces
}
kwargs.update(options)
super().__init__(**kwargs)

View File

@ -4,6 +4,11 @@ from gradio import mix
import os
"""
WARNING: Some of these tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
"""
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -15,6 +20,14 @@ class TestSeries(unittest.TestCase):
series = mix.Series(io1, io2)
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/image-identity")
io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
series = mix.Series(io1, io2)
output = series("test/test_data/lion.jpg")
self.assertGreater(output['lion'], 0.5)
class TestParallel(unittest.TestCase):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World 1!", "textbox",
@ -24,6 +37,13 @@ class TestParallel(unittest.TestCase):
parallel = mix.Parallel(io1, io2)
self.assertEqual(parallel.process(["Hello"])[0], ["Hello World 1!",
"Hello World 2!"])
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
io2 = gr.Interface.load("spaces/abidlabs/english2german")
parallel = mix.Parallel(io1, io2)
hello_es, hello_de = parallel("Hello")
self.assertIn("hola", hello_es.lower())
self.assertIn("hallo", hello_de.lower())
if __name__ == '__main__':