gradio/test/test_interpretation.py

45 lines
2.1 KiB
Python
Raw Normal View History

2020-10-05 20:43:42 +08:00
import unittest
import gradio.interpretation
import gradio.test_data
2020-10-22 20:09:17 +08:00
from gradio.processing_utils import decode_base64_to_image, encode_array_to_base64
2020-10-05 20:43:42 +08:00
from gradio import Interface
import numpy as np
class TestDefault(unittest.TestCase):
def test_default_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
2020-10-22 20:09:17 +08:00
text_interface = Interface(max_word_len, "textbox", "label", interpretation="default")
2020-11-05 00:57:09 +08:00
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
2020-10-22 20:09:17 +08:00
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
2020-10-05 20:43:42 +08:00
def test_default_image(self):
max_pixel_value = lambda img: img.max()
2020-10-22 20:09:17 +08:00
img_interface = Interface(max_pixel_value, "image", "number", interpretation="default")
2020-10-05 20:43:42 +08:00
array = np.zeros((100,100))
array[0, 0] = 1
2020-10-22 20:09:17 +08:00
img = encode_array_to_base64(array)
2020-11-05 00:57:09 +08:00
interpretation = img_interface.interpret([img])[0][0]
2020-10-05 20:43:42 +08:00
self.assertGreater(interpretation[0][0], 0) # Checks to see if the top-left has >0 score.
class TestCustom(unittest.TestCase):
def test_custom_text(self):
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
custom = lambda text: [(char, 1) for char in text]
text_interface = Interface(max_word_len, "textbox", "label", interpretation=custom)
2020-11-05 00:57:09 +08:00
result = text_interface.interpret(["quickest brown fox"])[0][0]
2020-10-05 20:43:42 +08:00
self.assertEqual(result[0][1], 1) # Checks to see if the first letter has score of 1.
def test_custom_img(self):
max_pixel_value = lambda img: img.max()
custom = lambda img: img.tolist()
img_interface = Interface(max_pixel_value, "image", "label", interpretation=custom)
2020-11-05 00:57:09 +08:00
result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0][0]
2020-10-05 20:43:42 +08:00
expected_result = np.asarray(decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert('RGB')).tolist()
self.assertEqual(result, expected_result)
if __name__ == '__main__':
unittest.main()