expanded testing for outputs

This commit is contained in:
aliabd 2021-11-04 13:07:08 -07:00
parent 0a2f05d7f3
commit 7f6b7033e1
2 changed files with 183 additions and 3 deletions

View File

@ -141,7 +141,7 @@ class Label(OutputComponent):
else:
return y['label']
elif self.type == "confidences" or self.type == "auto":
if 'confidences' in y.keys() and isinstance(y['confidences'], list):
if ('confidences' in y.keys()) and isinstance(y['confidences'], list):
return {k['label']:k['confidence'] for k in y['confidences']}
else:
return y

View File

@ -2,8 +2,23 @@ import unittest
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shutil
class OutputComponent(unittest.TestCase):
def test_as_component(self):
output = gr.outputs.OutputComponent(label="Test Input")
self.assertEqual(output.postprocess("Hello World!"), "Hello World!")
self.assertEqual(output.deserialize(1), 1)
class TestTextbox(unittest.TestCase):
def test_as_component(self):
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Textbox(type="unknown")
wrong_type.postprocess(0)
def test_in_interface(self):
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
self.assertEqual(iface.process(["Hello"])[0], ["o"])
@ -17,7 +32,10 @@ class TestLabel(unittest.TestCase):
label_output = gr.outputs.Label()
label = label_output.postprocess(y)
self.assertDictEqual(label, {"label": "happy"})
self.assertEqual(label_output.deserialize(y), y)
self.assertEqual(label_output.deserialize(label), y)
to_save = label_output.save_flagged("flagged", "label_output", label, None)
self.assertEqual(to_save, y)
y = {
3: 0.7,
1: 0.2,
@ -33,6 +51,24 @@ class TestLabel(unittest.TestCase):
{"label": 0, "confidence": 0.1},
]
})
label_output = gr.outputs.Label(num_top_classes=2)
label = label_output.postprocess(y)
self.assertDictEqual(label, {
"label": 3,
"confidences": [
{"label": 3, "confidence": 0.7},
{"label": 1, "confidence": 0.2},
]
})
with self.assertRaises(ValueError):
label_output.postprocess([1, 2, 3])
to_save = label_output.save_flagged("flagged", "label_output", label, None)
self.assertEqual(to_save, '{"3": 0.7, "1": 0.2}')
self.assertEqual(label_output.restore_flagged(to_save), {"3": 0.7, "1": 0.2})
with self.assertRaises(ValueError):
label_output = gr.outputs.Label(type="unknown")
label_output.deserialize([1, 2, 3])
def test_in_interface(self):
x_img = gr.test_data.BASE64_IMAGE
@ -58,12 +94,30 @@ class TestLabel(unittest.TestCase):
]
})
class TestImage(unittest.TestCase):
def test_as_component(self):
y_img = gr.processing_utils.decode_base64_to_image(gr.test_data.BASE64_IMAGE)
image_output = gr.outputs.Image()
self.assertTrue(image_output.postprocess(y_img).startswith("data:image/png;base64,iVBORw0KGgoAAA"))
self.assertTrue(image_output.postprocess(np.array(y_img)).startswith("data:image/png;base64,iVBORw0KGgoAAA"))
with self.assertWarns(DeprecationWarning):
plot_output = gr.outputs.Image(plot=True)
xpoints = np.array([0, 6])
ypoints = np.array([0, 250])
fig = plt.figure()
p = plt.plot(xpoints, ypoints)
self.assertTrue(plot_output.postprocess(fig).startswith("data:image/png;base64,"))
with self.assertRaises(ValueError):
image_output.postprocess([1, 2, 3])
image_output = gr.outputs.Image(type="numpy")
self.assertTrue(image_output.postprocess(y_img).startswith("data:image/png;base64,"))
to_save = image_output.save_flagged("flagged", "image_output", gr.test_data.BASE64_IMAGE, None)
self.assertEqual("image_output/0.png", to_save)
to_save = image_output.save_flagged("flagged", "image_output", gr.test_data.BASE64_IMAGE, None)
self.assertEqual("image_output/1.png", to_save)
shutil.rmtree('flagged')
def test_in_interface(self):
def generate_noise(width, height):
@ -72,7 +126,33 @@ class TestImage(unittest.TestCase):
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
self.assertTrue(iface.process([10, 20])[0][0].startswith("data:image/png;base64"))
class TestVideo(unittest.TestCase):
def test_as_component(self):
y_vid = "test/test_files/video_sample.mp4"
video_output = gr.outputs.Video()
self.assertTrue(video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,"))
self.assertTrue(video_output.deserialize(gr.test_data.BASE64_VIDEO["data"]).endswith(".mp4"))
to_save = video_output.save_flagged("flagged", "video_output", gr.test_data.BASE64_VIDEO, None)
self.assertEqual("video_output/0.mp4", to_save)
to_save = video_output.save_flagged("flagged", "video_output", gr.test_data.BASE64_VIDEO, None)
self.assertEqual("video_output/1.mp4", to_save)
shutil.rmtree('flagged')
class TestKeyValues(unittest.TestCase):
def test_as_component(self):
kv_output = gr.outputs.KeyValues()
kv_dict = {"a": 1, "b": 2}
kv_list = [("a", 1), ("b", 2)]
self.assertEqual(kv_output.postprocess(kv_dict), kv_list)
self.assertEqual(kv_output.postprocess(kv_list), kv_list)
with self.assertRaises(ValueError):
kv_output.postprocess(0)
to_save = kv_output.save_flagged("flagged", "kv_output", kv_list, None)
self.assertEqual(to_save, '[["a", 1], ["b", 2]]')
self.assertEqual(kv_output.restore_flagged(to_save), [["a", 1], ["b", 2]])
def test_in_interface(self):
def letter_distribution(word):
dist = {}
@ -84,7 +164,23 @@ class TestKeyValues(unittest.TestCase):
self.assertListEqual(iface.process(["alpaca"])[0][0], [
("a", 3), ("l", 1), ("p", 1), ("c", 1)])
class TestHighlightedText(unittest.TestCase):
def test_as_component(self):
ht_output = gr.outputs.HighlightedText(color_map={"pos": "green", "neg": "red"})
self.assertEqual(ht_output.get_template_context(), {
'color_map': {'pos': 'green', 'neg': 'red'},
'name': 'highlightedtext',
'label': None
})
ht = {
"pos": "Hello ",
"neg": "World"
}
to_save = ht_output.save_flagged("flagged", "ht_output", ht, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(ht_output.restore_flagged(to_save), {"pos": "Hello ", "neg": "World"})
def test_in_interface(self):
def highlight_vowels(sentence):
phrases, cur_phrase = [], ""
@ -111,6 +207,19 @@ class TestAudio(unittest.TestCase):
y_audio = gr.processing_utils.decode_base64_to_file(gr.test_data.BASE64_AUDIO["data"])
audio_output = gr.outputs.Audio(type="file")
self.assertTrue(audio_output.postprocess(y_audio.name).startswith("data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"))
self.assertEqual(audio_output.get_template_context(), {
'name': 'audio',
'label': None
})
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Audio(type="unknown")
wrong_type.postprocess(y_audio.name)
self.assertTrue(audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav"))
to_save = audio_output.save_flagged("flagged", "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
self.assertEqual("audio_output/0.wav", to_save)
to_save = audio_output.save_flagged("flagged", "audio_output", gr.test_data.BASE64_AUDIO["data"], None)
self.assertEqual("audio_output/1.wav", to_save)
shutil.rmtree('flagged')
def test_in_interface(self):
def generate_noise(duration):
@ -121,6 +230,17 @@ class TestAudio(unittest.TestCase):
class TestJSON(unittest.TestCase):
def test_as_component(self):
js_output = gr.outputs.JSON()
self.assertTrue(js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"')
js = {
"pos": "Hello ",
"neg": "World"
}
to_save = js_output.save_flagged("flagged", "js_output", js, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(js_output.restore_flagged(to_save), {"pos": "Hello ", "neg": "World"})
def test_in_interface(self):
def get_avg_age_per_gender(data):
return {
@ -165,6 +285,12 @@ class TestFile(unittest.TestCase):
self.assertDictEqual(iface.process(["hello world"])[0][0], {
'name': 'test.txt', 'size': 11, 'data': 'aGVsbG8gd29ybGQ='
})
file_output = gr.outputs.File()
to_save = file_output.save_flagged("flagged", "file_output", gr.test_data.BASE64_FILE, None)
self.assertEqual("file_output/0.pdf", to_save)
to_save = file_output.save_flagged("flagged", "file_output", gr.test_data.BASE64_FILE, None)
self.assertEqual("file_output/1.pdf", to_save)
shutil.rmtree('flagged')
class TestDataframe(unittest.TestCase):
@ -178,7 +304,20 @@ class TestDataframe(unittest.TestCase):
[[2, True], [3, True], [4, False]], columns=["num", "prime"]))
self.assertDictEqual(output,
{"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]})
self.assertEqual(dataframe_output.get_template_context(), {
'headers': None,
'max_rows': 20,
'max_cols': None,
'overflow_row_behaviour': 'paginate',
'name': 'dataframe',
'label': None
})
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Dataframe(type="unknown")
wrong_type.postprocess(0)
to_save = dataframe_output.save_flagged("flagged", "dataframe_output", output, None)
self.assertEqual(to_save, '[[2, true], [3, true], [4, false]]')
self.assertEqual(dataframe_output.restore_flagged(to_save), [[2, True], [3, True], [4, False]])
def test_in_interface(self):
def check_odd(array):
@ -198,6 +337,21 @@ class TestCarousel(unittest.TestCase):
self.assertEqual(output, [['Hello World', gr.test_data.BASE64_IMAGE],
['Bye World', gr.test_data.BASE64_IMAGE]])
carousel_output = gr.outputs.Carousel("text", label="Disease")
output = carousel_output.postprocess([["Hello World"], ["Bye World"]])
self.assertEqual(output, [['Hello World'], ['Bye World']])
self.assertEqual(carousel_output.get_template_context(), {
'components': [{'name': 'textbox', 'label': None}],
'name': 'carousel',
'label': 'Disease'
})
output = carousel_output.postprocess(["Hello World", "Bye World"])
self.assertEqual(output, [['Hello World'], ['Bye World']])
with self.assertRaises(ValueError):
carousel_output.postprocess('Hello World!')
to_save = carousel_output.save_flagged("flagged", "carousel_output", output, None)
self.assertEqual(to_save, '[["Hello World"], ["Bye World"]]')
def test_in_interface(self):
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
def report(img):
@ -217,6 +371,32 @@ class TestCarousel(unittest.TestCase):
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAD0AAABECAIAAAC9Laq3AAAFGklEQVR4nNXaMWwbVRzH8U8RUjMU1RIDaZeeGKADKJbaIWJozRYJWrKB6NBUKoIxS8VWUlaQCFurIpEOgMoCIgUhpIqwkA61SBUGmslZAkhUSlVEG5YyHOee7bvnu/MlgZ+s5Pm987uvn/+//7t79/bw0P9Qj23/KZa5wp16O92zneO9xkKK+AVe5slaut4m7jUWWctqqoe+du47XOPHYYeNSl8j932us1jmI9Xp6+K+zjX+qvTZKvSjc9/iah3pohz9KNwB81VWUfpq3AXNV1nD6ctyVzBfZYXoS3Ff43pV81VWNn1B7mUWa5+ry6iffij3dpivsh7RB7jv8DkrO4hVRFuc5+nHs9rus8j1nUYartu0OZPJvSvmG6oNltMGS3Pvuvky9QfL/NpXG3Ovs8DGjjOFdY92XkqIuTvM8QxHeGLnwHK1xc+s8nfeEek4WWPtP0B/m+UAcaxBX+4i/QZL/JnVtI9TnOJY/D4zD9px+mzzJXqTU30YedyxdoA+ZD7wDG8N1vZxH6fDem/lNtHH5mvntB7hJO9xNLM5zT3BElhgbpvpV2nnmO8A53gRfJV3rvS6TyMpzNDhYw4NHL/GZyxxrypxh0/zM8Y+ribQIQXWq2bqpt9gke9yMsbr7OPZgj9m2JeYYZp55rnb2xRHzhGeY2+wk7D5TvAWB7ldhDjWUG40mGM2h77NKs/n0IfNh8t5zgur+Lpmgzk6vMP+3qa/afMZbbZS9atJ5aAOJH9LQT8Ky7LrsY1i9LfzzbePC3zDCQ6WOfUG5ytzx2oMo/8hx3wn+IaTKAx9k3u8x0tJmma09e9GPn2ezpXM/Ru0OcanfQ1p7hU2y3QaqzGM/giXaaEk9Cf5Vyw93HeZrYQuRf9KqvIAl/mIozxbprebnOX9wBF9cXKFiLkR6OdTb98tn+PavMobwdRJVnzf5cII9FGqXAp6I2EttFaT58sR6UvpHhd5tdRnwvkkTV+74sk/Jr6UkzdzVSQPdukXquDl6ntwKZA0Aiqev9c5UxP9BmcL4zb5kpm+2rLzzoj033Oel4ami0RNWszTSGaAR3qYnj/L6BAf83Dg1dVPqdeJSqeYTpVnk8ISD0eZ54uP/VeVHlE0ewe0kxQa8b/K451Weuy7+prLySVrBR0Gp/kAzNFipXuWWrhjHWKipq4wzv7UWMylo7HI/U5xrQ+sAlTWGGimajbTzTuwj6OaxvktKa+ADvPd5x8x93EWei8tdl0R6LCUzNYRm3zJt3qf79zq/V12SxFTPMWl1JBHnKbBdPyV+tardlfjTKWWa6IU9xTT6WFNc2/SSnJLk4ilmi4GGzRSCTjzgNbAwLX4BczxZuLUf9WNkyWW2GKFsVQXt0ambxLxAHFo9mqMSSZzVo6aTPUR93E/4AY3khP0qTJ9g/Fk2CZZ6e0/xsokjphOLVn2q++5a+30hxNojDGeREuMlXkfHUd5FO4383lxXfRd0OOscDcJ2amstVJJlDcL9Bx6zj06fUSH0ywwy4fM5oxxN8ozQjlTgXl+jBaTOfQTHA5+sa5mkERqJnQzz3wBDb0+CdDv5Xj+F9OLsplFFoXNF1CpfTOByNnKaoro8AUtIg6kbtqjpLKiKuyvKkI/Tiu5nNhPg3WmmShlvoAq72cLuzbW71xMyg3eZnLwNrGaRtw/OJT+Ch3GknWScuYLqJb9muHIafBaTsKurhr3xw7SRyOaL6Da9yM/4Fs6tZgvoG3dt76N+gfaDbBaHMV3YgAAAABJRU5ErkJggg==']]])
class TestTimeseries(unittest.TestCase):
def test_as_component(self):
timeseries_output = gr.outputs.Timeseries(label="Disease")
self.assertEqual(timeseries_output.get_template_context(), {
'x': None, 'y': None, 'name': 'timeseries', 'label': 'Disease'
})
data = {'Name': ['Tom', 'nick', 'krish', 'jack'], 'Age': [20, 21, 19, 18]}
df = pd.DataFrame(data)
self.assertEqual(timeseries_output.postprocess(df),{'headers': ['Name', 'Age'],
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
['jack', 18]]})
timeseries_output = gr.outputs.Timeseries(y="Age", label="Disease")
output = timeseries_output.postprocess(df)
self.assertEqual(output, {'headers': ['Name', 'Age'],
'data': [['Tom', 20], ['nick', 21], ['krish', 19],
['jack', 18]]})
to_save = timeseries_output.save_flagged("flagged", "timeseries_output", output, None)
self.assertEqual(to_save, '{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
'["jack", 18]]}')
self.assertEqual(timeseries_output.restore_flagged(to_save), {"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21],
["krish", 19], ["jack", 18]]})
class TestNames(unittest.TestCase):
def test_no_duplicate_uncased_names(self): # this ensures that get_input_instance() works correctly when instantiating from components
subclasses = gr.outputs.OutputComponent.__subclasses__()