expanded testing for inputs

This commit is contained in:
aliabd 2021-11-03 14:21:51 -07:00
parent 376bf3f138
commit 18d0bd2d13
3 changed files with 214 additions and 6 deletions

View File

@ -115,7 +115,7 @@ class Textbox(InputComponent):
self.test_input = {
"str": "the quick brown fox jumped over the lazy dog",
"number": 786.92,
}[type]
}.get(type)
else:
self.test_input = default
self.interpret_by_tokens = True

File diff suppressed because one or more lines are too long

View File

@ -11,6 +11,18 @@ import json
import shutil
class InputComponent(unittest.TestCase):
def test_as_component(self):
input = gr.inputs.InputComponent(label="Test Input")
self.assertEqual(input.preprocess("Hello World!"), "Hello World!")
self.assertEqual(input.preprocess_example(["1", "2", "3"]), ["1", "2", "3"])
self.assertEqual(input.serialize(1, True), 1)
self.assertEqual(input.set_interpret_parameters(), input)
self.assertIsNone(input.get_interpretation_neighbors("Hi!"))
self.assertIsNone(input.get_interpretation_scores("Hi!", [], []))
self.assertIsNone(input.generate_sample())
class TestTextbox(unittest.TestCase):
def test_as_component(self):
text_input = gr.inputs.Textbox()
@ -21,6 +33,24 @@ class TestTextbox(unittest.TestCase):
self.assertEqual(to_save, "Hello World!")
restored = text_input.restore_flagged(to_save)
self.assertEqual(restored, "Hello World!")
with self.assertWarns(DeprecationWarning):
numeric_text_input = gr.inputs.Textbox(type="number")
self.assertEqual(numeric_text_input.preprocess("2"), 2.0)
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Textbox(type="unknown")
wrong_type.preprocess(0)
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
['Hello', 'World!', 'Gradio', 'speaking.'],
['World! Gradio speaking.', 'Hello Gradio speaking.', 'Hello World! speaking.', 'Hello World! Gradio'],
None))
text_input.interpretation_replacement = "unknown"
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
['Hello', 'World!', 'Gradio', 'speaking.'],
['unknown World! Gradio speaking.', 'Hello unknown Gradio speaking.', 'Hello World! unknown speaking.',
'Hello World! Gradio unknown'], None))
self.assertIsInstance(text_input.generate_sample(), str)
def test_in_interface(self):
@ -47,6 +77,10 @@ class TestNumber(unittest.TestCase):
restored = numeric_input.restore_flagged(to_save)
self.assertEqual(restored, 3)
self.assertIsInstance(numeric_input.generate_sample(), float)
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute")
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}))
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="percent")
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}))
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "number", "textbox")
@ -71,6 +105,15 @@ class TestSlider(unittest.TestCase):
restored = slider_input.restore_flagged(to_save)
self.assertEqual(restored, 3)
self.assertIsInstance(slider_input.generate_sample(), int)
slider_input = gr.inputs.Slider(minimum=10, maximum=20, step=1, default=15, label="Slide Your Input")
self.assertEqual(slider_input.get_template_context(), {
'minimum': 10,
'maximum': 20,
'step': 1,
'default': 15,
'name': 'slider',
'label': 'Slide Your Input'
})
def test_in_interface(self):
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
@ -95,6 +138,12 @@ class TestCheckbox(unittest.TestCase):
restored = bool_input.restore_flagged(to_save)
self.assertEqual(restored, True)
self.assertIsInstance(bool_input.generate_sample(), bool)
bool_input = gr.inputs.Checkbox(default=True, label="Check Your Input")
self.assertEqual(bool_input.get_template_context(), {
'default': True,
'name': 'checkbox',
'label': 'Check Your Input'
})
def test_in_interface(self):
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "textbox")
@ -103,6 +152,9 @@ class TestCheckbox(unittest.TestCase):
scores, alternative_outputs = iface.interpret([False])
self.assertEqual(scores, [(None, 1.0)])
self.assertEqual(alternative_outputs, [[['1']]])
scores, alternative_outputs = iface.interpret([True])
self.assertEqual(scores, [(-1.0, None)])
self.assertEqual(alternative_outputs, [[['0']]])
class TestCheckboxGroup(unittest.TestCase):
@ -116,6 +168,17 @@ class TestCheckboxGroup(unittest.TestCase):
restored = checkboxes_input.restore_flagged(to_save)
self.assertEqual(restored, ["a", "c"])
self.assertIsInstance(checkboxes_input.generate_sample(), list)
checkboxes_input = gr.inputs.CheckboxGroup(choices=["a", "b", "c"], default=["a", "c"],
label="Check Your Inputs")
self.assertEqual(checkboxes_input.get_template_context(), {
'choices': ['a', 'b', 'c'],
'default': ['a', 'c'],
'name': 'checkboxgroup',
'label': 'Check Your Inputs'
})
with self.assertRaises(ValueError):
wrong_type = gr.inputs.CheckboxGroup(["a"], type="unknown")
wrong_type.preprocess(0)
def test_in_interface(self):
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"])
@ -141,6 +204,17 @@ class TestRadio(unittest.TestCase):
restored = radio_input.restore_flagged(to_save)
self.assertEqual(restored, "a")
self.assertIsInstance(radio_input.generate_sample(), str)
radio_input = gr.inputs.Radio(choices=["a", "b", "c"], default="a",
label="Pick Your One Input")
self.assertEqual(radio_input.get_template_context(), {
'choices': ['a', 'b', 'c'],
'default': 'a',
'name': 'radio',
'label': 'Pick Your One Input'
})
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Radio(["a","b"], type="unknown")
wrong_type.preprocess(0)
def test_in_interface(self):
radio_input = gr.inputs.Radio(["a", "b", "c"])
@ -165,6 +239,17 @@ class TestDropdown(unittest.TestCase):
restored = dropdown_input.restore_flagged(to_save)
self.assertEqual(restored, "a")
self.assertIsInstance(dropdown_input.generate_sample(), str)
dropdown_input = gr.inputs.Dropdown(choices=["a", "b", "c"], default="a",
label="Drop Your Input")
self.assertEqual(dropdown_input.get_template_context(), {
'choices': ['a', 'b', 'c'],
'default': 'a',
'name': 'dropdown',
'label': 'Drop Your Input'
})
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Dropdown(["a"], type="unknown")
wrong_type.preprocess(0)
def test_in_interface(self):
dropdown_input = gr.inputs.Dropdown(["a", "b", "c"])
@ -197,6 +282,39 @@ class TestImage(unittest.TestCase):
self.assertEqual(restored, "image_input/1.png")
shutil.rmtree('flagged')
self.assertIsInstance(image_input.generate_sample(), str)
image_input = gr.inputs.Image(source="upload", tool="editor", type="pil", label="Upload Your Image")
self.assertEqual(image_input.get_template_context(), {
'image_mode': 'RGB',
'shape': None,
'source': 'upload',
'tool': 'editor',
'optional': False,
'name': 'image',
'label': 'Upload Your Image'
})
self.assertIsNone(image_input.preprocess(None))
image_input = gr.inputs.Image(invert_colors=True)
self.assertIsNotNone(image_input.preprocess(img))
image_input.preprocess(img)
with self.assertWarns(DeprecationWarning):
file_image = gr.inputs.Image(type="file")
file_image.preprocess(gr.test_data.BASE64_IMAGE)
file_image = gr.inputs.Image(type="filepath")
self.assertIsInstance(file_image.preprocess(img), str)
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Image(type="unknown")
wrong_type.preprocess(img)
wrong_type.serialize("test/test_files/bus.png", False)
img_pil = PIL.Image.open('test/test_files/bus.png')
image_input = gr.inputs.Image(type="numpy")
self.assertIsInstance(image_input.serialize(img_pil, False), str)
image_input = gr.inputs.Image(type="pil")
self.assertIsInstance(image_input.serialize(img_pil, False), str)
image_input = gr.inputs.Image(type="file")
with open("test/test_files/bus.png") as f:
self.assertEqual(image_input.serialize(f, False), img)
image_input.shape = (30, 10)
self.assertIsNotNone(image_input._segment_by_slic(img))
def test_in_interface(self):
img = gr.test_data.BASE64_IMAGE
@ -209,6 +327,14 @@ class TestImage(unittest.TestCase):
scores, alternative_outputs = iface.interpret([img])
self.assertEqual(scores, gr.test_data.SUM_PIXELS_INTERPRETATION["scores"])
self.assertEqual(alternative_outputs, gr.test_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"])
iface = gr.Interface(lambda x: np.sum(x), image_input, "label", interpretation="shap")
scores, alternative_outputs = iface.interpret([img])
self.assertEqual(len(scores[0]), len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]))
self.assertEqual(len(alternative_outputs[0]),
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]))
image_input = gr.inputs.Image(shape=(30, 10))
iface = gr.Interface(lambda x: np.sum(x), image_input, "textbox", interpretation="default")
self.assertIsNotNone(iface.interpret([img]))
class TestAudio(unittest.TestCase):
@ -228,6 +354,32 @@ class TestAudio(unittest.TestCase):
self.assertEqual(restored, "audio_input/1.wav")
shutil.rmtree('flagged')
self.assertIsInstance(audio_input.generate_sample(), dict)
audio_input = gr.inputs.Audio(label="Upload Your Audio")
self.assertEqual(audio_input.get_template_context(), {
'source': 'upload',
'optional': False,
'name': 'audio',
'label': 'Upload Your Audio'
})
self.assertIsNone(audio_input.preprocess(None))
x_wav["is_example"] = True
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
self.assertIsNotNone(audio_input.preprocess(x_wav))
with self.assertWarns(DeprecationWarning):
audio_input = gr.inputs.Audio(type="file")
audio_input.preprocess(x_wav)
with open("test/test_files/audio_sample.wav") as f:
audio_input.serialize(f, False)
audio_input = gr.inputs.Audio(type="filepath")
self.assertIsInstance(audio_input.preprocess(x_wav), str)
with self.assertRaises(ValueError):
audio_input = gr.inputs.Audio(type="unknown")
audio_input.preprocess(x_wav)
audio_input.serialize(x_wav, False)
audio_input = gr.inputs.Audio(type="numpy")
x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav")
self.assertIsInstance(audio_input.serialize(x_wav, False), dict)
def test_in_interface(self):
x_wav = gr.test_data.BASE64_AUDIO
@ -240,7 +392,7 @@ class TestAudio(unittest.TestCase):
max_amplitude_from_wav_file,
gr.inputs.Audio(type="file"),
"number", interpretation="default")
self.assertEqual(iface.process([x_wav])[0], [5239])
self.assertEqual(iface.process([x_wav])[0], [576])
# scores, alternative_outputs = iface.interpret([x_wav])
# self.assertEqual(scores, ... )
# self.assertEqual(alternative_outputs, ...)
@ -262,6 +414,16 @@ class TestFile(unittest.TestCase):
self.assertEqual(restored, "file_input/1.pdf")
shutil.rmtree('flagged')
self.assertIsInstance(file_input.generate_sample(), dict)
file_input = gr.inputs.File(label="Upload Your File")
self.assertEqual(file_input.get_template_context(), {
'file_count': 'single',
'optional': False,
'name': 'file',
'label': 'Upload Your File'
})
self.assertIsNone(file_input.preprocess(None))
x_file["is_example"] = True
self.assertIsNotNone(file_input.preprocess(x_file))
def test_in_interface(self):
x_file = gr.test_data.BASE64_FILE
@ -287,6 +449,23 @@ class TestDataframe(unittest.TestCase):
restored = dataframe_input.restore_flagged(to_save)
self.assertEqual(x_data, restored)
self.assertIsInstance(dataframe_input.generate_sample(), list)
dataframe_input = gr.inputs.Dataframe(headers=["Name", "Age", "Member"], label="Dataframe Input")
self.assertEqual(dataframe_input.get_template_context(), {
'headers': ['Name', 'Age', 'Member'],
'datatype': 'str',
'row_count': 3,
'col_count': 3,
'col_width': None,
'default': [[None, None, None], [None, None, None], [None, None, None]],
'name': 'dataframe',
'label': 'Dataframe Input'
})
dataframe_input = gr.inputs.Dataframe()
output = dataframe_input.preprocess(x_data)
self.assertEqual(output[1][1], 24)
with self.assertRaises(ValueError):
wrong_type = gr.inputs.Dataframe(type="unknown")
wrong_type.preprocess(x_data)
def test_in_interface(self):
x_data = [[1, 2, 3], [4, 5, 6]]
@ -315,6 +494,20 @@ class TestVideo(unittest.TestCase):
self.assertEqual(restored, "video_input/1.mp4")
shutil.rmtree('flagged')
self.assertIsInstance(video_input.generate_sample(), dict)
video_input = gr.inputs.Video(label="Upload Your Video")
self.assertEqual(video_input.get_template_context(), {
'optional': False,
'name': 'video',
'label': 'Upload Your Video'
})
self.assertIsNone(video_input.preprocess(None))
x_video["is_example"] = True
self.assertIsNotNone(video_input.preprocess(x_video))
video_input = gr.inputs.Video(type="avi")
self.assertEqual(video_input.preprocess(x_video)[-3:], "avi")
with self.assertRaises(NotImplementedError):
video_input.serialize(x_video, True)
def test_in_interface(self):
x_video = gr.test_data.BASE64_VIDEO
@ -341,6 +534,20 @@ class TestTimeseries(unittest.TestCase):
restored = timeseries_input.restore_flagged(to_save)
self.assertEqual(x_timeseries, restored)
self.assertIsInstance(timeseries_input.generate_sample(), dict)
timeseries_input = gr.inputs.Timeseries(
x="time",
y="retail", label="Upload Your Timeseries"
)
self.assertEqual(timeseries_input.get_template_context(), {
'x': 'time',
'y': ['retail'],
'optional': False,
'name': 'timeseries',
'label': 'Upload Your Timeseries'
})
self.assertIsNone(timeseries_input.preprocess(None))
x_timeseries["range"] = (0, 1)
self.assertIsNotNone(timeseries_input.preprocess(x_timeseries))
def test_in_interface(self):
timeseries_input = gr.inputs.Timeseries(