mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
blocks-components-test
- finalize and document components tests
This commit is contained in:
parent
3e72cfdea0
commit
9f7a48604b
@ -1,8 +1,8 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from difflib import SequenceMatcher
|
||||
from test.test_data import media_data
|
||||
|
||||
@ -15,12 +15,25 @@ import gradio as gr
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
"""
|
||||
Tests are divided into two
|
||||
1. test_component_functionalities are unit tests that check essential functions of a component, the functions that are checked are documented in the docstring.
|
||||
2. test_in_interface_... are functional tests that check a component's functionalities inside an Interface. Please do not use Interface.launch() in this file, as it slow downs the tests.
|
||||
"""
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_as_input_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, postprocess, serialize, save_flagged, restore_flagged, tokenize, generate_sample, get_template_context
|
||||
"""
|
||||
text_input = gr.Textbox()
|
||||
self.assertEqual(text_input.preprocess("Hello World!"), "Hello World!")
|
||||
self.assertEqual(text_input.preprocess_example("Hello World!"), "Hello World!")
|
||||
self.assertEqual(text_input.postprocess(None), None)
|
||||
self.assertEqual(text_input.postprocess("Ali"), "Ali")
|
||||
self.assertEqual(text_input.postprocess(2), "2")
|
||||
self.assertEqual(text_input.postprocess(2.14), "2.14")
|
||||
self.assertEqual(text_input.serialize("Hello World!", True), "Hello World!")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = text_input.save_flagged(
|
||||
@ -60,10 +73,23 @@ class TestTextbox(unittest.TestCase):
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
text_input.get_template_context(),
|
||||
{
|
||||
"lines": 1,
|
||||
"placeholder": None,
|
||||
"default_value": "",
|
||||
"name": "textbox",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
)
|
||||
self.assertIsInstance(text_input.generate_sample(), str)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process, interpret,
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["olleH"])
|
||||
iface = gr.Interface(
|
||||
@ -108,6 +134,10 @@ class TestTextbox(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
Interface, process
|
||||
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x[-1], "textbox", gr.Textbox())
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["o"])
|
||||
iface = gr.Interface(lambda x: x / 2, "number", gr.Textbox())
|
||||
@ -115,11 +145,18 @@ class TestTextbox(unittest.TestCase):
|
||||
|
||||
|
||||
class TestNumber(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, set_interpret_parameters, get_interpretation_neighbors, get_template_context
|
||||
|
||||
"""
|
||||
numeric_input = gr.Number()
|
||||
self.assertEqual(numeric_input.preprocess(3), 3.0)
|
||||
self.assertEqual(numeric_input.preprocess(None), None)
|
||||
self.assertEqual(numeric_input.preprocess_example(3), 3)
|
||||
self.assertEqual(numeric_input.postprocess(3), 3.0)
|
||||
self.assertEqual(numeric_input.postprocess(2.14), 2.14)
|
||||
self.assertEqual(numeric_input.postprocess(None), None)
|
||||
self.assertEqual(numeric_input.serialize(3, True), 3)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = numeric_input.save_flagged(tmpdirname, "numeric_input", 3, None)
|
||||
@ -142,11 +179,52 @@ class TestNumber(unittest.TestCase):
|
||||
{"default_value": None, "name": "number", "label": None, "css": {}},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x ** 2, "number", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4.0"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "number", "number", interpretation="default"
|
||||
lambda x: x ** 2, "number", "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(
|
||||
scores,
|
||||
[
|
||||
[
|
||||
(1.94, -0.23640000000000017),
|
||||
(1.96, -0.15840000000000032),
|
||||
(1.98, -0.07960000000000012),
|
||||
[2, None],
|
||||
(2.02, 0.08040000000000003),
|
||||
(2.04, 0.16159999999999997),
|
||||
(2.06, 0.24359999999999982),
|
||||
]
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
[
|
||||
[
|
||||
[3.7636],
|
||||
[3.8415999999999997],
|
||||
[3.9204],
|
||||
[4.0804],
|
||||
[4.1616],
|
||||
[4.2436],
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: int(x) ** 2, "textbox", "number")
|
||||
self.assertEqual(iface.process([2])[0], [4.0])
|
||||
iface = gr.Interface(
|
||||
lambda x: x ** 2, "number", "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(
|
||||
@ -179,10 +257,15 @@ class TestNumber(unittest.TestCase):
|
||||
|
||||
|
||||
class TestSlider(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, get_template_context
|
||||
"""
|
||||
slider_input = gr.Slider()
|
||||
self.assertEqual(slider_input.preprocess(3.0), 3.0)
|
||||
self.assertEqual(slider_input.preprocess_example(3), 3)
|
||||
self.assertEqual(slider_input.postprocess(3), 3)
|
||||
self.assertEqual(slider_input.postprocess(None), None)
|
||||
self.assertEqual(slider_input.serialize(3, True), 3)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = slider_input.save_flagged(tmpdirname, "slider_input", 3, None)
|
||||
@ -208,10 +291,13 @@ class TestSlider(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
||||
""" "
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: x ** 2, "slider", "textbox")
|
||||
self.assertEqual(iface.process([2])[0], ["4"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x**2, "slider", "number", interpretation="default"
|
||||
lambda x: x ** 2, "slider", "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([2])
|
||||
self.assertEqual(
|
||||
@ -247,10 +333,14 @@ class TestSlider(unittest.TestCase):
|
||||
|
||||
|
||||
class TestCheckbox(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, postprocess, serialize, generate_sample, get_template_context
|
||||
"""
|
||||
bool_input = gr.Checkbox()
|
||||
self.assertEqual(bool_input.preprocess(True), True)
|
||||
self.assertEqual(bool_input.preprocess_example(True), True)
|
||||
self.assertEqual(bool_input.postprocess(True), True)
|
||||
self.assertEqual(bool_input.serialize(True, True), True)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = bool_input.save_flagged(tmpdirname, "bool_input", True, None)
|
||||
@ -270,6 +360,9 @@ class TestCheckbox(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
"""
|
||||
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number")
|
||||
self.assertEqual(iface.process([True])[0], [1])
|
||||
iface = gr.Interface(
|
||||
@ -284,7 +377,10 @@ class TestCheckbox(unittest.TestCase):
|
||||
|
||||
|
||||
class TestCheckboxGroup(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, preprocess_example, serialize, save_flagged, restore_flagged, generate_sample, get_template_context
|
||||
"""
|
||||
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
|
||||
self.assertEqual(checkboxes_input.preprocess(["a", "c"]), ["a", "c"])
|
||||
self.assertEqual(checkboxes_input.preprocess_example(["a", "c"]), ["a", "c"])
|
||||
@ -317,15 +413,22 @@ class TestCheckboxGroup(unittest.TestCase):
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
|
||||
self.assertEqual(iface.process([["a", "c"]])[0], ["a|c"])
|
||||
self.assertEqual(iface.process([[]])[0], [""])
|
||||
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"], type="index")
|
||||
_ = gr.CheckboxGroup(["a", "b", "c"], type="index")
|
||||
|
||||
|
||||
class TestRadio(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, preprocess_example, serialize, save_flagged, generate_sample, get_template_context
|
||||
|
||||
"""
|
||||
radio_input = gr.Radio(["a", "b", "c"])
|
||||
self.assertEqual(radio_input.preprocess("c"), "c")
|
||||
self.assertEqual(radio_input.preprocess_example("a"), "a")
|
||||
@ -354,6 +457,9 @@ class TestRadio(unittest.TestCase):
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
"""
|
||||
radio_input = gr.Radio(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
@ -367,54 +473,13 @@ class TestRadio(unittest.TestCase):
|
||||
self.assertEqual(alternative_outputs, [[[0], [4]]])
|
||||
|
||||
|
||||
class TestDropdown(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
dropdown_input = gr.Dropdown(["a", "b", "c"])
|
||||
self.assertEqual(dropdown_input.preprocess("c"), "c")
|
||||
self.assertEqual(dropdown_input.preprocess_example("a"), "a")
|
||||
self.assertEqual(dropdown_input.serialize("a", True), "a")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = dropdown_input.save_flagged(
|
||||
tmpdirname, "dropdown_input", "a", None
|
||||
)
|
||||
self.assertEqual(to_save, "a")
|
||||
restored = dropdown_input.restore_flagged(tmpdirname, to_save, None)
|
||||
self.assertEqual(restored, "a")
|
||||
self.assertIsInstance(dropdown_input.generate_sample(), str)
|
||||
dropdown_input = gr.Dropdown(
|
||||
choices=["a", "b", "c"], default="a", label="Drop Your Input"
|
||||
)
|
||||
self.assertEqual(
|
||||
dropdown_input.get_template_context(),
|
||||
{
|
||||
"choices": ["a", "b", "c"],
|
||||
"default_value": "a",
|
||||
"name": "dropdown",
|
||||
"label": "Drop Your Input",
|
||||
"css": {},
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.Dropdown(["a"], type="unknown")
|
||||
wrong_type.preprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
dropdown_input = gr.Dropdown(["a", "b", "c"])
|
||||
iface = gr.Interface(lambda x: 2 * x, dropdown_input, "textbox")
|
||||
self.assertEqual(iface.process(["c"])[0], ["cc"])
|
||||
dropdown = gr.Dropdown(["a", "b", "c"], type="index")
|
||||
iface = gr.Interface(
|
||||
lambda x: 2 * x, dropdown, "number", interpretation="default"
|
||||
)
|
||||
self.assertEqual(iface.process(["c"])[0], [4])
|
||||
scores, alternative_outputs = iface.interpret(["b"])
|
||||
self.assertEqual(scores, [[-2.0, None, 2.0]])
|
||||
self.assertEqual(alternative_outputs, [[[0], [4]]])
|
||||
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_as_component_as_input(self):
|
||||
img = media_data.BASE64_IMAGE
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, postprocess, serialize, save_flagged, restore_flagged, generate_sample, get_template_context, _segment_by_slic
|
||||
type: pil, file, filepath, numpy
|
||||
"""
|
||||
img = deepcopy(media_data.BASE64_IMAGE)
|
||||
image_input = gr.Image()
|
||||
self.assertEqual(image_input.preprocess(img).shape, (68, 61, 3))
|
||||
image_input = gr.Image(shape=(25, 25), image_mode="L")
|
||||
@ -454,7 +519,7 @@ class TestImage(unittest.TestCase):
|
||||
image_input.preprocess(img)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
file_image = gr.Image(type="file")
|
||||
file_image.preprocess(media_data.BASE64_IMAGE)
|
||||
file_image.preprocess(deepcopy(media_data.BASE64_IMAGE))
|
||||
file_image = gr.Image(type="filepath")
|
||||
self.assertIsInstance(file_image.preprocess(img), str)
|
||||
with self.assertRaises(ValueError):
|
||||
@ -474,47 +539,10 @@ class TestImage(unittest.TestCase):
|
||||
image_input.shape = (30, 10)
|
||||
self.assertIsNotNone(image_input._segment_by_slic(img))
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
img = media_data.BASE64_IMAGE
|
||||
image_input = gr.Image()
|
||||
iface = gr.Interface(
|
||||
lambda x: PIL.Image.open(x).rotate(90, expand=True),
|
||||
gr.Image(shape=(30, 10), type="file"),
|
||||
"image",
|
||||
# Output functionalities
|
||||
y_img = gr.processing_utils.decode_base64_to_image(
|
||||
deepcopy(media_data.BASE64_IMAGE)
|
||||
)
|
||||
output = iface.process([img])[0][0]
|
||||
self.assertEqual(
|
||||
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
|
||||
)
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(scores, media_data.SUM_PIXELS_INTERPRETATION["scores"])
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
media_data.SUM_PIXELS_INTERPRETATION["alternative_outputs"],
|
||||
)
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "label", interpretation="shap"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(
|
||||
len(scores[0]),
|
||||
len(media_data.SUM_PIXELS_SHAP_INTERPRETATION["scores"][0]),
|
||||
)
|
||||
self.assertEqual(
|
||||
len(alternative_outputs[0]),
|
||||
len(media_data.SUM_PIXELS_SHAP_INTERPRETATION["alternative_outputs"][0]),
|
||||
)
|
||||
image_input = gr.Image(shape=(30, 10))
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "number", interpretation="default"
|
||||
)
|
||||
self.assertIsNotNone(iface.interpret([img]))
|
||||
|
||||
def test_as_component_as_output(self):
|
||||
y_img = gr.processing_utils.decode_base64_to_image(media_data.BASE64_IMAGE)
|
||||
image_output = gr.Image()
|
||||
self.assertTrue(
|
||||
image_output.postprocess(y_img).startswith(
|
||||
@ -544,15 +572,69 @@ class TestImage(unittest.TestCase):
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = image_output.save_flagged(
|
||||
tmpdirname, "image_output", media_data.BASE64_IMAGE, None
|
||||
tmpdirname, "image_output", deepcopy(media_data.BASE64_IMAGE), None
|
||||
)
|
||||
self.assertEqual("image_output/0.png", to_save)
|
||||
to_save = image_output.save_flagged(
|
||||
tmpdirname, "image_output", media_data.BASE64_IMAGE, None
|
||||
tmpdirname, "image_output", deepcopy(media_data.BASE64_IMAGE), None
|
||||
)
|
||||
self.assertEqual("image_output/1.png", to_save)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process, interpret
|
||||
type: file
|
||||
interpretation: default, shap,
|
||||
"""
|
||||
img = deepcopy(media_data.BASE64_IMAGE)
|
||||
image_input = gr.Image()
|
||||
iface = gr.Interface(
|
||||
lambda x: PIL.Image.open(x).rotate(90, expand=True),
|
||||
gr.Image(shape=(30, 10), type="file"),
|
||||
"image",
|
||||
)
|
||||
output = iface.process([img])[0][0]
|
||||
self.assertEqual(
|
||||
gr.processing_utils.decode_base64_to_image(output).size, (10, 30)
|
||||
)
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "number", interpretation="default"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(
|
||||
scores, deepcopy(media_data.SUM_PIXELS_INTERPRETATION)["scores"]
|
||||
)
|
||||
self.assertEqual(
|
||||
alternative_outputs,
|
||||
deepcopy(media_data.SUM_PIXELS_INTERPRETATION)["alternative_outputs"],
|
||||
)
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "label", interpretation="shap"
|
||||
)
|
||||
scores, alternative_outputs = iface.interpret([img])
|
||||
self.assertEqual(
|
||||
len(scores[0]),
|
||||
len(deepcopy(media_data.SUM_PIXELS_SHAP_INTERPRETATION)["scores"][0]),
|
||||
)
|
||||
self.assertEqual(
|
||||
len(alternative_outputs[0]),
|
||||
len(
|
||||
deepcopy(media_data.SUM_PIXELS_SHAP_INTERPRETATION)[
|
||||
"alternative_outputs"
|
||||
][0]
|
||||
),
|
||||
)
|
||||
image_input = gr.Image(shape=(30, 10))
|
||||
iface = gr.Interface(
|
||||
lambda x: np.sum(x), image_input, "number", interpretation="default"
|
||||
)
|
||||
self.assertIsNotNone(iface.interpret([img]))
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
|
||||
def generate_noise(width, height):
|
||||
return np.random.randint(0, 256, (width, height, 3))
|
||||
|
||||
@ -563,8 +645,12 @@ class TestImage(unittest.TestCase):
|
||||
|
||||
|
||||
class TestAudio(unittest.TestCase):
|
||||
def test_as_component_as_input(self):
|
||||
x_wav = copy.deepcopy(media_data.BASE64_AUDIO)
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, postprocess serialize, save_flagged, restore_flagged, generate_sample, get_template_context, deserialize
|
||||
type: filepath, numpy, file
|
||||
"""
|
||||
x_wav = deepcopy(media_data.BASE64_AUDIO)
|
||||
audio_input = gr.Audio()
|
||||
output = audio_input.preprocess(x_wav)
|
||||
self.assertEqual(output[0], 8000)
|
||||
@ -613,18 +699,9 @@ class TestAudio(unittest.TestCase):
|
||||
x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav")
|
||||
self.assertIsInstance(audio_input.serialize(x_wav, False), dict)
|
||||
|
||||
def test_tokenize(self):
|
||||
x_wav = media_data.BASE64_AUDIO
|
||||
audio_input = gr.Audio()
|
||||
tokens, _, _ = audio_input.tokenize(x_wav)
|
||||
self.assertEquals(len(tokens), audio_input.interpretation_segments)
|
||||
x_new = audio_input.get_masked_inputs(tokens, [[1] * len(tokens)])[0]
|
||||
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
|
||||
self.assertGreater(similarity, 0.9)
|
||||
|
||||
def test_as_component_as_output(self):
|
||||
# Output functionalities
|
||||
y_audio = gr.processing_utils.decode_base64_to_file(
|
||||
media_data.BASE64_AUDIO["data"]
|
||||
deepcopy(media_data.BASE64_AUDIO)["data"]
|
||||
)
|
||||
audio_output = gr.Audio(type="file")
|
||||
self.assertTrue(
|
||||
@ -643,19 +720,38 @@ class TestAudio(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
self.assertTrue(
|
||||
audio_output.deserialize(media_data.BASE64_AUDIO["data"]).endswith(".wav")
|
||||
audio_output.deserialize(
|
||||
deepcopy(media_data.BASE64_AUDIO)["data"]
|
||||
).endswith(".wav")
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = audio_output.save_flagged(
|
||||
tmpdirname, "audio_output", media_data.BASE64_AUDIO, None
|
||||
tmpdirname, "audio_output", deepcopy(media_data.BASE64_AUDIO), None
|
||||
)
|
||||
self.assertEqual("audio_output/0.wav", to_save)
|
||||
to_save = audio_output.save_flagged(
|
||||
tmpdirname, "audio_output", media_data.BASE64_AUDIO, None
|
||||
tmpdirname, "audio_output", deepcopy(media_data.BASE64_AUDIO), None
|
||||
)
|
||||
self.assertEqual("audio_output/1.wav", to_save)
|
||||
|
||||
def test_tokenize(self):
|
||||
"""
|
||||
Tokenize, get_masked_inputs
|
||||
"""
|
||||
x_wav = deepcopy(media_data.BASE64_AUDIO)
|
||||
audio_input = gr.Audio()
|
||||
tokens, _, _ = audio_input.tokenize(x_wav)
|
||||
self.assertEquals(len(tokens), audio_input.interpretation_segments)
|
||||
x_new = audio_input.get_masked_inputs(tokens, [[1] * len(tokens)])[0]
|
||||
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
|
||||
self.assertGreater(similarity, 0.9)
|
||||
|
||||
# TODO: add test_in_interface_as_input
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
|
||||
def generate_noise(duration):
|
||||
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int16)
|
||||
|
||||
@ -664,8 +760,11 @@ class TestAudio(unittest.TestCase):
|
||||
|
||||
|
||||
class TestFile(unittest.TestCase):
|
||||
def test_as_component_as_input(self):
|
||||
x_file = media_data.BASE64_FILE
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, serialize, save_flagged, restore_flagged, generate_sample, get_template_context, default_value
|
||||
"""
|
||||
x_file = deepcopy(media_data.BASE64_FILE)
|
||||
file_input = gr.File()
|
||||
output = file_input.preprocess(x_file)
|
||||
self.assertIsInstance(output, tempfile._TemporaryFileWrapper)
|
||||
@ -698,8 +797,17 @@ class TestFile(unittest.TestCase):
|
||||
x_file["is_example"] = True
|
||||
self.assertIsNotNone(file_input.preprocess(x_file))
|
||||
|
||||
file_input = gr.File("test/test_files/sample_file.pdf")
|
||||
self.assertEqual(
|
||||
file_input.get_template_context(),
|
||||
deepcopy(media_data.FILE_TEMPLATE_CONTEXT),
|
||||
)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
x_file = media_data.BASE64_FILE
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
x_file = deepcopy(media_data.BASE64_FILE)
|
||||
|
||||
def get_size_of_file(file_obj):
|
||||
return os.path.getsize(file_obj.name)
|
||||
@ -708,6 +816,10 @@ class TestFile(unittest.TestCase):
|
||||
self.assertEqual(iface.process([[x_file]])[0], [10558])
|
||||
|
||||
def test_as_component_as_output(self):
|
||||
"""
|
||||
Interface, process, save_flagged,
|
||||
"""
|
||||
|
||||
def write_file(content):
|
||||
with open("test.txt", "w") as f:
|
||||
f.write(content)
|
||||
@ -725,17 +837,20 @@ class TestFile(unittest.TestCase):
|
||||
file_output = gr.File()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = file_output.save_flagged(
|
||||
tmpdirname, "file_output", [media_data.BASE64_FILE], None
|
||||
tmpdirname, "file_output", [deepcopy(media_data.BASE64_FILE)], None
|
||||
)
|
||||
self.assertEqual("file_output/0", to_save)
|
||||
to_save = file_output.save_flagged(
|
||||
tmpdirname, "file_output", [media_data.BASE64_FILE], None
|
||||
tmpdirname, "file_output", [deepcopy(media_data.BASE64_FILE)], None
|
||||
)
|
||||
self.assertEqual("file_output/1", to_save)
|
||||
|
||||
|
||||
class TestDataframe(unittest.TestCase):
|
||||
def test_as_component_as_input(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, serialize, save_flagged, restore_flagged, generate_sample, get_template_context
|
||||
"""
|
||||
x_data = [["Tim", 12, False], ["Jan", 24, True]]
|
||||
dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"])
|
||||
output = dataframe_input.preprocess(x_data)
|
||||
@ -784,19 +899,7 @@ class TestDataframe(unittest.TestCase):
|
||||
wrong_type = gr.Dataframe(type="unknown")
|
||||
wrong_type.preprocess(x_data)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
x_data = [[1, 2, 3], [4, 5, 6]]
|
||||
iface = gr.Interface(np.max, "numpy", "number")
|
||||
self.assertEqual(iface.process([x_data])[0], [6])
|
||||
x_data = [["Tim"], ["Jon"], ["Sal"]]
|
||||
|
||||
def get_last(my_list):
|
||||
return my_list[-1]
|
||||
|
||||
iface = gr.Interface(get_last, "list", "text")
|
||||
self.assertEqual(iface.process([x_data])[0], ["Sal"])
|
||||
|
||||
def test_as_component_as_output(self):
|
||||
# Output functionalities
|
||||
dataframe_output = gr.Dataframe()
|
||||
output = dataframe_output.postprocess(np.zeros((2, 2)))
|
||||
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]]})
|
||||
@ -828,7 +931,6 @@ class TestDataframe(unittest.TestCase):
|
||||
[None, None, None],
|
||||
[None, None, None],
|
||||
],
|
||||
"name": "dataframe",
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
@ -855,7 +957,26 @@ class TestDataframe(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process,
|
||||
"""
|
||||
x_data = [[1, 2, 3], [4, 5, 6]]
|
||||
iface = gr.Interface(np.max, "numpy", "number")
|
||||
self.assertEqual(iface.process([x_data])[0], [6])
|
||||
x_data = [["Tim"], ["Jon"], ["Sal"]]
|
||||
|
||||
def get_last(my_list):
|
||||
return my_list[-1]
|
||||
|
||||
iface = gr.Interface(get_last, "list", "text")
|
||||
self.assertEqual(iface.process([x_data])[0], ["Sal"])
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
|
||||
def check_odd(array):
|
||||
return array % 2 == 0
|
||||
|
||||
@ -866,8 +987,11 @@ class TestDataframe(unittest.TestCase):
|
||||
|
||||
|
||||
class TestVideo(unittest.TestCase):
|
||||
def test_as_component_as_input(self):
|
||||
x_video = media_data.BASE64_VIDEO
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, serialize, deserialize, save_flagged, restore_flagged, generate_sample, get_template_context
|
||||
"""
|
||||
x_video = deepcopy(media_data.BASE64_VIDEO)
|
||||
video_input = gr.Video()
|
||||
output = video_input.preprocess(x_video)
|
||||
self.assertIsInstance(output, str)
|
||||
@ -896,37 +1020,49 @@ class TestVideo(unittest.TestCase):
|
||||
x_video["is_example"] = True
|
||||
self.assertIsNotNone(video_input.preprocess(x_video))
|
||||
video_input = gr.Video(type="avi")
|
||||
# self.assertEqual(video_input.preprocess(x_video)[-3:], "avi")
|
||||
self.assertEqual(video_input.preprocess(x_video)[-3:], "avi")
|
||||
with self.assertRaises(NotImplementedError):
|
||||
video_input.serialize(x_video, True)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
x_video = media_data.BASE64_VIDEO
|
||||
iface = gr.Interface(lambda x: x, "video", "playable_video")
|
||||
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
|
||||
|
||||
def test_as_component_as_output(self):
|
||||
y_vid = "test/test_files/video_sample.mp4"
|
||||
# Output functionalities
|
||||
y_vid_path = "test/test_files/video_sample.mp4"
|
||||
video_output = gr.Video()
|
||||
self.assertTrue(
|
||||
video_output.postprocess(y_vid)["data"].startswith("data:video/mp4;base64,")
|
||||
video_output.postprocess(y_vid_path)["data"].startswith(
|
||||
"data:video/mp4;base64,"
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
video_output.deserialize(media_data.BASE64_VIDEO["data"]).endswith(".mp4")
|
||||
video_output.deserialize(
|
||||
deepcopy(media_data.BASE64_VIDEO)["data"]
|
||||
).endswith(".mp4")
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = video_output.save_flagged(
|
||||
tmpdirname, "video_output", media_data.BASE64_VIDEO, None
|
||||
tmpdirname, "video_output", deepcopy(media_data.BASE64_VIDEO), None
|
||||
)
|
||||
self.assertEqual("video_output/0.mp4", to_save)
|
||||
to_save = video_output.save_flagged(
|
||||
tmpdirname, "video_output", media_data.BASE64_VIDEO, None
|
||||
tmpdirname, "video_output", deepcopy(media_data.BASE64_VIDEO), None
|
||||
)
|
||||
self.assertEqual("video_output/1.mp4", to_save)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
x_video = deepcopy(media_data.BASE64_VIDEO)
|
||||
iface = gr.Interface(lambda x: x, "video", "playable_video")
|
||||
self.assertEqual(iface.process([x_video])[0][0]["data"], x_video["data"])
|
||||
|
||||
# TODO: test_in_interface_as_output
|
||||
|
||||
|
||||
class TestTimeseries(unittest.TestCase):
|
||||
def test_as_component_as_input(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Preprocess, postprocess, save_flagged, restore_flagged, generate_sample, get_template_context,
|
||||
"""
|
||||
timeseries_input = gr.Timeseries(x="time", y=["retail", "food", "other"])
|
||||
x_timeseries = {
|
||||
"data": [[1] + [2] * len(timeseries_input.y)] * 4,
|
||||
@ -962,24 +1098,8 @@ class TestTimeseries(unittest.TestCase):
|
||||
x_timeseries["range"] = (0, 1)
|
||||
self.assertIsNotNone(timeseries_input.preprocess(x_timeseries))
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
timeseries_input = gr.Timeseries(x="time", y=["retail", "food", "other"])
|
||||
x_timeseries = {
|
||||
"data": [[1] + [2] * len(timeseries_input.y)] * 4,
|
||||
"headers": [timeseries_input.x] + timeseries_input.y,
|
||||
}
|
||||
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
|
||||
self.assertEqual(
|
||||
iface.process([x_timeseries])[0],
|
||||
[
|
||||
{
|
||||
"headers": ["time", "retail", "food", "other"],
|
||||
"data": [[1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2]],
|
||||
}
|
||||
],
|
||||
)
|
||||
# Output functionalities
|
||||
|
||||
def test_as_component_as_output(self):
|
||||
timeseries_output = gr.Timeseries(label="Disease")
|
||||
|
||||
self.assertEqual(
|
||||
@ -1030,9 +1150,31 @@ class TestTimeseries(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: test_in_interface_as_input
|
||||
|
||||
def test_in_interface_as_output(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
timeseries_input = gr.Timeseries(x="time", y=["retail", "food", "other"])
|
||||
x_timeseries = {
|
||||
"data": [[1] + [2] * len(timeseries_input.y)] * 4,
|
||||
"headers": [timeseries_input.x] + timeseries_input.y,
|
||||
}
|
||||
iface = gr.Interface(lambda x: x, timeseries_input, "dataframe")
|
||||
self.assertEqual(
|
||||
iface.process([x_timeseries])[0],
|
||||
[
|
||||
{
|
||||
"headers": ["time", "retail", "food", "other"],
|
||||
"data": [[1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2]],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestNames(unittest.TestCase):
|
||||
# this ensures that `components.get_component_instance()` works correctly when instantiating from components
|
||||
# This test ensures that `components.get_component_instance()` works correctly when instantiating from components
|
||||
def test_no_duplicate_uncased_names(self):
|
||||
subclasses = gr.components.Component.__subclasses__()
|
||||
unique_subclasses_uncased = set([s.__name__.lower() for s in subclasses])
|
||||
@ -1040,7 +1182,10 @@ class TestNames(unittest.TestCase):
|
||||
|
||||
|
||||
class TestLabel(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Process, postprocess, deserialize, save_flagged, restore_flagged
|
||||
"""
|
||||
y = "happy"
|
||||
label_output = gr.Label()
|
||||
label = label_output.postprocess(y)
|
||||
@ -1093,8 +1238,16 @@ class TestLabel(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
label_output.get_template_context(),
|
||||
{"name": "label", "label": None, "css": {}},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
x_img = media_data.BASE64_IMAGE
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
x_img = deepcopy(media_data.BASE64_IMAGE)
|
||||
|
||||
def rgb_distribution(img):
|
||||
rgb_dist = np.mean(img, axis=(0, 1))
|
||||
@ -1122,7 +1275,10 @@ class TestLabel(unittest.TestCase):
|
||||
|
||||
|
||||
class TestHighlightedText(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
get_template_context, save_flagged, restore_flagged
|
||||
"""
|
||||
ht_output = gr.HighlightedText(color_map={"pos": "green", "neg": "red"})
|
||||
self.assertEqual(
|
||||
ht_output.get_template_context(),
|
||||
@ -1145,6 +1301,10 @@ class TestHighlightedText(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
|
||||
def highlight_vowels(sentence):
|
||||
phrases, cur_phrase = [], ""
|
||||
vowels, mode = "aeiou", None
|
||||
@ -1168,7 +1328,10 @@ class TestHighlightedText(unittest.TestCase):
|
||||
|
||||
|
||||
class TestJSON(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Postprocess, save_flagged, restore_flagged
|
||||
"""
|
||||
js_output = gr.JSON()
|
||||
self.assertTrue(
|
||||
js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"'
|
||||
@ -1181,8 +1344,16 @@ class TestJSON(unittest.TestCase):
|
||||
js_output.restore_flagged(tmpdirname, to_save, None),
|
||||
{"pos": "Hello ", "neg": "World"},
|
||||
)
|
||||
self.assertEqual(
|
||||
js_output.get_template_context(),
|
||||
{"css": {}, "default_value": '""', "label": None, "name": "json"},
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
|
||||
def get_avg_age_per_gender(data):
|
||||
return {
|
||||
"M": int(data[data["gender"] == "M"].mean()),
|
||||
@ -1206,7 +1377,26 @@ class TestJSON(unittest.TestCase):
|
||||
|
||||
|
||||
class TestHTML(unittest.TestCase):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Get_template_context
|
||||
"""
|
||||
html_component = gr.components.HTML("#Welcome onboard", label="HTML Input")
|
||||
self.assertEqual(
|
||||
{
|
||||
"css": {},
|
||||
"default_value": "#Welcome onboard",
|
||||
"label": "HTML Input",
|
||||
"name": "html",
|
||||
},
|
||||
html_component.get_template_context(),
|
||||
)
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
|
||||
def bold_text(text):
|
||||
return "<strong>" + text + "</strong>"
|
||||
|
||||
@ -1215,7 +1405,10 @@ class TestHTML(unittest.TestCase):
|
||||
|
||||
|
||||
class TestCarousel(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
def test_component_functionalities(self):
|
||||
"""
|
||||
Postprocess, get_template_context, save_flagged, restore_flagged
|
||||
"""
|
||||
carousel_output = gr.Carousel(
|
||||
components=[gr.Textbox(), gr.Image()], label="Disease"
|
||||
)
|
||||
@ -1229,8 +1422,8 @@ class TestCarousel(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
["Hello World", media_data.BASE64_IMAGE],
|
||||
["Bye World", media_data.BASE64_IMAGE],
|
||||
["Hello World", deepcopy(media_data.BASE64_IMAGE)],
|
||||
["Bye World", deepcopy(media_data.BASE64_IMAGE)],
|
||||
],
|
||||
)
|
||||
|
||||
@ -1264,8 +1457,13 @@ class TestCarousel(unittest.TestCase):
|
||||
tmpdirname, "carousel_output", output, None
|
||||
)
|
||||
self.assertEqual(to_save, '[["Hello World"], ["Bye World"]]')
|
||||
restored = carousel_output.restore_flagged(tmpdirname, output, None)
|
||||
self.assertEqual(None, restored)
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process
|
||||
"""
|
||||
carousel_output = gr.Carousel(
|
||||
components=[gr.Textbox(), gr.Image()], label="Disease"
|
||||
)
|
||||
@ -1279,7 +1477,7 @@ class TestCarousel(unittest.TestCase):
|
||||
return results
|
||||
|
||||
iface = gr.Interface(report, gr.inputs.Image(type="numpy"), carousel_output)
|
||||
result = iface.process([media_data.BASE64_IMAGE])
|
||||
result = iface.process([deepcopy(media_data.BASE64_IMAGE)])
|
||||
self.assertTrue(result[0][0][0][0] == "Red")
|
||||
self.assertTrue(
|
||||
result[0][0][0][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
@ -1293,5 +1491,6 @@ class TestCarousel(unittest.TestCase):
|
||||
result[0][0][2][1].startswith("data:image/png;base64,iVBORw0KGgoAAA")
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user