mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
Merge branch 'master' into flagging-spaces
This commit is contained in:
commit
38b4a4a66a
@ -760,7 +760,7 @@ class Image(InputComponent):
|
|||||||
return processing_utils.encode_url_or_file_to_base64(x)
|
return processing_utils.encode_url_or_file_to_base64(x)
|
||||||
elif self.type == "file":
|
elif self.type == "file":
|
||||||
return processing_utils.encode_url_or_file_to_base64(x.name)
|
return processing_utils.encode_url_or_file_to_base64(x.name)
|
||||||
elif self.type == "numpy" or "pil":
|
elif self.type in ("numpy", "pil"):
|
||||||
if self.type == "numpy":
|
if self.type == "numpy":
|
||||||
x = PIL.Image.fromarray(np.uint8(x)).convert('RGB')
|
x = PIL.Image.fromarray(np.uint8(x)).convert('RGB')
|
||||||
fmt = x.format
|
fmt = x.format
|
||||||
|
@ -45,18 +45,20 @@ class Series(gradio.Interface):
|
|||||||
|
|
||||||
def connected_fn(*data): # Run each function with the appropriate preprocessing and postprocessing
|
def connected_fn(*data): # Run each function with the appropriate preprocessing and postprocessing
|
||||||
for idx, io in enumerate(interfaces):
|
for idx, io in enumerate(interfaces):
|
||||||
# skip preprocessing for first interface since the compound interface will include it
|
# skip preprocessing for first interface since the Series interface will include it
|
||||||
if idx > 0:
|
if idx > 0 and not(io.api_mode):
|
||||||
data = [input_interface.preprocess(data[i]) for i, input_interface in enumerate(io.input_components)]
|
data = [input_component.preprocess(data[i]) for i, input_component in enumerate(io.input_components)]
|
||||||
|
|
||||||
# run all of predictions sequentially
|
# run all of predictions sequentially
|
||||||
predictions = []
|
predictions = []
|
||||||
for predict_fn in io.predict:
|
for predict_fn in io.predict:
|
||||||
prediction = predict_fn(*data)
|
prediction = predict_fn(*data)
|
||||||
predictions.append(prediction)
|
predictions.append(prediction)
|
||||||
data = predictions
|
data = predictions
|
||||||
# skip postprocessing for final interface since the compound interface will include it
|
# skip postprocessing for final interface since the Series interface will include it
|
||||||
if idx < len(interfaces) - 1:
|
if idx < len(interfaces) - 1 and not(io.api_mode):
|
||||||
data = [output_interface.postprocess(data[i]) for i, output_interface in enumerate(io.output_components)]
|
data = [output_component.postprocess(data[i]) for i, output_component in enumerate(io.output_components)]
|
||||||
|
|
||||||
return data[0]
|
return data[0]
|
||||||
|
|
||||||
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
|
connected_fn.__name__ = " => ".join([f[0].__name__ for f in fns])
|
||||||
@ -65,6 +67,7 @@ class Series(gradio.Interface):
|
|||||||
"fn": connected_fn,
|
"fn": connected_fn,
|
||||||
"inputs": interfaces[0].input_components,
|
"inputs": interfaces[0].input_components,
|
||||||
"outputs": interfaces[-1].output_components,
|
"outputs": interfaces[-1].output_components,
|
||||||
|
"api_mode": interfaces[0].api_mode, # TODO(abidlabs): allow mixing api_mode and non-api_mode interfaces
|
||||||
}
|
}
|
||||||
kwargs.update(options)
|
kwargs.update(options)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -315,6 +315,8 @@ class TestImage(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
wrong_type = gr.inputs.Image(type="unknown")
|
wrong_type = gr.inputs.Image(type="unknown")
|
||||||
wrong_type.preprocess(img)
|
wrong_type.preprocess(img)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
wrong_type = gr.inputs.Image(type="unknown")
|
||||||
wrong_type.serialize("test/test_files/bus.png", False)
|
wrong_type.serialize("test/test_files/bus.png", False)
|
||||||
img_pil = PIL.Image.open('test/test_files/bus.png')
|
img_pil = PIL.Image.open('test/test_files/bus.png')
|
||||||
image_input = gr.inputs.Image(type="numpy")
|
image_input = gr.inputs.Image(type="numpy")
|
||||||
|
@ -4,6 +4,11 @@ from gradio import mix
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
WARNING: Some of these tests have an external dependency: namely that Hugging Face's Hub and Space APIs do not change, and they keep their most famous models up. So if, e.g. Spaces is down, then these test will not pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +20,14 @@ class TestSeries(unittest.TestCase):
|
|||||||
series = mix.Series(io1, io2)
|
series = mix.Series(io1, io2)
|
||||||
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
|
self.assertEqual(series.process(["Hello"])[0], ["Hello World!"])
|
||||||
|
|
||||||
|
def test_with_external(self):
|
||||||
|
io1 = gr.Interface.load("spaces/abidlabs/image-identity")
|
||||||
|
io2 = gr.Interface.load("spaces/abidlabs/image-classifier")
|
||||||
|
series = mix.Series(io1, io2)
|
||||||
|
output = series("test/test_data/lion.jpg")
|
||||||
|
self.assertGreater(output['lion'], 0.5)
|
||||||
|
|
||||||
|
|
||||||
class TestParallel(unittest.TestCase):
|
class TestParallel(unittest.TestCase):
|
||||||
def test_in_interface(self):
|
def test_in_interface(self):
|
||||||
io1 = gr.Interface(lambda x: x + " World 1!", "textbox",
|
io1 = gr.Interface(lambda x: x + " World 1!", "textbox",
|
||||||
@ -24,6 +37,13 @@ class TestParallel(unittest.TestCase):
|
|||||||
parallel = mix.Parallel(io1, io2)
|
parallel = mix.Parallel(io1, io2)
|
||||||
self.assertEqual(parallel.process(["Hello"])[0], ["Hello World 1!",
|
self.assertEqual(parallel.process(["Hello"])[0], ["Hello World 1!",
|
||||||
"Hello World 2!"])
|
"Hello World 2!"])
|
||||||
|
def test_with_external(self):
|
||||||
|
io1 = gr.Interface.load("spaces/abidlabs/english_to_spanish")
|
||||||
|
io2 = gr.Interface.load("spaces/abidlabs/english2german")
|
||||||
|
parallel = mix.Parallel(io1, io2)
|
||||||
|
hello_es, hello_de = parallel("Hello")
|
||||||
|
self.assertIn("hola", hello_es.lower())
|
||||||
|
self.assertIn("hallo", hello_de.lower())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user