mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
expanded testing for outputs
This commit is contained in:
parent
0a2f05d7f3
commit
7f6b7033e1
@ -141,7 +141,7 @@ class Label(OutputComponent):
|
||||
else:
|
||||
return y['label']
|
||||
elif self.type == "confidences" or self.type == "auto":
|
||||
if 'confidences' in y.keys() and isinstance(y['confidences'], list):
|
||||
if ('confidences' in y.keys()) and isinstance(y['confidences'], list):
|
||||
return {k['label']:k['confidence'] for k in y['confidences']}
|
||||
else:
|
||||
return y
|
||||
|
@ -2,8 +2,23 @@ import unittest
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import shutil
|
||||
|
||||
|
||||
class OutputComponent(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
output = gr.outputs.OutputComponent(label="Test Input")
|
||||
self.assertEqual(output.postprocess("Hello World!"), "Hello World!")
|
||||
self.assertEqual(output.deserialize(1), 1)
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.outputs.Textbox(type="unknown")
|
||||
wrong_type.postprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["o"])
|
||||
@ -17,7 +32,10 @@ class TestLabel(unittest.TestCase):
|
||||
label_output = gr.outputs.Label()
|
||||
label = label_output.postprocess(y)
|
||||
self.assertDictEqual(label, {"label": "happy"})
|
||||
|
||||
self.assertEqual(label_output.deserialize(y), y)
|
||||
self.assertEqual(label_output.deserialize(label), y)
|
||||
to_save = label_output.save_flagged("flagged", "label_output", label, None)
|
||||
self.assertEqual(to_save, y)
|
||||
y = {
|
||||
3: 0.7,
|
||||
1: 0.2,
|
||||
@ -33,6 +51,24 @@ class TestLabel(unittest.TestCase):
|
||||
{"label": 0, "confidence": 0.1},
|
||||
]
|
||||
})
|
||||
label_output = gr.outputs.Label(num_top_classes=2)
|
||||
label = label_output.postprocess(y)
|
||||
self.assertDictEqual(label, {
|
||||
"label": 3,
|
||||
"confidences": [
|
||||
{"label": 3, "confidence": 0.7},
|
||||
{"label": 1, "confidence": 0.2},
|
||||
]
|
||||
})
|
||||
with self.assertRaises(ValueError):
|
||||
label_output.postprocess([1, 2, 3])
|
||||
|
||||
to_save = label_output.save_flagged("flagged", "label_output", label, None)
|
||||
self.assertEqual(to_save, '{"3": 0.7, "1": 0.2}')
|
||||
self.assertEqual(label_output.restore_flagged(to_save), {"3": 0.7, "1": 0.2})
|
||||
with self.assertRaises(ValueError):
|
||||
label_output = gr.outputs.Label(type="unknown")
|
||||
label_output.deserialize([1, 2, 3])
|
||||
|
||||
def test_in_interface(self):
|
||||
x_img = gr.test_data.BASE64_IMAGE
|
||||
@ -58,12 +94,30 @@ class TestLabel(unittest.TestCase):
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
y_img = gr.processing_utils.decode_base64_to_image(gr.test_data.BASE64_IMAGE)
|
||||
image_output = gr.outputs.Image()
|
||||
self.assertTrue(image_output.postprocess(y_img).startswith("data:image/png;base64,iVBORw0KGgoAAA"))
|
||||
self.assertTrue(image_output.postprocess(np.array(y_img)).startswith("data:image/png;base64,iVBORw0KGgoAAA"))
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
plot_output = gr.outputs.Image(plot=True)
|
||||
|
||||
xpoints = np.array([0, 6])
|
||||
ypoints = np.array([0, 250])
|
||||
fig = plt.figure()
|
||||
p = plt.plot(xpoints, ypoints)
|
||||
self.assertTrue(plot_output.postprocess(fig).startswith("data:image/png;base64,"))
|
||||
with self.assertRaises(ValueError):
|
||||
image_output.postprocess([1, 2, 3])
|
||||
image_output = gr.outputs.Image(type="numpy")
|
||||
self.assertTrue(image_output.postprocess(y_img).startswith("data:image/png;base64,"))
|
||||
to_save = image_output.save_flagged("flagged", "image_output", gr.test_data.BASE64_IMAGE, None)
|
||||
self.assertEqual("image_output/0.png", to_save)
|
||||
to_save = image_output.save_flagged("flagged", "image_output", gr.test_data.BASE64_IMAGE, None)
|
||||
self.assertEqual("image_output/1.png", to_save)
|
||||
shutil.rmtree('flagged')
|
||||
|
||||
def test_in_interface(self):
|
||||
def generate_noise(width, height):
|
||||
@ -72,7 +126,33 @@ class TestImage(unittest.TestCase):
|
||||
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
|
||||
self.assertTrue(iface.process([10, 20])[0][0].startswith("data:image/png;base64"))
|
||||
|
||||
|
||||
class TestVideo(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
y_vid = "test/test_files/video_sample.mp4"
|
||||
video_output = gr.outputs.Video()
|
||||
self.assertTrue(video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,"))
|
||||
self.assertTrue(video_output.deserialize(gr.test_data.BASE64_VIDEO["data"]).endswith(".mp4"))
|
||||
to_save = video_output.save_flagged("flagged", "video_output", gr.test_data.BASE64_VIDEO, None)
|
||||
self.assertEqual("video_output/0.mp4", to_save)
|
||||
to_save = video_output.save_flagged("flagged", "video_output", gr.test_data.BASE64_VIDEO, None)
|
||||
self.assertEqual("video_output/1.mp4", to_save)
|
||||
shutil.rmtree('flagged')
|
||||
|
||||
|
||||
class TestKeyValues(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
kv_output = gr.outputs.KeyValues()
|
||||
kv_dict = {"a": 1, "b": 2}
|
||||
kv_list = [("a", 1), ("b", 2)]
|
||||
self.assertEqual(kv_output.postprocess(kv_dict), kv_list)
|
||||
self.assertEqual(kv_output.postprocess(kv_list), kv_list)
|
||||
with self.assertRaises(ValueError):
|
||||
kv_output.postprocess(0)
|
||||
to_save = kv_output.save_flagged("flagged", "kv_output", kv_list, None)
|
||||
self.assertEqual(to_save, '[["a", 1], ["b", 2]]')
|
||||
self.assertEqual(kv_output.restore_flagged(to_save), [["a", 1], ["b", 2]])
|
||||
|
||||
def test_in_interface(self):
|
||||
def letter_distribution(word):
|
||||
dist = {}
|
||||
@ -84,7 +164,23 @@ class TestKeyValues(unittest.TestCase):
|
||||
self.assertListEqual(iface.process(["alpaca"])[0][0], [
|
||||
("a", 3), ("l", 1), ("p", 1), ("c", 1)])
|
||||
|
||||
|
||||
class TestHighlightedText(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
ht_output = gr.outputs.HighlightedText(color_map={"pos": "green", "neg": "red"})
|
||||
self.assertEqual(ht_output.get_template_context(), {
|
||||
'color_map': {'pos': 'green', 'neg': 'red'},
|
||||
'name': 'highlightedtext',
|
||||
'label': None
|
||||
})
|
||||
ht = {
|
||||
"pos": "Hello ",
|
||||
"neg": "World"
|
||||
}
|
||||
to_save = ht_output.save_flagged("flagged", "ht_output", ht, None)
|
||||
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
|
||||
self.assertEqual(ht_output.restore_flagged(to_save), {"pos": "Hello ", "neg": "World"})
|
||||
|
||||
def test_in_interface(self):
|
||||
def highlight_vowels(sentence):
|
||||
phrases, cur_phrase = [], ""
|
||||
@ -111,6 +207,19 @@ class TestAudio(unittest.TestCase):
|
||||
y_audio = gr.processing_utils.decode_base64_to_file(gr.test_data.BASE64_AUDIO["data"])
|
||||
audio_output = gr.outputs.Audio(type="file")
|
||||
self.assertTrue(audio_output.postprocess(y_audio.name).startswith("data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"))
|
||||
self.assertEqual(audio_output.get_template_context(), {
|
||||
'name': 'audio',
|
||||
'label': None
|
||||
})
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.outputs.Audio(type="unknown")
|
||||
wrong_type.postprocess(y_audio.name)
|
||||
self.assertTrue(audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav"))
|
||||
to_save = audio_output.save_flagged("flagged", "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
|
||||
self.assertEqual("audio_output/0.wav", to_save)
|
||||
to_save = audio_output.save_flagged("flagged", "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
|
||||
self.assertEqual("audio_output/1.wav", to_save)
|
||||
shutil.rmtree('flagged')
|
||||
|
||||
def test_in_interface(self):
|
||||
def generate_noise(duration):
|
||||
@ -121,6 +230,17 @@ class TestAudio(unittest.TestCase):
|
||||
|
||||
|
||||
class TestJSON(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
js_output = gr.outputs.JSON()
|
||||
self.assertTrue(js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"')
|
||||
js = {
|
||||
"pos": "Hello ",
|
||||
"neg": "World"
|
||||
}
|
||||
to_save = js_output.save_flagged("flagged", "js_output", js, None)
|
||||
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
|
||||
self.assertEqual(js_output.restore_flagged(to_save), {"pos": "Hello ", "neg": "World"})
|
||||
|
||||
def test_in_interface(self):
|
||||
def get_avg_age_per_gender(data):
|
||||
return {
|
||||
@ -165,6 +285,12 @@ class TestFile(unittest.TestCase):
|
||||
self.assertDictEqual(iface.process(["hello world"])[0][0], {
|
||||
'name': 'test.txt', 'size': 11, 'data': 'aGVsbG8gd29ybGQ='
|
||||
})
|
||||
file_output = gr.outputs.File()
|
||||
to_save = file_output.save_flagged("flagged", "file_output", gr.test_data.BASE64_FILE, None)
|
||||
self.assertEqual("file_output/0.pdf", to_save)
|
||||
to_save = file_output.save_flagged("flagged", "file_output", gr.test_data.BASE64_FILE, None)
|
||||
self.assertEqual("file_output/1.pdf", to_save)
|
||||
shutil.rmtree('flagged')
|
||||
|
||||
|
||||
class TestDataframe(unittest.TestCase):
|
||||
@ -178,7 +304,20 @@ class TestDataframe(unittest.TestCase):
|
||||
[[2, True], [3, True], [4, False]], columns=["num", "prime"]))
|
||||
self.assertDictEqual(output,
|
||||
{"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]})
|
||||
|
||||
self.assertEqual(dataframe_output.get_template_context(), {
|
||||
'headers': None,
|
||||
'max_rows': 20,
|
||||
'max_cols': None,
|
||||
'overflow_row_behaviour': 'paginate',
|
||||
'name': 'dataframe',
|
||||
'label': None
|
||||
})
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.outputs.Dataframe(type="unknown")
|
||||
wrong_type.postprocess(0)
|
||||
to_save = dataframe_output.save_flagged("flagged", "dataframe_output", output, None)
|
||||
self.assertEqual(to_save, '[[2, true], [3, true], [4, false]]')
|
||||
self.assertEqual(dataframe_output.restore_flagged(to_save), [[2, True], [3, True], [4, False]])
|
||||
|
||||
def test_in_interface(self):
|
||||
def check_odd(array):
|
||||
@ -198,6 +337,21 @@ class TestCarousel(unittest.TestCase):
|
||||
self.assertEqual(output, [['Hello World', gr.test_data.BASE64_IMAGE],
|
||||
['Bye World', gr.test_data.BASE64_IMAGE]])
|
||||
|
||||
carousel_output = gr.outputs.Carousel("text", label="Disease")
|
||||
output = carousel_output.postprocess([["Hello World"], ["Bye World"]])
|
||||
self.assertEqual(output, [['Hello World'], ['Bye World']])
|
||||
self.assertEqual(carousel_output.get_template_context(), {
|
||||
'components': [{'name': 'textbox', 'label': None}],
|
||||
'name': 'carousel',
|
||||
'label': 'Disease'
|
||||
})
|
||||
output = carousel_output.postprocess(["Hello World", "Bye World"])
|
||||
self.assertEqual(output, [['Hello World'], ['Bye World']])
|
||||
with self.assertRaises(ValueError):
|
||||
carousel_output.postprocess('Hello World!')
|
||||
to_save = carousel_output.save_flagged("flagged", "carousel_output", output, None)
|
||||
self.assertEqual(to_save, '[["Hello World"], ["Bye World"]]')
|
||||
|
||||
def test_in_interface(self):
|
||||
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
|
||||
def report(img):
|
||||
@ -217,6 +371,32 @@ class TestCarousel(unittest.TestCase):
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFGklEQVR4nNXaMWwbVRzH8U8RUjMU1RIDaZeeGKADKJbaIWJozRYJWrKB6NBUKoIxS8VWUlaQCFurIpEOgMoCIgUhpIqwkA61SBUGmslZAkhUSlVEG5YyHOee7bvnu/MlgZ+s5Pm987uvn/+//7t79/bw0P9Qj23/KZa5wp16O92zneO9xkKK+AVe5slaut4m7jUWWctqqoe+du47XOPHYYeNSl8j932us1jmI9Xp6+K+zjX+qvTZKvSjc9/iah3pohz9KNwB81VWUfpq3AXNV1nD6ctyVzBfZYXoS3Ff43pV81VWNn1B7mUWa5+ry6iffij3dpivsh7RB7jv8DkrO4hVRFuc5+nHs9rus8j1nUYartu0OZPJvSvmG6oNltMGS3Pvuvky9QfL/NpXG3Ovs8DGjjOFdY92XkqIuTvM8QxHeGLnwHK1xc+s8nfeEek4WWPtP0B/m+UAcaxBX+4i/QZL/JnVtI9TnOJY/D4zD9px+mzzJXqTU30YedyxdoA+ZD7wDG8N1vZxH6fDem/lNtHH5mvntB7hJO9xNLM5zT3BElhgbpvpV2nnmO8A53gRfJV3rvS6TyMpzNDhYw4NHL/GZyxxrypxh0/zM8Y+ribQIQXWq2bqpt9gke9yMsbr7OPZgj9m2JeYYZp55rnb2xRHzhGeY2+wk7D5TvAWB7ldhDjWUG40mGM2h77NKs/n0IfNh8t5zgur+Lpmgzk6vMP+3qa/afMZbbZS9atJ5aAOJH9LQT8Ky7LrsY1i9LfzzbePC3zDCQ6WOfUG5ytzx2oMo/8hx3wn+IaTKAx9k3u8x0tJmma09e9GPn2ezpXM/Ru0OcanfQ1p7hU2y3QaqzGM/giXaaEk9Cf5Vyw93HeZrYQuRf9KqvIAl/mIozxbprebnOX9wBF9cXKFiLkR6OdTb98tn+PavMobwdRJVnzf5cII9FGqXAp6I2EttFaT58sR6UvpHhd5tdRnwvkkTV+74sk/Jr6UkzdzVSQPdukXquDl6ntwKZA0Aiqev9c5UxP9BmcL4zb5kpm+2rLzzoj033Oel4ami0RNWszTSGaAR3qYnj/L6BAf83Dg1dVPqdeJSqeYTpVnk8ISD0eZ54uP/VeVHlE0ewe0kxQa8b/K451Weuy7+prLySVrBR0Gp/kAzNFipXuWWrhjHWKipq4wzv7UWMylo7HI/U5xrQ+sAlTWGGimajbTzTuwj6OaxvktKa+ADvPd5x8x93EWei8tdl0R6LCUzNYRm3zJt3qf79zq/V12SxFTPMWl1JBHnKbBdPyV+tardlfjTKWWa6IU9xTT6WFNc2/SSnJLk4ilmi4GGzRSCTjzgNbAwLX4BczxZuLUf9WNkyWW2GKFsVQXt0ambxLxAHFo9mqMSSZzVo6aTPUR93E/4AY3khP0qTJ9g/Fk2CZZ6e0/xsokjphOLVn2q++5a+30hxNojDGeREuMlXkfHUd5FO4383lxXfRd0OOscDcJ2amstVJJlDcL9Bx6zj06fUSH0ywwy4fM5oxxN8ozQjlTgXl+jBaTOfQTHA5+sa5mkERqJnQzz3wBDb0+CdDv5Xj+F9OLsplFFoXNF1CpfTOByNnKaoro8AUtIg6kbtqjpLKiKuyvKkI/Tiu5nNhPg3WmmShlvoAq72cLuzbW71xMyg3eZnLwNrGaRtw/OJT+Ch3GknWScuYLqJb9muHIafBaTsKurhr3xw7SRyOaL6Da9yM/4Fs6tZgvoG3dt76N+gfaDbBaHMV3YgAAAABJRU5ErkJggg==']]])
|
||||
|
||||
|
||||
class TestTimeseries(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
timeseries_output = gr.outputs.Timeseries(label="Disease")
|
||||
self.assertEqual(timeseries_output.get_template_context(), {
|
||||
'x': None, 'y': None, 'name': 'timeseries', 'label': 'Disease'
|
||||
})
|
||||
data = {'Name': ['Tom', 'nick', 'krish', 'jack'], 'Age': [20, 21, 19, 18]}
|
||||
df = pd.DataFrame(data)
|
||||
self.assertEqual(timeseries_output.postprocess(df),{'headers': ['Name', 'Age'],
|
||||
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
|
||||
['jack', 18]]})
|
||||
|
||||
timeseries_output = gr.outputs.Timeseries(y="Age", label="Disease")
|
||||
output = timeseries_output.postprocess(df)
|
||||
self.assertEqual(output, {'headers': ['Name', 'Age'],
|
||||
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
|
||||
['jack', 18]]})
|
||||
|
||||
to_save = timeseries_output.save_flagged("flagged", "timeseries_output", output, None)
|
||||
self.assertEqual(to_save, '{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
|
||||
'["jack", 18]]}')
|
||||
self.assertEqual(timeseries_output.restore_flagged(to_save), {"headers": ["Name", "Age"],
|
||||
"data": [["Tom", 20], ["nick", 21],
|
||||
["krish", 19], ["jack", 18]]})
|
||||
|
||||
|
||||
class TestNames(unittest.TestCase):
|
||||
def test_no_duplicate_uncased_names(self): # this ensures that get_input_instance() works correctly when instantiating from components
|
||||
subclasses = gr.outputs.OutputComponent.__subclasses__()
|
||||
|
Loading…
Reference in New Issue
Block a user