mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
111 lines
4.9 KiB
Python
111 lines
4.9 KiB
Python
import unittest
|
|
import gradio.interpretation
|
|
import gradio.test_data
|
|
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
|
|
from gradio import Interface
|
|
import numpy as np
|
|
|
|
|
|
class TestDefault(unittest.TestCase):
|
|
def test_default_text(self):
|
|
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
|
text_interface = Interface(max_word_len, "textbox", "label", interpretation="default")
|
|
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
|
|
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
|
|
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
|
|
|
|
class TestShapley(unittest.TestCase):
|
|
def test_shapley_text(self):
|
|
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
|
text_interface = Interface(max_word_len, "textbox", "label", interpretation="shapley")
|
|
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
|
|
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
|
|
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
|
|
|
|
class TestCustom(unittest.TestCase):
|
|
def test_custom_text(self):
|
|
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
|
custom = lambda text: [(char, 1) for char in text]
|
|
text_interface = Interface(max_word_len, "textbox", "label", interpretation=custom)
|
|
result = text_interface.interpret(["quickest brown fox"])[0][0]
|
|
self.assertEqual(result[0][1], 1) # Checks to see if the first letter has score of 1.
|
|
|
|
def test_custom_img(self):
|
|
max_pixel_value = lambda img: img.max()
|
|
custom = lambda img: img.tolist()
|
|
img_interface = Interface(max_pixel_value, "image", "label", interpretation=custom)
|
|
result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0][0]
|
|
expected_result = np.asarray(decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert('RGB')).tolist()
|
|
self.assertEqual(result, expected_result)
|
|
|
|
|
|
class TestHelperMethods(unittest.TestCase):
|
|
def test_diff(self):
|
|
diff = gradio.interpretation.diff(13, "2")
|
|
self.assertEquals(diff, 11)
|
|
diff = gradio.interpretation.diff("cat", "dog")
|
|
self.assertEquals(diff, 1)
|
|
diff = gradio.interpretation.diff("cat", "cat")
|
|
self.assertEquals(diff, 0)
|
|
|
|
def test_quantify_difference_with_textbox(self):
|
|
iface = Interface(lambda text: text, ["textbox"], ["textbox"])
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test"])
|
|
self.assertEquals(diff, 0)
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test_diff"])
|
|
self.assertEquals(diff, 1)
|
|
|
|
def test_quantify_difference_with_label(self):
|
|
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"])
|
|
self.assertEquals(diff, -7)
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"])
|
|
self.assertEquals(diff, -100)
|
|
|
|
def test_quantify_difference_with_confidences(self):
|
|
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
|
output_1 = {
|
|
"cat": 0.9,
|
|
"dog": 0.1
|
|
}
|
|
output_2 = {
|
|
"cat": 0.6,
|
|
"dog": 0.4
|
|
}
|
|
output_3 = {
|
|
"cat": 0.1,
|
|
"dog": 0.6
|
|
}
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_2])
|
|
self.assertAlmostEquals(diff, 0.3)
|
|
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_3])
|
|
self.assertAlmostEquals(diff, 0.8)
|
|
|
|
def test_get_regression_value(self):
|
|
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
|
output_1 = {
|
|
"cat": 0.9,
|
|
"dog": 0.1
|
|
}
|
|
output_2 = {
|
|
"cat": float("nan"),
|
|
"dog": 0.4
|
|
}
|
|
output_3 = {
|
|
"cat": 0.1,
|
|
"dog": 0.6
|
|
}
|
|
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_2])
|
|
self.assertEquals(diff, 0)
|
|
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_3])
|
|
self.assertAlmostEquals(diff, 0.1)
|
|
|
|
def test_get_classification_value(self):
|
|
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
|
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["cat"], ["test"])
|
|
self.assertEquals(diff, 1)
|
|
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["test"], ["test"])
|
|
self.assertEquals(diff, 0)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |