mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
expanded testing for inputs
This commit is contained in:
parent
376bf3f138
commit
18d0bd2d13
@ -115,7 +115,7 @@ class Textbox(InputComponent):
|
||||
self.test_input = {
|
||||
"str": "the quick brown fox jumped over the lazy dog",
|
||||
"number": 786.92,
|
||||
}[type]
|
||||
}.get(type)
|
||||
else:
|
||||
self.test_input = default
|
||||
self.interpret_by_tokens = True
|
||||
|
File diff suppressed because one or more lines are too long
@ -11,6 +11,18 @@ import json
|
||||
import shutil
|
||||
|
||||
|
||||
class InputComponent(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
input = gr.inputs.InputComponent(label="Test Input")
|
||||
self.assertEqual(input.preprocess("Hello World!"), "Hello World!")
|
||||
self.assertEqual(input.preprocess_example(["1", "2", "3"]), ["1", "2", "3"])
|
||||
self.assertEqual(input.serialize(1, True), 1)
|
||||
self.assertEqual(input.set_interpret_parameters(), input)
|
||||
self.assertIsNone(input.get_interpretation_neighbors("Hi!"))
|
||||
self.assertIsNone(input.get_interpretation_scores("Hi!", [], []))
|
||||
self.assertIsNone(input.generate_sample())
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
text_input = gr.inputs.Textbox()
|
||||
@ -21,6 +33,24 @@ class TestTextbox(unittest.TestCase):
|
||||
self.assertEqual(to_save, "Hello World!")
|
||||
restored = text_input.restore_flagged(to_save)
|
||||
self.assertEqual(restored, "Hello World!")
|
||||
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
numeric_text_input = gr.inputs.Textbox(type="number")
|
||||
self.assertEqual(numeric_text_input.preprocess("2"), 2.0)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Textbox(type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
|
||||
['Hello', 'World!', 'Gradio', 'speaking.'],
|
||||
['World! Gradio speaking.', 'Hello Gradio speaking.', 'Hello World! speaking.', 'Hello World! Gradio'],
|
||||
None))
|
||||
text_input.interpretation_replacement = "unknown"
|
||||
self.assertEqual(text_input.tokenize("Hello World! Gradio speaking."), (
|
||||
['Hello', 'World!', 'Gradio', 'speaking.'],
|
||||
['unknown World! Gradio speaking.', 'Hello unknown Gradio speaking.', 'Hello World! unknown speaking.',
|
||||
'Hello World! Gradio unknown'], None))
|
||||
|
||||
self.assertIsInstance(text_input.generate_sample(), str)
|
||||
|
||||
def test_in_interface(self):
|
||||
@ -47,6 +77,10 @@ class TestNumber(unittest.TestCase):
|
||||
restored = numeric_input.restore_flagged(to_save)
|
||||
self.assertEqual(restored, 3)
|
||||
self.assertIsInstance(numeric_input.generate_sample(), float)
|
||||
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="absolute")
|
||||
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([-2.0, -1.0, 0.0, 2.0, 3.0, 4.0], {}))
|
||||
numeric_input.set_interpret_parameters(steps=3, delta=1, delta_type="percent")
|
||||
self.assertEqual(numeric_input.get_interpretation_neighbors(1), ([0.97, 0.98, 0.99, 1.01, 1.02, 1.03], {}))
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
||||
@ -71,6 +105,15 @@ class TestSlider(unittest.TestCase):
|
||||
restored = slider_input.restore_flagged(to_save)
|
||||
self.assertEqual(restored, 3)
|
||||
self.assertIsInstance(slider_input.generate_sample(), int)
|
||||
slider_input = gr.inputs.Slider(minimum=10, maximum=20, step=1, default=15, label="Slide Your Input")
|
||||
self.assertEqual(slider_input.get_template_context(), {
|
||||
'minimum': 10,
|
||||
'maximum': 20,
|
||||
'step': 1,
|
||||
'default': 15,
|
||||
'name': 'slider',
|
||||
'label': 'Slide Your Input'
|
||||
})
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
||||
@ -95,6 +138,12 @@ class TestCheckbox(unittest.TestCase):
|
||||
restored = bool_input.restore_flagged(to_save)
|
||||
self.assertEqual(restored, True)
|
||||
self.assertIsInstance(bool_input.generate_sample(), bool)
|
||||
bool_input = gr.inputs.Checkbox(default=True, label="Check Your Input")
|
||||
self.assertEqual(bool_input.get_template_context(), {
|
||||
'default': True,
|
||||
'name': 'checkbox',
|
||||
'label': 'Check Your Input'
|
||||
})
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "textbox")
|
||||
@ -103,6 +152,9 @@ class TestCheckbox(unittest.TestCase):
|
||||
scores, alternative_outputs = iface.interpret([False])
|
||||
self.assertEqual(scores, [(None, 1.0)])
|
||||
self.assertEqual(alternative_outputs, [[['1']]])
|
||||
scores, alternative_outputs = iface.interpret([True])
|
||||
self.assertEqual(scores, [(-1.0, None)])
|
||||
self.assertEqual(alternative_outputs, [[['0']]])
|
||||
|
||||
|
||||
class TestCheckboxGroup(unittest.TestCase):
|
||||
@ -116,6 +168,17 @@ class TestCheckboxGroup(unittest.TestCase):
|
||||
restored = checkboxes_input.restore_flagged(to_save)
|
||||
self.assertEqual(restored, ["a", "c"])
|
||||
self.assertIsInstance(checkboxes_input.generate_sample(), list)
|
||||
checkboxes_input = gr.inputs.CheckboxGroup(choices=["a", "b", "c"], default=["a", "c"],
|
||||
label="Check Your Inputs")
|
||||
self.assertEqual(checkboxes_input.get_template_context(), {
|
||||
'choices': ['a', 'b', 'c'],
|
||||
'default': ['a', 'c'],
|
||||
'name': 'checkboxgroup',
|
||||
'label': 'Check Your Inputs'
|
||||
})
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.CheckboxGroup(["a"], type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
checkboxes_input = gr.inputs.CheckboxGroup(["a", "b", "c"])
|
||||
@ -141,6 +204,17 @@ class TestRadio(unittest.TestCase):
|
||||
restored = radio_input.restore_flagged(to_save)
|
||||
self.assertEqual(restored, "a")
|
||||
self.assertIsInstance(radio_input.generate_sample(), str)
|
||||
radio_input = gr.inputs.Radio(choices=["a", "b", "c"], default="a",
|
||||
label="Pick Your One Input")
|
||||
self.assertEqual(radio_input.get_template_context(), {
|
||||
'choices': ['a', 'b', 'c'],
|
||||
'default': 'a',
|
||||
'name': 'radio',
|
||||
'label': 'Pick Your One Input'
|
||||
})
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Radio(["a","b"], type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
radio_input = gr.inputs.Radio(["a", "b", "c"])
|
||||
@ -165,6 +239,17 @@ class TestDropdown(unittest.TestCase):
|
||||
restored = dropdown_input.restore_flagged(to_save)
|
||||
self.assertEqual(restored, "a")
|
||||
self.assertIsInstance(dropdown_input.generate_sample(), str)
|
||||
dropdown_input = gr.inputs.Dropdown(choices=["a", "b", "c"], default="a",
|
||||
label="Drop Your Input")
|
||||
self.assertEqual(dropdown_input.get_template_context(), {
|
||||
'choices': ['a', 'b', 'c'],
|
||||
'default': 'a',
|
||||
'name': 'dropdown',
|
||||
'label': 'Drop Your Input'
|
||||
})
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Dropdown(["a"], type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
dropdown_input = gr.inputs.Dropdown(["a", "b", "c"])
|
||||
@ -197,6 +282,39 @@ class TestImage(unittest.TestCase):
|
||||
self.assertEqual(restored, "image_input/1.png")
|
||||
shutil.rmtree('flagged')
|
||||
self.assertIsInstance(image_input.generate_sample(), str)
|
||||
image_input = gr.inputs.Image(source="upload", tool="editor", type="pil", label="Upload Your Image")
|
||||
self.assertEqual(image_input.get_template_context(), {
|
||||
'image_mode': 'RGB',
|
||||
'shape': None,
|
||||
'source': 'upload',
|
||||
'tool': 'editor',
|
||||
'optional': False,
|
||||
'name': 'image',
|
||||
'label': 'Upload Your Image'
|
||||
})
|
||||
self.assertIsNone(image_input.preprocess(None))
|
||||
image_input = gr.inputs.Image(invert_colors=True)
|
||||
self.assertIsNotNone(image_input.preprocess(img))
|
||||
image_input.preprocess(img)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
file_image = gr.inputs.Image(type="file")
|
||||
file_image.preprocess(gr.test_data.BASE64_IMAGE)
|
||||
file_image = gr.inputs.Image(type="filepath")
|
||||
self.assertIsInstance(file_image.preprocess(img), str)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Image(type="unknown")
|
||||
wrong_type.preprocess(img)
|
||||
wrong_type.serialize("test/test_files/bus.png", False)
|
||||
img_pil = PIL.Image.open('test/test_files/bus.png')
|
||||
image_input = gr.inputs.Image(type="numpy")
|
||||
self.assertIsInstance(image_input.serialize(img_pil, False), str)
|
||||
image_input = gr.inputs.Image(type="pil")
|
||||
self.assertIsInstance(image_input.serialize(img_pil, False), str)
|
||||
image_input = gr.inputs.Image(type="file")
|
||||
with open("test/test_files/bus.png") as f:
|
||||
self.assertEqual(image_input.serialize(f, False), img)
|
||||
image_input.shape = (30, 10)
|
||||
self.assertIsNotNone(image_input._segment_by_slic(img))
|
||||
|
||||
def test_in_interface(self):
|
||||
img = gr.test_data.BASE64_IMAGE
|
||||
@ -209,6 +327,14 @@ class TestImage(unittest.TestCase):
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(scores, gr.test_data.SUM_PIXELS_INTERPRETATION["scores"])
|
||||
self.assertEqual(alternative_outputs, gr.test_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"])
|
||||
iface = gr.Interface(lambda x: np.sum(x), image_input, "label", interpretation="shap")
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(len(scores[0]), len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]))
|
||||
self.assertEqual(len(alternative_outputs[0]),
|
||||
len(gr.test_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]))
|
||||
image_input = gr.inputs.Image(shape=(30, 10))
|
||||
iface = gr.Interface(lambda x: np.sum(x), image_input, "textbox", interpretation="default")
|
||||
self.assertIsNotNone(iface.interpret([img]))
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
@ -228,6 +354,32 @@ class TestAudio(unittest.TestCase):
|
||||
self.assertEqual(restored, "audio_input/1.wav")
|
||||
shutil.rmtree('flagged')
|
||||
self.assertIsInstance(audio_input.generate_sample(), dict)
|
||||
audio_input = gr.inputs.Audio(label="Upload Your Audio")
|
||||
self.assertEqual(audio_input.get_template_context(), {
|
||||
'source': 'upload',
|
||||
'optional': False,
|
||||
'name': 'audio',
|
||||
'label': 'Upload Your Audio'
|
||||
})
|
||||
self.assertIsNone(audio_input.preprocess(None))
|
||||
x_wav["is_example"] = True
|
||||
x_wav["crop_min"], x_wav["crop_max"] = 1, 4
|
||||
self.assertIsNotNone(audio_input.preprocess(x_wav))
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
audio_input = gr.inputs.Audio(type="file")
|
||||
audio_input.preprocess(x_wav)
|
||||
with open("test/test_files/audio_sample.wav") as f:
|
||||
audio_input.serialize(f, False)
|
||||
audio_input = gr.inputs.Audio(type="filepath")
|
||||
self.assertIsInstance(audio_input.preprocess(x_wav), str)
|
||||
with self.assertRaises(ValueError):
|
||||
audio_input = gr.inputs.Audio(type="unknown")
|
||||
audio_input.preprocess(x_wav)
|
||||
audio_input.serialize(x_wav, False)
|
||||
audio_input = gr.inputs.Audio(type="numpy")
|
||||
x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav")
|
||||
self.assertIsInstance(audio_input.serialize(x_wav, False), dict)
|
||||
|
||||
|
||||
def test_in_interface(self):
|
||||
x_wav = gr.test_data.BASE64_AUDIO
|
||||
@ -240,7 +392,7 @@ class TestAudio(unittest.TestCase):
|
||||
max_amplitude_from_wav_file,
|
||||
gr.inputs.Audio(type="file"),
|
||||
"number", interpretation="default")
|
||||
self.assertEqual(iface.process([x_wav])[0], [5239])
|
||||
self.assertEqual(iface.process([x_wav])[0], [576])
|
||||
# scores, alternative_outputs = iface.interpret([x_wav])
|
||||
# self.assertEqual(scores, ... )
|
||||
# self.assertEqual(alternative_outputs, ...)
|
||||
@ -262,6 +414,16 @@ class TestFile(unittest.TestCase):
|
||||
self.assertEqual(restored, "file_input/1.pdf")
|
||||
shutil.rmtree('flagged')
|
||||
self.assertIsInstance(file_input.generate_sample(), dict)
|
||||
file_input = gr.inputs.File(label="Upload Your File")
|
||||
self.assertEqual(file_input.get_template_context(), {
|
||||
'file_count': 'single',
|
||||
'optional': False,
|
||||
'name': 'file',
|
||||
'label': 'Upload Your File'
|
||||
})
|
||||
self.assertIsNone(file_input.preprocess(None))
|
||||
x_file["is_example"] = True
|
||||
self.assertIsNotNone(file_input.preprocess(x_file))
|
||||
|
||||
def test_in_interface(self):
|
||||
x_file = gr.test_data.BASE64_FILE
|
||||
@ -287,6 +449,23 @@ class TestDataframe(unittest.TestCase):
|
||||
restored = dataframe_input.restore_flagged(to_save)
|
||||
self.assertEqual(x_data, restored)
|
||||
self.assertIsInstance(dataframe_input.generate_sample(), list)
|
||||
dataframe_input = gr.inputs.Dataframe(headers=["Name", "Age", "Member"], label="Dataframe Input")
|
||||
self.assertEqual(dataframe_input.get_template_context(), {
|
||||
'headers': ['Name', 'Age', 'Member'],
|
||||
'datatype': 'str',
|
||||
'row_count': 3,
|
||||
'col_count': 3,
|
||||
'col_width': None,
|
||||
'default': [[None, None, None], [None, None, None], [None, None, None]],
|
||||
'name': 'dataframe',
|
||||
'label': 'Dataframe Input'
|
||||
})
|
||||
dataframe_input = gr.inputs.Dataframe()
|
||||
output = dataframe_input.preprocess(x_data)
|
||||
self.assertEqual(output[1][1], 24)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.inputs.Dataframe(type="unknown")
|
||||
wrong_type.preprocess(x_data)
|
||||
|
||||
def test_in_interface(self):
|
||||
x_data = [[1, 2, 3], [4, 5, 6]]
|
||||
@ -315,6 +494,20 @@ class TestVideo(unittest.TestCase):
|
||||
self.assertEqual(restored, "video_input/1.mp4")
|
||||
shutil.rmtree('flagged')
|
||||
self.assertIsInstance(video_input.generate_sample(), dict)
|
||||
video_input = gr.inputs.Video(label="Upload Your Video")
|
||||
self.assertEqual(video_input.get_template_context(), {
|
||||
'optional': False,
|
||||
'name': 'video',
|
||||
'label': 'Upload Your Video'
|
||||
})
|
||||
self.assertIsNone(video_input.preprocess(None))
|
||||
x_video["is_example"] = True
|
||||
self.assertIsNotNone(video_input.preprocess(x_video))
|
||||
video_input = gr.inputs.Video(type="avi")
|
||||
self.assertEqual(video_input.preprocess(x_video)[-3:], "avi")
|
||||
with self.assertRaises(NotImplementedError):
|
||||
video_input.serialize(x_video, True)
|
||||
|
||||
|
||||
def test_in_interface(self):
|
||||
x_video = gr.test_data.BASE64_VIDEO
|
||||
@ -341,6 +534,20 @@ class TestTimeseries(unittest.TestCase):
|
||||
restored = timeseries_input.restore_flagged(to_save)
|
||||
self.assertEqual(x_timeseries, restored)
|
||||
self.assertIsInstance(timeseries_input.generate_sample(), dict)
|
||||
timeseries_input = gr.inputs.Timeseries(
|
||||
x="time",
|
||||
y="retail", label="Upload Your Timeseries"
|
||||
)
|
||||
self.assertEqual(timeseries_input.get_template_context(), {
|
||||
'x': 'time',
|
||||
'y': ['retail'],
|
||||
'optional': False,
|
||||
'name': 'timeseries',
|
||||
'label': 'Upload Your Timeseries'
|
||||
})
|
||||
self.assertIsNone(timeseries_input.preprocess(None))
|
||||
x_timeseries["range"] = (0, 1)
|
||||
self.assertIsNotNone(timeseries_input.preprocess(x_timeseries))
|
||||
|
||||
def test_in_interface(self):
|
||||
timeseries_input = gr.inputs.Timeseries(
|
||||
|
Loading…
Reference in New Issue
Block a user