From 6b13238c9d4fb771055f2517c867c12bd30551cd Mon Sep 17 00:00:00 2001
From: Abubakar Abid <a12d@stanford.edu>
Date: Tue, 26 Oct 2021 17:36:12 -0500
Subject: [PATCH 1/2] added interpretation

---
 gradio/interface.py         |  2 +-
 gradio/interpretation.py    |  2 +-
 test/test_interpretation.py | 87 ++++++++++++++++++++++++++++++++-----
 3 files changed, 78 insertions(+), 13 deletions(-)

diff --git a/gradio/interface.py b/gradio/interface.py
index 1881f9c853..d84bc1c9ba 100644
--- a/gradio/interface.py
+++ b/gradio/interface.py
@@ -424,7 +424,7 @@ class Interface:
                         scores.append(
                             input_component.get_interpretation_scores(
                                 raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
-                elif interp == "shap":
+                elif interp == "shap" or interp == "shapley":
                     try:
                         import shap
                     except (ImportError, ModuleNotFoundError):
diff --git a/gradio/interpretation.py b/gradio/interpretation.py
index 278aac0135..c9dd2922f4 100644
--- a/gradio/interpretation.py
+++ b/gradio/interpretation.py
@@ -49,7 +49,7 @@ def get_regression_or_classification_value(interface, original_output, perturbed
                 return 0
             return perturbed_output[0][original_label]
         else:
-            score = diff(perturbed_label, original_label)  # Intentionall inverted order of arguments.
+            score = diff(perturbed_label, original_label)  # Intentionally inverted order of arguments.
         return score
 
     else:
diff --git a/test/test_interpretation.py b/test/test_interpretation.py
index e068b20d5c..0a31f942e2 100644
--- a/test/test_interpretation.py
+++ b/test/test_interpretation.py
@@ -12,18 +12,15 @@ class TestDefault(unittest.TestCase):
         text_interface = Interface(max_word_len, "textbox", "label", interpretation="default")
         interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
         self.assertGreater(interpretation[0][1], 0)  # Checks to see if the first word has >0 score.
-        self.assertEqual(interpretation[-1][1], 0)  # Checks to see if the last word has 0 score.
+        self.assertEqual(interpretation[-1][1], 0)  # Checks to see if the last word has 0 score.        
 
-    ## Commented out since skimage is no longer a required dependency, this will fail in CircleCI TODO(abidlabs): have backup default segmentation
-    # def test_default_image(self):
-    #     max_pixel_value = lambda img: img.max()
-    #     img_interface = Interface(max_pixel_value, "image", "number", interpretation="default")
-    #     array = np.zeros((100,100))
-    #     array[0, 0] = 1
-    #     img = encode_array_to_base64(array)        
-    #     interpretation = img_interface.interpret([img])[0][0]
-    #     self.assertGreater(interpretation[0][0], 0)  # Checks to see if the top-left has >0 score.
-        
+class TestShapley(unittest.TestCase):
+    def test_shapley_text(self):
+        max_word_len = lambda text: max([len(word) for word in text.split(" ")])
+        text_interface = Interface(max_word_len, "textbox", "label", interpretation="shapley")
+        interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
+        self.assertGreater(interpretation[0][1], 0)  # Checks to see if the first word has >0 score.
+        self.assertEqual(interpretation[-1][1], 0)  # Checks to see if the last word has 0 score.        
 
 class TestCustom(unittest.TestCase):
     def test_custom_text(self):
@@ -42,5 +39,73 @@ class TestCustom(unittest.TestCase):
         self.assertEqual(result, expected_result)
         
 
+class TestHelperMethods(unittest.TestCase):
+    def test_diff(self):
+        diff = gradio.interpretation.diff(13, "2")
+        self.assertEquals(diff, 11)
+        diff = gradio.interpretation.diff("cat", "dog")
+        self.assertEquals(diff, 1)
+        diff = gradio.interpretation.diff("cat", "cat")
+        self.assertEquals(diff, 0)
+
+    def test_quantify_difference_with_textbox(self):
+        iface = Interface(lambda text: text, ["textbox"], ["textbox"])
+        diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test"])
+        self.assertEquals(diff, 0)
+        diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test_diff"])
+        self.assertEquals(diff, 1)
+
+    def test_quantify_difference_with_label(self):
+        iface = Interface(lambda text: len(text), ["textbox"], ["label"])
+        diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"])
+        self.assertEquals(diff, -7)
+        diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"])
+        self.assertEquals(diff, -100)
+
+    def test_quantify_difference_with_confidences(self):
+        iface = Interface(lambda text: len(text), ["textbox"], ["label"])
+        output_1 = {
+            "cat": 0.9,
+            "dog": 0.1
+        }
+        output_2 = {
+            "cat": 0.6,
+            "dog": 0.4
+        }
+        output_3 = {
+            "cat": 0.1,
+            "dog": 0.6
+        }
+        diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_2])
+        self.assertAlmostEquals(diff, 0.3)
+        diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_3])
+        self.assertAlmostEquals(diff, 0.8)
+
+    def test_get_regression_value(self):
+        iface = Interface(lambda text: text, ["textbox"], ["label"])
+        output_1 = {
+            "cat": 0.9,
+            "dog": 0.1
+        }
+        output_2 = {
+            "cat": float("nan"),
+            "dog": 0.4
+        }
+        output_3 = {
+            "cat": 0.1,
+            "dog": 0.6
+        }
+        diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_2])
+        self.assertEquals(diff, 0)
+        diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_3])
+        self.assertAlmostEquals(diff, 0.1)
+
+    def test_get_classification_value(self):
+        iface = Interface(lambda text: text, ["textbox"], ["label"])
+        diff = gradio.interpretation.get_regression_or_classification_value(iface, ["cat"], ["test"])
+        self.assertEquals(diff, 1)
+        diff = gradio.interpretation.get_regression_or_classification_value(iface, ["test"], ["test"])
+        self.assertEquals(diff, 0)
+
 if __name__ == '__main__':
     unittest.main()
\ No newline at end of file

From 966a30ada74315705128c53a080ed6caa6a6b1e2 Mon Sep 17 00:00:00 2001
From: Abubakar Abid <a12d@stanford.edu>
Date: Tue, 26 Oct 2021 17:47:49 -0500
Subject: [PATCH 2/2] added shap

---
 gradio.egg-info/requires.txt | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/gradio.egg-info/requires.txt b/gradio.egg-info/requires.txt
index fa44b53bdd..d6ea781d61 100644
--- a/gradio.egg-info/requires.txt
+++ b/gradio.egg-info/requires.txt
@@ -13,4 +13,5 @@ Flask>=1.1.1
 Flask-Cors>=3.0.8
 flask-cachebuster
 Flask-Login
-IPython
\ No newline at end of file
+IPython
+shap