Merge branch 'master' into flagging-spaces

This commit is contained in:
Abubakar Abid 2021-11-16 12:08:58 -06:00
commit 38b4a4a66a
4 changed files with 32 additions and 7 deletions

View File

@ -760,7 +760,7 @@ class Image(InputComponent):
return processing_utils.encode_url_or_file_to_base64(x)
elif self.type == "file":
return processing_utils.encode_url_or_file_to_base64(x.name)
elif self.type == "numpy" or "pil":
elif self.type in ("numpy", "pil"):
if self.type == "numpy":
x = PIL.Image.fromarray(np.uint8(x)).convert('RGB')
fmt = x.format

View File

@ -45,18 +45,20 @@ class Series(gradio.Interface):
def connected_fn(*data): # Run each function with the appropriate preprocessing and postprocessing
for idx, io in enumerate(interfaces):
# skip preprocessing for first interface since the compound interface will include it
if idx > 0:
data = [input_interface.preprocess(data[i]) for i, input_interface in enumerate(io.input_components)]
# skip preprocessing for first interface since the Series interface will include it
if idx > 0 and not(io.api_mode):
data = [input_component.preprocess(data[i]) for i, input_component in enumerate(io.input_components)]
# run all of predictions sequentially
predictions = []
for predict_fn in io.predict:
prediction = predict_fn(*data)
predictions.append(prediction)
data = predictions
# skip postprocessing for final interface since the compound interface will include it
if idx < len(interfaces) - 1:
data = [output_interface.postprocess(data[i]) for i, output_interface in enumerate(io.output_components)]
# skip postprocessing for final interface since the Series interface will include it
if idx < len(interfaces) - 1 and not(io.api_mode):
data = [output_component.postprocess(data[i]) for i, output_component in enumerate(io.output_components)]
return data[0]
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
@ -65,6 +67,7 @@ class Series(gradio.Interface):
"fn": connected_fn,
"inputs": interfaces[0].input_components,
"outputs": interfaces[-1].output_components,
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): allow mixing api_mode and non-api_mode interfaces
}
kwargs.update(options)
super().__init__(**kwargs)

View File

@ -315,6 +315,8 @@ class TestImage(unittest.TestCase):
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Image(type="unknown")
wrong_type.preprocess(img)
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Image(type="unknown")
wrong_type.serialize("test/test_files/bus.png", False)
img_pil = PIL.Image.open('test/test_files/bus.png')
image_input = gr.inputs.Image(type="numpy")

View File

@ -4,6 +4,11 @@ from gradio import mix
import os
"""
WARNING: Some of these tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
"""
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -15,6 +20,14 @@ class TestSeries(unittest.TestCase):
series = mix.Series(io1, io2)
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/image-identity")
io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
series = mix.Series(io1, io2)
output = series("test/test_data/lion.jpg")
self.assertGreater(output['lion'], 0.5)
class TestParallel(unittest.TestCase):
def test_in_interface(self):
io1 = gr.Interface(lambda x: x + " World 1!", "textbox",
@ -24,6 +37,13 @@ class TestParallel(unittest.TestCase):
parallel = mix.Parallel(io1, io2)
self.assertEqual(parallel.process(["Hello"])[0], ["Hello World 1!",
"Hello World 2!"])
def test_with_external(self):
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
io2 = gr.Interface.load("spaces/abidlabs/english2german")
parallel = mix.Parallel(io1, io2)
hello_es, hello_de = parallel("Hello")
self.assertIn("hola", hello_es.lower())
self.assertIn("hallo", hello_de.lower())
if __name__ == '__main__':