From aab300cf3ffb5afe238d775ea09a9cf95c88c48d Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 5 Oct 2020 07:43:42 -0500 Subject: [PATCH] added test for interpretation --- gradio/interface.py | 4 ++-- gradio/interpretation.py | 2 +- test/test_interpretation.py | 47 +++++++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 test/test_interpretation.py diff --git a/gradio/interface.py b/gradio/interface.py index a561dfce59..151f4d7f52 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -266,10 +266,10 @@ class Interface: if self.capture_session and self.session is not None: graph, sess = self.session with graph.as_default(), sess.as_default(): - interpretation = interpreter(*processed_input).tolist() + interpretation = interpreter(*processed_input) else: try: - interpretation = interpreter(*processed_input).tolist() + interpretation = interpreter(*processed_input) except ValueError as exception: if str(exception).endswith("is not an element of this graph."): raise ValueError(strings.en["TF1_ERROR"]) diff --git a/gradio/interpretation.py b/gradio/interpretation.py index 03a21f13b8..a8de27c217 100644 --- a/gradio/interpretation.py +++ b/gradio/interpretation.py @@ -1,6 +1,5 @@ from gradio.inputs import Image, Textbox from gradio.outputs import Label -from gradio import processing_utils from skimage.segmentation import slic import numpy as np @@ -8,6 +7,7 @@ expected_types = { Image: "numpy", } + def default(separator=" ", n_segments=20): """ Basic "default" interpretation method that uses "leave-one-out" to explain predictions for diff --git a/test/test_interpretation.py b/test/test_interpretation.py new file mode 100644 index 0000000000..d6fc230736 --- /dev/null +++ b/test/test_interpretation.py @@ -0,0 +1,47 @@ +import unittest +import gradio.interpretation +import gradio.test_data +from gradio.processing_utils import decode_base64_to_image +from gradio import Interface +import numpy as np + + +class TestDefault(unittest.TestCase): + def setUp(self): + self.default_method = gradio.interpretation.default() + + def test_default_text(self): + max_word_len = lambda text: max([len(word) for word in text.split(" ")]) + text_interface = Interface(max_word_len, "textbox", "label") + interpretation = self.default_method(text_interface, ["quickest brown fox"])[0] + self.assertGreater(interpretation[0][1], 0) # Checks to see if the first letter has >0 score. + self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last letter has 0 score. + + def test_default_image(self): + max_pixel_value = lambda img: img.max() + img_interface = Interface(max_pixel_value, "image", "label") + array = np.zeros((100,100)) + array[0, 0] = 1 + interpretation = self.default_method(img_interface, [array])[0] + self.assertGreater(interpretation[0][0], 0) # Checks to see if the top-left has >0 score. + + +class TestCustom(unittest.TestCase): + def test_custom_text(self): + max_word_len = lambda text: max([len(word) for word in text.split(" ")]) + custom = lambda text: [(char, 1) for char in text] + text_interface = Interface(max_word_len, "textbox", "label", interpretation=custom) + result = text_interface.interpret(["quickest brown fox"])[0] + self.assertEqual(result[0][1], 1) # Checks to see if the first letter has score of 1. + + def test_custom_img(self): + max_pixel_value = lambda img: img.max() + custom = lambda img: img.tolist() + img_interface = Interface(max_pixel_value, "image", "label", interpretation=custom) + result = img_interface.interpret([gradio.test_data.BASE64_IMAGE])[0] + expected_result = np.asarray(decode_base64_to_image(gradio.test_data.BASE64_IMAGE).convert('RGB')).tolist() + self.assertEqual(result, expected_result) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file