From 6b13238c9d4fb771055f2517c867c12bd30551cd Mon Sep 17 00:00:00 2001 From: Abubakar Abid <a12d@stanford.edu> Date: Tue, 26 Oct 2021 17:36:12 -0500 Subject: [PATCH 1/2] added interpretation --- gradio/interface.py | 2 +- gradio/interpretation.py | 2 +- test/test_interpretation.py | 87 ++++++++++++++++++++++++++++++++----- 3 files changed, 78 insertions(+), 13 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 1881f9c853..d84bc1c9ba 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -424,7 +424,7 @@ class Interface: scores.append( input_component.get_interpretation_scores( raw_input[i], neighbor_values, interface_scores, **interpret_kwargs)) - elif interp == "shap": + elif interp == "shap" or interp == "shapley": try: import shap except (ImportError, ModuleNotFoundError): diff --git a/gradio/interpretation.py b/gradio/interpretation.py index 278aac0135..c9dd2922f4 100644 --- a/gradio/interpretation.py +++ b/gradio/interpretation.py @@ -49,7 +49,7 @@ def get_regression_or_classification_value(interface, original_output, perturbed return 0 return perturbed_output[0][original_label] else: - score = diff(perturbed_label, original_label) # Intentionall inverted order of arguments. + score = diff(perturbed_label, original_label) # Intentionally inverted order of arguments. return score else: diff --git a/test/test_interpretation.py b/test/test_interpretation.py index e068b20d5c..0a31f942e2 100644 --- a/test/test_interpretation.py +++ b/test/test_interpretation.py @@ -12,18 +12,15 @@ class TestDefault(unittest.TestCase): text_interface = Interface(max_word_len, "textbox", "label", interpretation="default") interpretation = text_interface.interpret(["quickest brown fox"])[0][0] self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score. - self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score. + self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score. - ## Commented out since skimage is no longer a required dependency, this will fail in CircleCI TODO(abidlabs): have backup default segmentation - # def test_default_image(self): - # max_pixel_value = lambda img: img.max() - # img_interface = Interface(max_pixel_value, "image", "number", interpretation="default") - # array = np.zeros((100,100)) - # array[0, 0] = 1 - # img = encode_array_to_base64(array) - # interpretation = img_interface.interpret([img])[0][0] - # self.assertGreater(interpretation[0][0], 0) # Checks to see if the top-left has >0 score. - +class TestShapley(unittest.TestCase): + def test_shapley_text(self): + max_word_len = lambda text: max([len(word) for word in text.split(" ")]) + text_interface = Interface(max_word_len, "textbox", "label", interpretation="shapley") + interpretation = text_interface.interpret(["quickest brown fox"])[0][0] + self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score. + self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score. class TestCustom(unittest.TestCase): def test_custom_text(self): @@ -42,5 +39,73 @@ class TestCustom(unittest.TestCase): self.assertEqual(result, expected_result) +class TestHelperMethods(unittest.TestCase): + def test_diff(self): + diff = gradio.interpretation.diff(13, "2") + self.assertEquals(diff, 11) + diff = gradio.interpretation.diff("cat", "dog") + self.assertEquals(diff, 1) + diff = gradio.interpretation.diff("cat", "cat") + self.assertEquals(diff, 0) + + def test_quantify_difference_with_textbox(self): + iface = Interface(lambda text: text, ["textbox"], ["textbox"]) + diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test"]) + self.assertEquals(diff, 0) + diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test_diff"]) + self.assertEquals(diff, 1) + + def test_quantify_difference_with_label(self): + iface = Interface(lambda text: len(text), ["textbox"], ["label"]) + diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"]) + self.assertEquals(diff, -7) + diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"]) + self.assertEquals(diff, -100) + + def test_quantify_difference_with_confidences(self): + iface = Interface(lambda text: len(text), ["textbox"], ["label"]) + output_1 = { + "cat": 0.9, + "dog": 0.1 + } + output_2 = { + "cat": 0.6, + "dog": 0.4 + } + output_3 = { + "cat": 0.1, + "dog": 0.6 + } + diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_2]) + self.assertAlmostEquals(diff, 0.3) + diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_3]) + self.assertAlmostEquals(diff, 0.8) + + def test_get_regression_value(self): + iface = Interface(lambda text: text, ["textbox"], ["label"]) + output_1 = { + "cat": 0.9, + "dog": 0.1 + } + output_2 = { + "cat": float("nan"), + "dog": 0.4 + } + output_3 = { + "cat": 0.1, + "dog": 0.6 + } + diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_2]) + self.assertEquals(diff, 0) + diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_3]) + self.assertAlmostEquals(diff, 0.1) + + def test_get_classification_value(self): + iface = Interface(lambda text: text, ["textbox"], ["label"]) + diff = gradio.interpretation.get_regression_or_classification_value(iface, ["cat"], ["test"]) + self.assertEquals(diff, 1) + diff = gradio.interpretation.get_regression_or_classification_value(iface, ["test"], ["test"]) + self.assertEquals(diff, 0) + if __name__ == '__main__': unittest.main() \ No newline at end of file From 966a30ada74315705128c53a080ed6caa6a6b1e2 Mon Sep 17 00:00:00 2001 From: Abubakar Abid <a12d@stanford.edu> Date: Tue, 26 Oct 2021 17:47:49 -0500 Subject: [PATCH 2/2] added shap --- gradio.egg-info/requires.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gradio.egg-info/requires.txt b/gradio.egg-info/requires.txt index fa44b53bdd..d6ea781d61 100644 --- a/gradio.egg-info/requires.txt +++ b/gradio.egg-info/requires.txt @@ -13,4 +13,5 @@ Flask>=1.1.1 Flask-Cors>=3.0.8 flask-cachebuster Flask-Login -IPython \ No newline at end of file +IPython +shap