mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
added interpretation
This commit is contained in:
parent
14f2e46b19
commit
6b13238c9d
@ -424,7 +424,7 @@ class Interface:
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
|
||||
elif interp == "shap":
|
||||
elif interp == "shap" or interp == "shapley":
|
||||
try:
|
||||
import shap
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
|
@ -49,7 +49,7 @@ def get_regression_or_classification_value(interface, original_output, perturbed
|
||||
return 0
|
||||
return perturbed_output[0][original_label]
|
||||
else:
|
||||
score = diff(perturbed_label, original_label) # Intentionall inverted order of arguments.
|
||||
score = diff(perturbed_label, original_label) # Intentionally inverted order of arguments.
|
||||
return score
|
||||
|
||||
else:
|
||||
|
@ -12,18 +12,15 @@ class TestDefault(unittest.TestCase):
|
||||
text_interface = Interface(max_word_len, "textbox", "label", interpretation="default")
|
||||
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
|
||||
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
|
||||
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
|
||||
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
|
||||
|
||||
## Commented out since skimage is no longer a required dependency, this will fail in CircleCI TODO(abidlabs): have backup default segmentation
|
||||
# def test_default_image(self):
|
||||
# max_pixel_value = lambda img: img.max()
|
||||
# img_interface = Interface(max_pixel_value, "image", "number", interpretation="default")
|
||||
# array = np.zeros((100,100))
|
||||
# array[0, 0] = 1
|
||||
# img = encode_array_to_base64(array)
|
||||
# interpretation = img_interface.interpret([img])[0][0]
|
||||
# self.assertGreater(interpretation[0][0], 0) # Checks to see if the top-left has >0 score.
|
||||
|
||||
class TestShapley(unittest.TestCase):
|
||||
def test_shapley_text(self):
|
||||
max_word_len = lambda text: max([len(word) for word in text.split(" ")])
|
||||
text_interface = Interface(max_word_len, "textbox", "label", interpretation="shapley")
|
||||
interpretation = text_interface.interpret(["quickest brown fox"])[0][0]
|
||||
self.assertGreater(interpretation[0][1], 0) # Checks to see if the first word has >0 score.
|
||||
self.assertEqual(interpretation[-1][1], 0) # Checks to see if the last word has 0 score.
|
||||
|
||||
class TestCustom(unittest.TestCase):
|
||||
def test_custom_text(self):
|
||||
@ -42,5 +39,73 @@ class TestCustom(unittest.TestCase):
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
class TestHelperMethods(unittest.TestCase):
|
||||
def test_diff(self):
|
||||
diff = gradio.interpretation.diff(13, "2")
|
||||
self.assertEquals(diff, 11)
|
||||
diff = gradio.interpretation.diff("cat", "dog")
|
||||
self.assertEquals(diff, 1)
|
||||
diff = gradio.interpretation.diff("cat", "cat")
|
||||
self.assertEquals(diff, 0)
|
||||
|
||||
def test_quantify_difference_with_textbox(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["textbox"])
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test"])
|
||||
self.assertEquals(diff, 0)
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["test"], ["test_diff"])
|
||||
self.assertEquals(diff, 1)
|
||||
|
||||
def test_quantify_difference_with_label(self):
|
||||
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["3"], ["10"])
|
||||
self.assertEquals(diff, -7)
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, ["0"], ["100"])
|
||||
self.assertEquals(diff, -100)
|
||||
|
||||
def test_quantify_difference_with_confidences(self):
|
||||
iface = Interface(lambda text: len(text), ["textbox"], ["label"])
|
||||
output_1 = {
|
||||
"cat": 0.9,
|
||||
"dog": 0.1
|
||||
}
|
||||
output_2 = {
|
||||
"cat": 0.6,
|
||||
"dog": 0.4
|
||||
}
|
||||
output_3 = {
|
||||
"cat": 0.1,
|
||||
"dog": 0.6
|
||||
}
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_2])
|
||||
self.assertAlmostEquals(diff, 0.3)
|
||||
diff = gradio.interpretation.quantify_difference_in_label(iface, [output_1], [output_3])
|
||||
self.assertAlmostEquals(diff, 0.8)
|
||||
|
||||
def test_get_regression_value(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
||||
output_1 = {
|
||||
"cat": 0.9,
|
||||
"dog": 0.1
|
||||
}
|
||||
output_2 = {
|
||||
"cat": float("nan"),
|
||||
"dog": 0.4
|
||||
}
|
||||
output_3 = {
|
||||
"cat": 0.1,
|
||||
"dog": 0.6
|
||||
}
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_2])
|
||||
self.assertEquals(diff, 0)
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, [output_1], [output_3])
|
||||
self.assertAlmostEquals(diff, 0.1)
|
||||
|
||||
def test_get_classification_value(self):
|
||||
iface = Interface(lambda text: text, ["textbox"], ["label"])
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["cat"], ["test"])
|
||||
self.assertEquals(diff, 1)
|
||||
diff = gradio.interpretation.get_regression_or_classification_value(iface, ["test"], ["test"])
|
||||
self.assertEquals(diff, 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user