gradio/test/test_outputs.py

555 lines
26 KiB
Python
Raw Normal View History

import os
import tempfile
2019-02-28 08:54:08 +08:00
import unittest
import matplotlib.pyplot as plt
2020-09-22 02:51:39 +08:00
import numpy as np
import pandas as pd
import gradio as gr
2021-11-10 02:30:59 +08:00
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
2021-11-05 04:07:08 +08:00
class OutputComponent(unittest.TestCase):
def test_as_component(self):
output = gr.outputs.OutputComponent(label="Test Input")
self.assertEqual(output.postprocess("Hello World!"), "Hello World!")
self.assertEqual(output.deserialize(1), 1)
2019-02-28 08:54:08 +08:00
2020-09-22 02:51:39 +08:00
class TestTextbox(unittest.TestCase):
2021-11-05 04:07:08 +08:00
def test_as_component(self):
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Textbox(type="unknown")
wrong_type.postprocess(0)
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
self.assertEqual(iface.process(["Hello"])[0], ["o"])
iface = gr.Interface(
lambda x: x / 2, "number", gr.outputs.Textbox(type="number")
)
self.assertEqual(iface.process([10])[0], [5])
2020-09-22 02:51:39 +08:00
2019-02-28 08:54:08 +08:00
class TestLabel(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_as_component(self):
y = "happy"
2020-09-22 02:51:39 +08:00
label_output = gr.outputs.Label()
label = label_output.postprocess(y)
self.assertDictEqual(label, {"label": "happy"})
2021-11-05 04:07:08 +08:00
self.assertEqual(label_output.deserialize(y), y)
self.assertEqual(label_output.deserialize(label), y)
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdir:
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
self.assertEqual(to_save, y)
y = {3: 0.7, 1: 0.2, 0: 0.1}
2020-09-22 02:51:39 +08:00
label_output = gr.outputs.Label()
label = label_output.postprocess(y)
self.assertDictEqual(
label,
{
"label": 3,
"confidences": [
{"label": 3, "confidence": 0.7},
{"label": 1, "confidence": 0.2},
{"label": 0, "confidence": 0.1},
],
},
)
2021-11-05 04:07:08 +08:00
label_output = gr.outputs.Label(num_top_classes=2)
label = label_output.postprocess(y)
self.assertDictEqual(
label,
{
"label": 3,
"confidences": [
{"label": 3, "confidence": 0.7},
{"label": 1, "confidence": 0.2},
],
},
)
2021-11-05 04:07:08 +08:00
with self.assertRaises(ValueError):
label_output.postprocess([1, 2, 3])
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdir:
to_save = label_output.save_flagged(tmpdir, "label_output", label, None)
self.assertEqual(to_save, '{"3": 0.7, "1": 0.2}')
self.assertEqual(
label_output.restore_flagged(tmpdir, to_save, None),
{
"label": "3",
"confidences": [
{"label": "3", "confidence": 0.7},
{"label": "1", "confidence": 0.2},
],
},
)
2021-11-05 04:07:08 +08:00
with self.assertRaises(ValueError):
label_output = gr.outputs.Label(type="unknown")
label_output.deserialize([1, 2, 3])
2019-02-28 08:54:08 +08:00
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
x_img = gr.test_data.BASE64_IMAGE
2020-09-22 02:51:39 +08:00
def rgb_distribution(img):
rgb_dist = np.mean(img, axis=(0, 1))
rgb_dist /= np.sum(rgb_dist)
rgb_dist = np.round(rgb_dist, decimals=2)
return {
"red": rgb_dist[0],
"green": rgb_dist[1],
"blue": rgb_dist[2],
}
2019-02-28 08:54:08 +08:00
2020-09-22 02:51:39 +08:00
iface = gr.Interface(rgb_distribution, "image", "label")
output = iface.process([x_img])[0][0]
self.assertDictEqual(
output,
{
"label": "red",
"confidences": [
{"label": "red", "confidence": 0.44},
{"label": "green", "confidence": 0.28},
{"label": "blue", "confidence": 0.28},
],
},
)
2019-02-28 08:54:08 +08:00
2021-11-05 04:07:08 +08:00
2019-03-18 20:38:10 +08:00
class TestImage(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_as_component(self):
y_img = gr.processing_utils.decode_base64_to_image(gr.test_data.BASE64_IMAGE)
image_output = gr.outputs.Image()
self.assertTrue(
image_output.postprocess(y_img).startswith(
""
)
)
self.assertTrue(
image_output.postprocess(np.array(y_img)).startswith(
""
)
)
2021-11-05 04:07:08 +08:00
with self.assertWarns(DeprecationWarning):
plot_output = gr.outputs.Image(plot=True)
xpoints = np.array([0, 6])
ypoints = np.array([0, 250])
fig = plt.figure()
plt.plot(xpoints, ypoints)
self.assertTrue(
plot_output.postprocess(fig).startswith("data:image/png;base64,")
)
2021-11-05 04:07:08 +08:00
with self.assertRaises(ValueError):
image_output.postprocess([1, 2, 3])
image_output = gr.outputs.Image(type="numpy")
self.assertTrue(
image_output.postprocess(y_img).startswith("data:image/png;base64,")
)
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = image_output.save_flagged(
tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None
)
2021-11-09 03:29:04 +08:00
self.assertEqual("image_output/0.png", to_save)
to_save = image_output.save_flagged(
tmpdirname, "image_output", gr.test_data.BASE64_IMAGE, None
)
2021-11-09 03:29:04 +08:00
self.assertEqual("image_output/1.png", to_save)
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
def generate_noise(width, height):
return np.random.randint(0, 256, (width, height, 3))
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
self.assertTrue(
iface.process([10, 20])[0][0].startswith("data:image/png;base64")
)
2020-09-22 02:51:39 +08:00
2021-11-05 04:07:08 +08:00
class TestVideo(unittest.TestCase):
def test_as_component(self):
y_vid = "test/test_files/video_sample.mp4"
video_output = gr.outputs.Video()
self.assertTrue(
video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,")
)
self.assertTrue(
video_output.deserialize(gr.test_data.BASE64_VIDEO["data"]).endswith(".mp4")
)
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = video_output.save_flagged(
tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None
)
2021-11-09 03:29:04 +08:00
self.assertEqual("video_output/0.mp4", to_save)
to_save = video_output.save_flagged(
tmpdirname, "video_output", gr.test_data.BASE64_VIDEO, None
)
2021-11-09 03:29:04 +08:00
self.assertEqual("video_output/1.mp4", to_save)
2021-11-05 04:07:08 +08:00
2020-09-22 02:51:39 +08:00
class TestKeyValues(unittest.TestCase):
2021-11-05 04:07:08 +08:00
def test_as_component(self):
kv_output = gr.outputs.KeyValues()
kv_dict = {"a": 1, "b": 2}
kv_list = [("a", 1), ("b", 2)]
self.assertEqual(kv_output.postprocess(kv_dict), kv_list)
self.assertEqual(kv_output.postprocess(kv_list), kv_list)
with self.assertRaises(ValueError):
kv_output.postprocess(0)
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = kv_output.save_flagged(tmpdirname, "kv_output", kv_list, None)
self.assertEqual(to_save, '[["a", 1], ["b", 2]]')
self.assertEqual(
kv_output.restore_flagged(tmpdirname, to_save, None),
[["a", 1], ["b", 2]],
)
2021-11-05 04:07:08 +08:00
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
def letter_distribution(word):
dist = {}
for letter in word:
dist[letter] = dist.get(letter, 0) + 1
return dist
iface = gr.Interface(letter_distribution, "text", "key_values")
self.assertListEqual(
iface.process(["alpaca"])[0][0], [("a", 3), ("l", 1), ("p", 1), ("c", 1)]
)
2020-09-22 02:51:39 +08:00
2021-11-05 04:07:08 +08:00
2020-09-22 02:51:39 +08:00
class TestHighlightedText(unittest.TestCase):
2021-11-05 04:07:08 +08:00
def test_as_component(self):
ht_output = gr.outputs.HighlightedText(color_map={"pos": "green", "neg": "red"})
self.assertEqual(
ht_output.get_template_context(),
{
"color_map": {"pos": "green", "neg": "red"},
"name": "highlightedtext",
"label": None,
"show_legend": False,
},
)
ht = {"pos": "Hello ", "neg": "World"}
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = ht_output.save_flagged(tmpdirname, "ht_output", ht, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(
ht_output.restore_flagged(tmpdirname, to_save, None),
{"pos": "Hello ", "neg": "World"},
)
2021-11-05 04:07:08 +08:00
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
def highlight_vowels(sentence):
phrases, cur_phrase = [], ""
vowels, mode = "aeiou", None
for letter in sentence:
letter_mode = "vowel" if letter in vowels else "non"
if mode is None:
mode = letter_mode
elif mode != letter_mode:
phrases.append((cur_phrase, mode))
cur_phrase = ""
mode = letter_mode
cur_phrase += letter
phrases.append((cur_phrase, mode))
return phrases
2020-09-22 02:51:39 +08:00
iface = gr.Interface(highlight_vowels, "text", "highlight")
self.assertListEqual(
iface.process(["Helloooo"])[0][0],
[("H", "non"), ("e", "vowel"), ("ll", "non"), ("oooo", "vowel")],
)
2020-09-22 02:51:39 +08:00
class TestAudio(unittest.TestCase):
def test_as_component(self):
y_audio = gr.processing_utils.decode_base64_to_file(
gr.test_data.BASE64_AUDIO["data"]
)
2020-09-22 02:51:39 +08:00
audio_output = gr.outputs.Audio(type="file")
self.assertTrue(
audio_output.postprocess(y_audio.name).startswith(
"data:audio/wav;base64,UklGRuI/AABXQVZFZm10IBAAA"
)
)
self.assertEqual(
audio_output.get_template_context(), {"name": "audio", "label": None}
)
2021-11-05 04:07:08 +08:00
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Audio(type="unknown")
wrong_type.postprocess(y_audio.name)
self.assertTrue(
audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav")
)
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
)
2021-11-09 03:29:04 +08:00
self.assertEqual("audio_output/0.wav", to_save)
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
)
2021-11-09 03:29:04 +08:00
self.assertEqual("audio_output/1.wav", to_save)
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
def generate_noise(duration):
2021-10-13 04:42:50 +08:00
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int16)
2020-09-22 02:51:39 +08:00
iface = gr.Interface(generate_noise, "slider", "audio")
self.assertTrue(iface.process([100])[0][0].startswith("data:audio/wav;base64"))
class TestJSON(unittest.TestCase):
2021-11-05 04:07:08 +08:00
def test_as_component(self):
js_output = gr.outputs.JSON()
self.assertTrue(
js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"'
)
js = {"pos": "Hello ", "neg": "World"}
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = js_output.save_flagged(tmpdirname, "js_output", js, None)
self.assertEqual(to_save, '{"pos": "Hello ", "neg": "World"}')
self.assertEqual(
js_output.restore_flagged(tmpdirname, to_save, None),
{"pos": "Hello ", "neg": "World"},
)
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
def get_avg_age_per_gender(data):
return {
"M": int(data[data["gender"] == "M"].mean()),
"F": int(data[data["gender"] == "F"].mean()),
"O": int(data[data["gender"] == "O"].mean()),
}
iface = gr.Interface(
get_avg_age_per_gender,
gr.inputs.Dataframe(headers=["gender", "age"]),
"json",
)
2020-09-22 02:51:39 +08:00
y_data = [
["M", 30],
["F", 20],
["M", 40],
["O", 20],
["F", 30],
]
self.assertDictEqual(iface.process([y_data])[0][0], {"M": 35, "F": 25, "O": 20})
2020-09-22 02:51:39 +08:00
class TestHTML(unittest.TestCase):
def test_in_interface(self):
def bold_text(text):
return "<strong>" + text + "</strong>"
iface = gr.Interface(bold_text, "text", "html")
self.assertEqual(iface.process(["test"])[0][0], "<strong>test</strong>")
class TestFile(unittest.TestCase):
def test_as_component(self):
def write_file(content):
with open("test.txt", "w") as f:
f.write(content)
return "test.txt"
iface = gr.Interface(write_file, "text", "file")
self.assertDictEqual(
iface.process(["hello world"])[0][0],
{
"name": "test.txt",
"size": 11,
"data": "data:text/plain;base64,aGVsbG8gd29ybGQ=",
},
)
2021-11-05 04:07:08 +08:00
file_output = gr.outputs.File()
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = file_output.save_flagged(
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
)
2021-12-30 18:53:11 +08:00
self.assertEqual("file_output/0", to_save)
to_save = file_output.save_flagged(
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
)
2021-12-30 18:53:11 +08:00
self.assertEqual("file_output/1", to_save)
2020-09-22 02:51:39 +08:00
class TestDataframe(unittest.TestCase):
def test_as_component(self):
dataframe_output = gr.outputs.Dataframe()
output = dataframe_output.postprocess(np.zeros((2, 2)))
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]]})
output = dataframe_output.postprocess([[1, 3, 5]])
2020-09-22 02:51:39 +08:00
self.assertDictEqual(output, {"data": [[1, 3, 5]]})
output = dataframe_output.postprocess(
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
)
self.assertDictEqual(
output,
{"headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]]},
)
self.assertEqual(
dataframe_output.get_template_context(),
{
"headers": None,
"max_rows": 20,
"max_cols": None,
"overflow_row_behaviour": "paginate",
"name": "dataframe",
"label": None,
},
)
2021-11-05 04:07:08 +08:00
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Dataframe(type="unknown")
wrong_type.postprocess(0)
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = dataframe_output.save_flagged(
tmpdirname, "dataframe_output", output, None
)
self.assertEqual(to_save, "[[2, true], [3, true], [4, false]]")
self.assertEqual(
dataframe_output.restore_flagged(tmpdirname, to_save, None),
{"data": [[2, True], [3, True], [4, False]]},
)
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
def check_odd(array):
return array % 2 == 0
2020-09-22 02:51:39 +08:00
iface = gr.Interface(check_odd, "numpy", "numpy")
self.assertEqual(
iface.process([[2, 3, 4]])[0][0], {"data": [[True, False, True]]}
)
2019-03-18 20:38:10 +08:00
2021-11-03 05:33:18 +08:00
class TestCarousel(unittest.TestCase):
def test_as_component(self):
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
output = carousel_output.postprocess(
[
["Hello World", "test/test_files/bus.png"],
["Bye World", "test/test_files/bus.png"],
]
)
self.assertEqual(
output,
[
["Hello World", gr.test_data.BASE64_IMAGE],
["Bye World", gr.test_data.BASE64_IMAGE],
],
)
2021-11-03 05:33:18 +08:00
2021-11-05 04:07:08 +08:00
carousel_output = gr.outputs.Carousel("text", label="Disease")
output = carousel_output.postprocess([["Hello World"], ["Bye World"]])
self.assertEqual(output, [["Hello World"], ["Bye World"]])
self.assertEqual(
carousel_output.get_template_context(),
{
"components": [{"name": "textbox", "label": None}],
"name": "carousel",
"label": "Disease",
},
)
2021-11-05 04:07:08 +08:00
output = carousel_output.postprocess(["Hello World", "Bye World"])
self.assertEqual(output, [["Hello World"], ["Bye World"]])
2021-11-05 04:07:08 +08:00
with self.assertRaises(ValueError):
carousel_output.postprocess("Hello World!")
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = carousel_output.save_flagged(
tmpdirname, "carousel_output", output, None
)
2021-11-09 03:29:04 +08:00
self.assertEqual(to_save, '[["Hello World"], ["Bye World"]]')
2021-11-05 04:07:08 +08:00
2021-11-03 05:33:18 +08:00
def test_in_interface(self):
carousel_output = gr.outputs.Carousel(["text", "image"], label="Disease")
2021-11-03 05:33:18 +08:00
def report(img):
results = []
for i, mode in enumerate(["Red", "Green", "Blue"]):
color_filter = np.array([0, 0, 0])
color_filter[i] = 1
results.append([mode, img * color_filter])
return results
2021-11-03 05:33:18 +08:00
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
self.assertEqual(
iface.process([gr.test_data.BASE64_IMAGE])[0],
[
[
[
"Red",
"",
],
[
"Green",
"",
],
[
"Blue",
"",
],
]
],
)
2021-11-03 05:33:18 +08:00
2021-11-05 04:07:08 +08:00
class TestTimeseries(unittest.TestCase):
def test_as_component(self):
timeseries_output = gr.outputs.Timeseries(label="Disease")
self.assertEqual(
timeseries_output.get_template_context(),
{"x": None, "y": None, "name": "timeseries", "label": "Disease"},
)
data = {"Name": ["Tom", "nick", "krish", "jack"], "Age": [20, 21, 19, 18]}
2021-11-05 04:07:08 +08:00
df = pd.DataFrame(data)
self.assertEqual(
timeseries_output.postprocess(df),
{
"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
},
)
2021-11-05 04:07:08 +08:00
timeseries_output = gr.outputs.Timeseries(y="Age", label="Disease")
output = timeseries_output.postprocess(df)
self.assertEqual(
output,
{
"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
},
)
2021-11-05 04:07:08 +08:00
2021-11-09 03:29:04 +08:00
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = timeseries_output.save_flagged(
tmpdirname, "timeseries_output", output, None
)
self.assertEqual(
to_save,
'{"headers": ["Name", "Age"], "data": [["Tom", 20], ["nick", 21], ["krish", 19], '
'["jack", 18]]}',
)
self.assertEqual(
timeseries_output.restore_flagged(tmpdirname, to_save, None),
{
"headers": ["Name", "Age"],
"data": [["Tom", 20], ["nick", 21], ["krish", 19], ["jack", 18]],
},
)
2021-11-05 04:07:08 +08:00
2021-10-14 11:21:34 +08:00
class TestNames(unittest.TestCase):
def test_no_duplicate_uncased_names(
self,
): # this ensures that get_input_instance() works correctly when instantiating from components
2021-10-14 11:21:34 +08:00
subclasses = gr.outputs.OutputComponent.__subclasses__()
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
self.assertEqual(len(subclasses), len(unique_subclasses_uncased))
2019-03-18 20:38:10 +08:00
if __name__ == "__main__":
2019-02-28 08:54:08 +08:00
unittest.main()