gradio/test/test_inputs.py

136 lines
5.0 KiB
Python
Raw Normal View History

2019-02-28 08:43:20 +08:00
import unittest
2020-09-15 20:16:14 +08:00
import gradio as gr
2020-09-18 05:38:22 +08:00
import PIL
import numpy as np
2020-09-22 02:51:39 +08:00
import scipy
import os
2019-02-28 08:43:20 +08:00
2020-09-15 20:16:14 +08:00
class TestTextbox(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
2020-09-15 20:16:14 +08:00
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
2020-09-15 20:16:14 +08:00
iface = gr.Interface(lambda x: x*x, "number", "number")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process(["5"])[0], [25])
2020-09-15 20:16:14 +08:00
class TestSlider(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
2020-09-18 05:38:22 +08:00
iface = gr.Interface(lambda x: str(x) + " cats", "slider", "textbox")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process([4])[0], ["4 cats"])
2020-09-15 20:16:14 +08:00
class TestCheckbox(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
2020-09-18 05:38:22 +08:00
iface = gr.Interface(lambda x: "yes" if x else "no", "checkbox", "textbox")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process([False])[0], ["no"])
2020-09-15 20:16:14 +08:00
class TestCheckboxGroup(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
2020-09-18 05:38:22 +08:00
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"])
iface = gr.Interface(lambda x: "|".join(x), checkboxes, "textbox")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
self.assertEqual(iface.process([[]])[0], [""])
2020-09-18 05:38:22 +08:00
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: "|".join(map(str, x)), checkboxes, "textbox")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process([["a", "c"]])[0], ["0|2"])
2020-09-15 20:16:14 +08:00
class TestRadio(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
2020-09-18 05:38:22 +08:00
radio = gr.inputs.Radio(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, radio, "textbox")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process(["c"])[0], ["cc"])
2020-09-18 05:38:22 +08:00
radio = gr.inputs.Radio(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, radio, "number")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process(["c"])[0], [4])
2020-09-15 20:16:14 +08:00
class TestDropdown(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
2020-09-18 05:38:22 +08:00
dropdown = gr.inputs.Dropdown(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, dropdown, "textbox")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process(["c"])[0], ["cc"])
2020-09-18 05:38:22 +08:00
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, dropdown, "number")
2020-09-22 02:51:39 +08:00
self.assertEqual(iface.process(["c"])[0], [4])
2020-08-06 01:42:52 +08:00
2019-02-28 08:43:20 +08:00
2020-08-11 04:52:43 +08:00
class TestImage(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_as_component(self):
2020-09-18 05:38:22 +08:00
x_img = gr.test_data.BASE64_IMAGE
image_input = gr.inputs.Image()
2020-09-22 02:51:39 +08:00
self.assertEqual(image_input.preprocess(x_img).shape, (68, 61 ,3))
2020-09-18 05:38:22 +08:00
image_input = gr.inputs.Image(image_mode="L", shape=(25, 25))
2020-09-22 02:51:39 +08:00
self.assertEqual(image_input.preprocess(x_img).shape, (25, 25))
2020-09-18 05:38:22 +08:00
image_input = gr.inputs.Image(shape=(30, 10), type="pil")
2020-09-22 02:51:39 +08:00
self.assertEqual(image_input.preprocess(x_img).size, (30, 10))
2020-09-18 05:38:22 +08:00
2019-02-28 08:43:20 +08:00
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
2020-09-18 05:38:22 +08:00
x_img = gr.test_data.BASE64_IMAGE
def open_and_rotate(img_file):
img = PIL.Image.open(img_file)
return img.rotate(90, expand=True)
iface = gr.Interface(
open_and_rotate,
gr.inputs.Image(shape=(30, 10), type="file"),
"image")
output = iface.process([x_img])[0][0]
2020-09-22 02:51:39 +08:00
self.assertEqual(gr.processing_utils.decode_base64_to_image(output).size, (10, 30))
2020-09-18 05:38:22 +08:00
2019-02-28 08:43:20 +08:00
2020-09-15 20:16:14 +08:00
class TestAudio(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_as_component(self):
2020-09-18 05:38:22 +08:00
x_wav = gr.test_data.BASE64_AUDIO
audio_input = gr.inputs.Audio()
output = audio_input.preprocess(x_wav)
2020-09-22 02:51:39 +08:00
self.assertEqual(output[0], 8000)
self.assertEqual(output[1].shape, (8046,))
2020-09-18 05:38:22 +08:00
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
x_wav = gr.test_data.BASE64_AUDIO
def max_amplitude_from_wav_file(wav_file):
_, data = scipy.io.wavfile.read(wav_file.name)
return np.max(data)
2020-09-18 05:38:22 +08:00
2020-09-22 02:51:39 +08:00
iface = gr.Interface(
max_amplitude_from_wav_file,
gr.inputs.Audio(type="file"),
"number")
self.assertEqual(iface.process([x_wav])[0], [5239])
2020-09-15 20:16:14 +08:00
class TestFile(unittest.TestCase):
2020-09-22 02:51:39 +08:00
def test_in_interface(self):
x_file = gr.test_data.BASE64_AUDIO
def get_size_of_file(file_obj):
return os.path.getsize(file_obj.name)
2020-09-15 20:16:14 +08:00
2020-09-22 02:51:39 +08:00
iface = gr.Interface(
get_size_of_file, "file", "number")
self.assertEqual(iface.process([x_file])[0], [16362])
2020-09-15 20:16:14 +08:00
2020-09-22 02:51:39 +08:00
class TestDataframe(unittest.TestCase):
def test_as_component(self):
x_data = [["Tim",12,False],["Jan",24,True]]
dataframe_input = gr.inputs.Dataframe(headers=["Name","Age","Member"])
output = dataframe_input.preprocess(x_data)
self.assertEqual(output["Age"][1], 24)
self.assertEqual(output["Member"][0], False)
def test_in_interface(self):
x_data = [[1,2,3],[4,5,6]]
iface = gr.Interface(np.max, "numpy", "number")
self.assertEqual(iface.process([x_data])[0], [6])
x_data = [["Tim"], ["Jon"], ["Sal"]]
def get_last(l):
return l[-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(iface.process([x_data])[0], ["Sal"])
2019-02-28 08:43:20 +08:00
if __name__ == '__main__':
2019-06-19 04:13:50 +08:00
unittest.main()