mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
Merge branch 'master' into flagging-spaces
This commit is contained in:
commit
38b4a4a66a
@ -760,7 +760,7 @@ class Image(InputComponent):
|
||||
return processing_utils.encode_url_or_file_to_base64(x)
|
||||
elif self.type == "file":
|
||||
return processing_utils.encode_url_or_file_to_base64(x.name)
|
||||
elif self.type == "numpy" or "pil":
|
||||
elif self.type in ("numpy", "pil"):
|
||||
if self.type == "numpy":
|
||||
x = PIL.Image.fromarray(np.uint8(x)).convert('RGB')
|
||||
fmt = x.format
|
||||
|
@ -45,18 +45,20 @@ class Series(gradio.Interface):
|
||||
|
||||
def connected_fn(*data): # Run each function with the appropriate preprocessing and postprocessing
|
||||
for idx, io in enumerate(interfaces):
|
||||
# skip preprocessing for first interface since the compound interface will include it
|
||||
if idx > 0:
|
||||
data = [input_interface.preprocess(data[i]) for i, input_interface in enumerate(io.input_components)]
|
||||
# skip preprocessing for first interface since the Series interface will include it
|
||||
if idx > 0 and not(io.api_mode):
|
||||
data = [input_component.preprocess(data[i]) for i, input_component in enumerate(io.input_components)]
|
||||
|
||||
# run all of predictions sequentially
|
||||
predictions = []
|
||||
for predict_fn in io.predict:
|
||||
prediction = predict_fn(*data)
|
||||
predictions.append(prediction)
|
||||
data = predictions
|
||||
# skip postprocessing for final interface since the compound interface will include it
|
||||
if idx < len(interfaces) - 1:
|
||||
data = [output_interface.postprocess(data[i]) for i, output_interface in enumerate(io.output_components)]
|
||||
# skip postprocessing for final interface since the Series interface will include it
|
||||
if idx < len(interfaces) - 1 and not(io.api_mode):
|
||||
data = [output_component.postprocess(data[i]) for i, output_component in enumerate(io.output_components)]
|
||||
|
||||
return data[0]
|
||||
|
||||
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
|
||||
@ -65,6 +67,7 @@ class Series(gradio.Interface):
|
||||
"fn": connected_fn,
|
||||
"inputs": interfaces[0].input_components,
|
||||
"outputs": interfaces[-1].output_components,
|
||||
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): allow mixing api_mode and non-api_mode interfaces
|
||||
}
|
||||
kwargs.update(options)
|
||||
super().__init__(**kwargs)
|
||||
|
@ -315,6 +315,8 @@ class TestImage(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Image(type="unknown")
|
||||
wrong_type.preprocess(img)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Image(type="unknown")
|
||||
wrong_type.serialize("test/test_files/bus.png", False)
|
||||
img_pil = PIL.Image.open('test/test_files/bus.png')
|
||||
image_input = gr.inputs.Image(type="numpy")
|
||||
|
@ -4,6 +4,11 @@ from gradio import mix
|
||||
import os
|
||||
|
||||
|
||||
"""
|
||||
WARNING: Some of these tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
|
||||
"""
|
||||
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
@ -15,6 +20,14 @@ class TestSeries(unittest.TestCase):
|
||||
series = mix.Series(io1, io2)
|
||||
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
|
||||
|
||||
def test_with_external(self):
|
||||
io1 = gr.Interface.load("spaces/abidlabs/image-identity")
|
||||
io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
|
||||
series = mix.Series(io1, io2)
|
||||
output = series("test/test_data/lion.jpg")
|
||||
self.assertGreater(output['lion'], 0.5)
|
||||
|
||||
|
||||
class TestParallel(unittest.TestCase):
|
||||
def test_in_interface(self):
|
||||
io1 = gr.Interface(lambda x: x + " World 1!", "textbox",
|
||||
@ -24,6 +37,13 @@ class TestParallel(unittest.TestCase):
|
||||
parallel = mix.Parallel(io1, io2)
|
||||
self.assertEqual(parallel.process(["Hello"])[0], ["Hello World 1!",
|
||||
"Hello World 2!"])
|
||||
def test_with_external(self):
|
||||
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
|
||||
io2 = gr.Interface.load("spaces/abidlabs/english2german")
|
||||
parallel = mix.Parallel(io1, io2)
|
||||
hello_es, hello_de = parallel("Hello")
|
||||
self.assertIn("hola", hello_es.lower())
|
||||
self.assertIn("hallo", hello_de.lower())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user