mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
added shap interpretation for text, image, and audio
This commit is contained in:
parent
4c2efc54a8
commit
9badae0570
@ -37,7 +37,7 @@ iface = gr.Interface(
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
"number",
|
"number",
|
||||||
interpretation="default",
|
# interpretation="default", # Removed interpretation for dataframes
|
||||||
examples=[
|
examples=[
|
||||||
[10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]],
|
[10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]],
|
||||||
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]],
|
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]],
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
Metadata-Version: 1.0
|
Metadata-Version: 2.1
|
||||||
Name: gradio
|
Name: gradio
|
||||||
Version: 2.1.1
|
Version: 2.1.1
|
||||||
Summary: Python library for easily interacting with trained machine learning models
|
Summary: Python library for easily interacting with trained machine learning models
|
||||||
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
|
|||||||
Author: Abubakar Abid
|
Author: Abubakar Abid
|
||||||
Author-email: a12d@stanford.edu
|
Author-email: a12d@stanford.edu
|
||||||
License: Apache License 2.0
|
License: Apache License 2.0
|
||||||
Description: UNKNOWN
|
|
||||||
Keywords: machine learning,visualization,reproducibility
|
Keywords: machine learning,visualization,reproducibility
|
||||||
Platform: UNKNOWN
|
Platform: UNKNOWN
|
||||||
|
License-File: LICENSE
|
||||||
|
|
||||||
|
UNKNOWN
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
LICENSE
|
||||||
MANIFEST.in
|
MANIFEST.in
|
||||||
README.md
|
README.md
|
||||||
setup.py
|
setup.py
|
||||||
@ -26,32 +27,6 @@ gradio.egg-info/requires.txt
|
|||||||
gradio.egg-info/top_level.txt
|
gradio.egg-info/top_level.txt
|
||||||
gradio/frontend/asset-manifest.json
|
gradio/frontend/asset-manifest.json
|
||||||
gradio/frontend/index.html
|
gradio/frontend/index.html
|
||||||
gradio/frontend/static/bundle.css
|
|
||||||
gradio/frontend/static/bundle.css.map
|
|
||||||
gradio/frontend/static/bundle.js
|
|
||||||
gradio/frontend/static/bundle.js.LICENSE.txt
|
|
||||||
gradio/frontend/static/bundle.js.map
|
|
||||||
gradio/frontend/static/css/main.0094f11b.css
|
|
||||||
gradio/frontend/static/css/main.0094f11b.css.map
|
|
||||||
gradio/frontend/static/css/main.20be28ac.css
|
|
||||||
gradio/frontend/static/css/main.20be28ac.css.map
|
|
||||||
gradio/frontend/static/css/main.2b64a968.css
|
|
||||||
gradio/frontend/static/css/main.2b64a968.css.map
|
|
||||||
gradio/frontend/static/css/main.380e3222.css
|
|
||||||
gradio/frontend/static/css/main.380e3222.css.map
|
|
||||||
gradio/frontend/static/css/main.4aea80f8.css
|
|
||||||
gradio/frontend/static/css/main.4aea80f8.css.map
|
|
||||||
gradio/frontend/static/css/main.4f157d97.css
|
|
||||||
gradio/frontend/static/css/main.4f157d97.css.map
|
|
||||||
gradio/frontend/static/css/main.5c663906.css
|
|
||||||
gradio/frontend/static/css/main.5c663906.css.map
|
|
||||||
gradio/frontend/static/css/main.99922310.css
|
|
||||||
gradio/frontend/static/css/main.99922310.css.map
|
|
||||||
gradio/frontend/static/css/main.acb02c85.css
|
|
||||||
gradio/frontend/static/css/main.acb02c85.css.map
|
|
||||||
gradio/frontend/static/css/main.cbbf8898.css
|
|
||||||
gradio/frontend/static/css/main.cbbf8898.css.map
|
|
||||||
gradio/frontend/static/media/logo_loading.e93acd82.jpg
|
|
||||||
test/test_demos.py
|
test/test_demos.py
|
||||||
test/test_inputs.py
|
test/test_inputs.py
|
||||||
test/test_interfaces.py
|
test/test_interfaces.py
|
||||||
|
229
gradio/inputs.py
229
gradio/inputs.py
@ -29,7 +29,7 @@ class InputComponent(Component):
|
|||||||
Input Component. All input components subclass this.
|
Input Component. All input components subclass this.
|
||||||
"""
|
"""
|
||||||
def __init__(self, label, requires_permissions=False):
|
def __init__(self, label, requires_permissions=False):
|
||||||
self.interpret()
|
self.set_interpret_parameters()
|
||||||
super().__init__(label, requires_permissions)
|
super().__init__(label, requires_permissions)
|
||||||
|
|
||||||
def preprocess(self, x):
|
def preprocess(self, x):
|
||||||
@ -44,7 +44,7 @@ class InputComponent(Component):
|
|||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def interpret(self):
|
def set_interpret_parameters(self):
|
||||||
'''
|
'''
|
||||||
Set any parameters for interpretation.
|
Set any parameters for interpretation.
|
||||||
'''
|
'''
|
||||||
@ -115,6 +115,7 @@ class Textbox(InputComponent):
|
|||||||
}[type]
|
}[type]
|
||||||
else:
|
else:
|
||||||
self.test_input = default
|
self.test_input = default
|
||||||
|
self.interpret_by_tokens = True
|
||||||
super().__init__(label)
|
super().__init__(label)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -147,7 +148,7 @@ class Textbox(InputComponent):
|
|||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def interpret(self, separator=" ", replacement=None):
|
def set_interpret_parameters(self, separator=" ", replacement=None):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation score of characters in input by splitting input into tokens, then using a "leave one out" method to calculate the score of each token by removing each token and measuring the delta of the output value.
|
Calculates interpretation score of characters in input by splitting input into tokens, then using a "leave one out" method to calculate the score of each token by removing each token and measuring the delta of the output value.
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -157,8 +158,11 @@ class Textbox(InputComponent):
|
|||||||
self.interpretation_separator = separator
|
self.interpretation_separator = separator
|
||||||
self.interpretation_replacement = replacement
|
self.interpretation_replacement = replacement
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_interpretation_neighbors(self, x):
|
def tokenize(self, x):
|
||||||
|
"""
|
||||||
|
Tokenizes an input string by dividing into "words" delimited by self.interpretation_separator
|
||||||
|
"""
|
||||||
tokens = x.split(self.interpretation_separator)
|
tokens = x.split(self.interpretation_separator)
|
||||||
leave_one_out_strings = []
|
leave_one_out_strings = []
|
||||||
for index in range(len(tokens)):
|
for index in range(len(tokens)):
|
||||||
@ -168,9 +172,19 @@ class Textbox(InputComponent):
|
|||||||
else:
|
else:
|
||||||
leave_one_out_set[index] = self.interpretation_replacement
|
leave_one_out_set[index] = self.interpretation_replacement
|
||||||
leave_one_out_strings.append(self.interpretation_separator.join(leave_one_out_set))
|
leave_one_out_strings.append(self.interpretation_separator.join(leave_one_out_set))
|
||||||
return leave_one_out_strings, {"tokens": tokens}, True
|
return tokens, leave_one_out_strings, None
|
||||||
|
|
||||||
|
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||||
|
"""
|
||||||
|
Constructs partially-masked sentences for SHAP interpretation
|
||||||
|
"""
|
||||||
|
masked_inputs = []
|
||||||
|
for binary_mask_vector in binary_mask_matrix:
|
||||||
|
masked_input = np.array(tokens)[np.array(binary_mask_vector, dtype=bool)]
|
||||||
|
masked_inputs.append(self.interpretation_separator.join(masked_input))
|
||||||
|
return masked_inputs
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores, tokens):
|
def get_interpretation_scores(self, x, neighbors, scores, tokens, masks=None):
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
(List[Tuple[str, float]]): Each tuple set represents a set of characters and their corresponding interpretation score.
|
(List[Tuple[str, float]]): Each tuple set represents a set of characters and their corresponding interpretation score.
|
||||||
@ -216,6 +230,7 @@ class Number(InputComponent):
|
|||||||
'''
|
'''
|
||||||
self.default = default
|
self.default = default
|
||||||
self.test_input = default if default is not None else 1
|
self.test_input = default if default is not None else 1
|
||||||
|
self.interpret_by_tokens = False
|
||||||
super().__init__(label)
|
super().__init__(label)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -244,7 +259,7 @@ class Number(InputComponent):
|
|||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def interpret(self, steps=3, delta=1, delta_type="percent"):
|
def set_interpret_parameters(self, steps=3, delta=1, delta_type="percent"):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation scores of numeric values close to the input number.
|
Calculates interpretation scores of numeric values close to the input number.
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -266,7 +281,7 @@ class Number(InputComponent):
|
|||||||
delta = self.interpretation_delta
|
delta = self.interpretation_delta
|
||||||
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
|
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
|
||||||
positives = (x + np.arange(1, self.interpretation_steps+1) * delta).tolist()
|
positives = (x + np.arange(1, self.interpretation_steps+1) * delta).tolist()
|
||||||
return negatives + positives, {}, False
|
return negatives + positives, {}
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores):
|
def get_interpretation_scores(self, x, neighbors, scores):
|
||||||
"""
|
"""
|
||||||
@ -305,6 +320,7 @@ class Slider(InputComponent):
|
|||||||
self.step = step
|
self.step = step
|
||||||
self.default = minimum if default is None else default
|
self.default = minimum if default is None else default
|
||||||
self.test_input = self.default
|
self.test_input = self.default
|
||||||
|
self.interpret_by_tokens = False
|
||||||
super().__init__(label)
|
super().__init__(label)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -329,7 +345,7 @@ class Slider(InputComponent):
|
|||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def interpret(self, steps=8):
|
def set_interpret_parameters(self, steps=8):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider.
|
Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider.
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -339,7 +355,7 @@ class Slider(InputComponent):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def get_interpretation_neighbors(self, x):
|
def get_interpretation_neighbors(self, x):
|
||||||
return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}, False
|
return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores):
|
def get_interpretation_scores(self, x, neighbors, scores):
|
||||||
"""
|
"""
|
||||||
@ -367,6 +383,7 @@ class Checkbox(InputComponent):
|
|||||||
"""
|
"""
|
||||||
self.test_input = True
|
self.test_input = True
|
||||||
self.default = default
|
self.default = default
|
||||||
|
self.interpret_by_tokens = False
|
||||||
super().__init__(label)
|
super().__init__(label)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -388,14 +405,14 @@ class Checkbox(InputComponent):
|
|||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def interpret(self):
|
def set_interpret_parameters(self):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation score of the input by comparing the output against the output when the input is the inverse boolean value of x.
|
Calculates interpretation score of the input by comparing the output against the output when the input is the inverse boolean value of x.
|
||||||
"""
|
"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_interpretation_neighbors(self, x):
|
def get_interpretation_neighbors(self, x):
|
||||||
return [not x], {}, False
|
return [not x], {}
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores):
|
def get_interpretation_scores(self, x, neighbors, scores):
|
||||||
"""
|
"""
|
||||||
@ -430,6 +447,7 @@ class CheckboxGroup(InputComponent):
|
|||||||
self.default = default
|
self.default = default
|
||||||
self.type = type
|
self.type = type
|
||||||
self.test_input = self.choices
|
self.test_input = self.choices
|
||||||
|
self.interpret_by_tokens = False
|
||||||
super().__init__(label)
|
super().__init__(label)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -447,7 +465,7 @@ class CheckboxGroup(InputComponent):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||||
|
|
||||||
def interpret(self):
|
def set_interpret_parameters(self):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation score of each choice in the input by comparing the output against the outputs when each choice in the input is independently either removed or added.
|
Calculates interpretation score of each choice in the input by comparing the output against the outputs when each choice in the input is independently either removed or added.
|
||||||
"""
|
"""
|
||||||
@ -462,7 +480,7 @@ class CheckboxGroup(InputComponent):
|
|||||||
else:
|
else:
|
||||||
leave_one_out_set.append(choice)
|
leave_one_out_set.append(choice)
|
||||||
leave_one_out_sets.append(leave_one_out_set)
|
leave_one_out_sets.append(leave_one_out_set)
|
||||||
return leave_one_out_sets, {}, False
|
return leave_one_out_sets, {}
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores):
|
def get_interpretation_scores(self, x, neighbors, scores):
|
||||||
"""
|
"""
|
||||||
@ -514,6 +532,7 @@ class Radio(InputComponent):
|
|||||||
self.type = type
|
self.type = type
|
||||||
self.test_input = self.choices[0]
|
self.test_input = self.choices[0]
|
||||||
self.default = default if default is not None else self.choices[0]
|
self.default = default if default is not None else self.choices[0]
|
||||||
|
self.interpret_by_tokens = False
|
||||||
super().__init__(label)
|
super().__init__(label)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -531,7 +550,7 @@ class Radio(InputComponent):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||||
|
|
||||||
def interpret(self):
|
def set_interpret_parameters(self):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
|
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
|
||||||
"""
|
"""
|
||||||
@ -540,7 +559,7 @@ class Radio(InputComponent):
|
|||||||
def get_interpretation_neighbors(self, x):
|
def get_interpretation_neighbors(self, x):
|
||||||
choices = list(self.choices)
|
choices = list(self.choices)
|
||||||
choices.remove(x)
|
choices.remove(x)
|
||||||
return choices, {}, False
|
return choices, {}
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores):
|
def get_interpretation_scores(self, x, neighbors, scores):
|
||||||
"""
|
"""
|
||||||
@ -577,6 +596,7 @@ class Dropdown(InputComponent):
|
|||||||
self.type = type
|
self.type = type
|
||||||
self.test_input = self.choices[0]
|
self.test_input = self.choices[0]
|
||||||
self.default = default if default is not None else self.choices[0]
|
self.default = default if default is not None else self.choices[0]
|
||||||
|
self.interpret_by_tokens = False
|
||||||
super().__init__(label)
|
super().__init__(label)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -594,7 +614,7 @@ class Dropdown(InputComponent):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||||
|
|
||||||
def interpret(self):
|
def set_interpret_parameters(self):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
|
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
|
||||||
"""
|
"""
|
||||||
@ -603,7 +623,7 @@ class Dropdown(InputComponent):
|
|||||||
def get_interpretation_neighbors(self, x):
|
def get_interpretation_neighbors(self, x):
|
||||||
choices = list(self.choices)
|
choices = list(self.choices)
|
||||||
choices.remove(x)
|
choices.remove(x)
|
||||||
return choices, {}, False
|
return choices, {}
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores):
|
def get_interpretation_scores(self, x, neighbors, scores):
|
||||||
"""
|
"""
|
||||||
@ -647,6 +667,7 @@ class Image(InputComponent):
|
|||||||
self.type = type
|
self.type = type
|
||||||
self.invert_colors = invert_colors
|
self.invert_colors = invert_colors
|
||||||
self.test_input = test_data.BASE64_IMAGE
|
self.test_input = test_data.BASE64_IMAGE
|
||||||
|
self.interpret_by_tokens = True
|
||||||
super().__init__(label, requires_permissions)
|
super().__init__(label, requires_permissions)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -690,7 +711,7 @@ class Image(InputComponent):
|
|||||||
def preprocess_example(self, x):
|
def preprocess_example(self, x):
|
||||||
return processing_utils.encode_file_to_base64(x)
|
return processing_utils.encode_file_to_base64(x)
|
||||||
|
|
||||||
def interpret(self, segments=16):
|
def set_interpret_parameters(self, segments=16):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
|
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -699,29 +720,60 @@ class Image(InputComponent):
|
|||||||
self.interpretation_segments = segments
|
self.interpretation_segments = segments
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_interpretation_neighbors(self, x):
|
def _segment_by_slic(self, x):
|
||||||
|
"""
|
||||||
|
Helper method that segments an image into superpixels using slic.
|
||||||
|
Parameters:
|
||||||
|
x: base64 representation of an image
|
||||||
|
"""
|
||||||
x = processing_utils.decode_base64_to_image(x)
|
x = processing_utils.decode_base64_to_image(x)
|
||||||
if self.shape is not None:
|
if self.shape is not None:
|
||||||
x = processing_utils.resize_and_crop(x, self.shape)
|
x = processing_utils.resize_and_crop(x, self.shape)
|
||||||
image = np.array(x)
|
resized_and_cropped_image = np.array(x)
|
||||||
try:
|
try:
|
||||||
from skimage.segmentation import slic
|
from skimage.segmentation import slic
|
||||||
except ImportError:
|
except (ImportError, ModuleNotFoundError):
|
||||||
print("Error: running default interpretation for images requires scikit-image, please install it first.")
|
raise ValueError("Error: running this interpretation for images requires scikit-image, please install it first.")
|
||||||
return
|
segments_slic = slic(
|
||||||
segments_slic = slic(image, self.interpretation_segments, compactness=10, sigma=1)
|
resized_and_cropped_image, self.interpretation_segments, compactness=10,
|
||||||
leave_one_out_tokens, masks = [], []
|
sigma=1, start_label=0)
|
||||||
replace_color = np.mean(image, axis=(0, 1))
|
return segments_slic, resized_and_cropped_image
|
||||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
|
||||||
mask = segments_slic == segVal
|
|
||||||
white_screen = np.copy(image)
|
|
||||||
white_screen[segments_slic == segVal] = replace_color
|
|
||||||
leave_one_out_tokens.append(
|
|
||||||
processing_utils.encode_array_to_base64(white_screen))
|
|
||||||
masks.append(mask)
|
|
||||||
return leave_one_out_tokens, {"masks": masks}, True
|
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores, masks):
|
def tokenize(self, x):
|
||||||
|
"""
|
||||||
|
Segments image into tokens, masks, and leave-one-out-tokens
|
||||||
|
Parameters:
|
||||||
|
x: base64 representation of an image
|
||||||
|
Returns:
|
||||||
|
tokens: list of tokens, used by the get_masked_input() method
|
||||||
|
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
|
||||||
|
masks: list of masks, used by the get_interpretation_neighbors() method
|
||||||
|
"""
|
||||||
|
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
|
||||||
|
tokens, masks, leave_one_out_tokens = [], [], []
|
||||||
|
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
|
||||||
|
for (i, segment_value) in enumerate(np.unique(segments_slic)):
|
||||||
|
mask = (segments_slic == segment_value)
|
||||||
|
image_screen = np.copy(resized_and_cropped_image)
|
||||||
|
image_screen[segments_slic == segment_value] = replace_color
|
||||||
|
leave_one_out_tokens.append(
|
||||||
|
processing_utils.encode_array_to_base64(image_screen))
|
||||||
|
token = np.copy(resized_and_cropped_image)
|
||||||
|
token[segments_slic != segment_value] = 0
|
||||||
|
tokens.append(token)
|
||||||
|
masks.append(mask)
|
||||||
|
return tokens, leave_one_out_tokens, masks
|
||||||
|
|
||||||
|
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||||
|
masked_inputs = []
|
||||||
|
for binary_mask_vector in binary_mask_matrix:
|
||||||
|
masked_input = np.zeros_like(tokens[0], dtype=int)
|
||||||
|
for token, b in zip(tokens, binary_mask_vector):
|
||||||
|
masked_input = masked_input + token*int(b)
|
||||||
|
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
|
||||||
|
return masked_inputs
|
||||||
|
|
||||||
|
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
|
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
|
||||||
@ -825,6 +877,7 @@ class Audio(InputComponent):
|
|||||||
requires_permissions = source == "microphone"
|
requires_permissions = source == "microphone"
|
||||||
self.type = type
|
self.type = type
|
||||||
self.test_input = test_data.BASE64_AUDIO
|
self.test_input = test_data.BASE64_AUDIO
|
||||||
|
self.interpret_by_tokens = True
|
||||||
super().__init__(label, requires_permissions)
|
super().__init__(label, requires_permissions)
|
||||||
|
|
||||||
def get_template_context(self):
|
def get_template_context(self):
|
||||||
@ -855,7 +908,7 @@ class Audio(InputComponent):
|
|||||||
def preprocess_example(self, x):
|
def preprocess_example(self, x):
|
||||||
return processing_utils.encode_file_to_base64(x, type="audio")
|
return processing_utils.encode_file_to_base64(x, type="audio")
|
||||||
|
|
||||||
def interpret(self, segments=8):
|
def set_interpret_parameters(self, segments=8):
|
||||||
"""
|
"""
|
||||||
Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value.
|
Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value.
|
||||||
Parameters:
|
Parameters:
|
||||||
@ -864,30 +917,66 @@ class Audio(InputComponent):
|
|||||||
self.interpretation_segments = segments
|
self.interpretation_segments = segments
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_interpretation_neighbors(self, x):
|
def tokenize(self, x):
|
||||||
file_obj = processing_utils.decode_base64_to_file(x)
|
file_obj = processing_utils.decode_base64_to_file(x)
|
||||||
x = scipy.io.wavfile.read(file_obj.name)
|
x = scipy.io.wavfile.read(file_obj.name)
|
||||||
sample_rate, data = x
|
sample_rate, data = x
|
||||||
leave_one_out_sets = []
|
leave_one_out_sets = []
|
||||||
|
tokens = []
|
||||||
|
masks = []
|
||||||
duration = data.shape[0]
|
duration = data.shape[0]
|
||||||
boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist()
|
boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist()
|
||||||
boundaries = [round(boundary) for boundary in boundaries]
|
boundaries = [round(boundary) for boundary in boundaries]
|
||||||
for index in range(len(boundaries) - 1):
|
for index in range(len(boundaries) - 1):
|
||||||
leave_one_out_data = np.copy(data)
|
|
||||||
start, stop = boundaries[index], boundaries[index + 1]
|
start, stop = boundaries[index], boundaries[index + 1]
|
||||||
|
masks.append((start, stop))
|
||||||
|
# Handle the leave one outs
|
||||||
|
leave_one_out_data = np.copy(data)
|
||||||
leave_one_out_data[start:stop] = 0
|
leave_one_out_data[start:stop] = 0
|
||||||
file = tempfile.NamedTemporaryFile(delete=False)
|
file = tempfile.NamedTemporaryFile(delete=False)
|
||||||
scipy.io.wavfile.write(file, sample_rate, leave_one_out_data)
|
scipy.io.wavfile.write(file, sample_rate, leave_one_out_data)
|
||||||
out_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
|
out_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
|
||||||
leave_one_out_sets.append(out_data)
|
leave_one_out_sets.append(out_data)
|
||||||
return leave_one_out_sets, {}, True
|
# Handle the tokens
|
||||||
|
token = np.copy(data)
|
||||||
|
token[0:start] = 0
|
||||||
|
token[stop:] = 0
|
||||||
|
file = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
scipy.io.wavfile.write(file, sample_rate, token)
|
||||||
|
token_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
|
||||||
|
tokens.append(token_data)
|
||||||
|
return tokens, leave_one_out_sets, masks
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores):
|
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||||
|
# create a "zero input" vector and get sample rate
|
||||||
|
x = tokens[0]
|
||||||
|
file_obj = processing_utils.decode_base64_to_file(x)
|
||||||
|
sample_rate, data = scipy.io.wavfile.read(file_obj.name)
|
||||||
|
zero_input = np.zeros_like(data, dtype=int)
|
||||||
|
# decode all of the tokens
|
||||||
|
token_data = []
|
||||||
|
for token in tokens:
|
||||||
|
file_obj = processing_utils.decode_base64_to_file(token)
|
||||||
|
_, data = scipy.io.wavfile.read(file_obj.name)
|
||||||
|
token_data.append(data)
|
||||||
|
# construct the masked version
|
||||||
|
masked_inputs = []
|
||||||
|
for binary_mask_vector in binary_mask_matrix:
|
||||||
|
masked_input = np.copy(zero_input)
|
||||||
|
for t, b in zip(token_data, binary_mask_vector):
|
||||||
|
masked_input = masked_input + t*int(b)
|
||||||
|
file = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
scipy.io.wavfile.write(file, sample_rate, masked_input)
|
||||||
|
masked_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
|
||||||
|
masked_inputs.append(masked_data)
|
||||||
|
return masked_inputs
|
||||||
|
|
||||||
|
def get_interpretation_scores(self, x, neighbors, scores, masks=None, tokens=None):
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
(List[float]): Each value represents the interpretation score corresponding to an evenly spaced subsection of audio.
|
(List[float]): Each value represents the interpretation score corresponding to an evenly spaced subsection of audio.
|
||||||
"""
|
"""
|
||||||
return scores
|
return list(scores)
|
||||||
|
|
||||||
def embed(self, x):
|
def embed(self, x):
|
||||||
"""
|
"""
|
||||||
@ -1049,35 +1138,35 @@ class Dataframe(InputComponent):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'pandas', 'numpy', 'array'.")
|
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'pandas', 'numpy', 'array'.")
|
||||||
|
|
||||||
def interpret(self):
|
# def set_interpret_parameters(self):
|
||||||
"""
|
# """
|
||||||
Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value.
|
# Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value.
|
||||||
"""
|
# """
|
||||||
return self
|
# return self
|
||||||
|
|
||||||
def get_interpretation_neighbors(self, x):
|
# def get_interpretation_neighbors(self, x):
|
||||||
x = pd.DataFrame(x)
|
# x = pd.DataFrame(x)
|
||||||
leave_one_out_sets = []
|
# leave_one_out_sets = []
|
||||||
shape = x.shape
|
# shape = x.shape
|
||||||
for i in range(shape[0]):
|
# for i in range(shape[0]):
|
||||||
for j in range(shape[1]):
|
# for j in range(shape[1]):
|
||||||
scalar = x.iloc[i, j]
|
# scalar = x.iloc[i, j]
|
||||||
leave_one_out_df = x.copy()
|
# leave_one_out_df = x.copy()
|
||||||
if is_bool_dtype(scalar):
|
# if is_bool_dtype(scalar):
|
||||||
leave_one_out_df.iloc[i, j] = not scalar
|
# leave_one_out_df.iloc[i, j] = not scalar
|
||||||
elif is_numeric_dtype(scalar):
|
# elif is_numeric_dtype(scalar):
|
||||||
leave_one_out_df.iloc[i, j] = 0
|
# leave_one_out_df.iloc[i, j] = 0
|
||||||
elif is_string_dtype(scalar):
|
# elif is_string_dtype(scalar):
|
||||||
leave_one_out_df.iloc[i, j] = ""
|
# leave_one_out_df.iloc[i, j] = ""
|
||||||
leave_one_out_sets.append(leave_one_out_df.values.tolist())
|
# leave_one_out_sets.append(leave_one_out_df.values.tolist())
|
||||||
return leave_one_out_sets, {"shape": x.shape}, True
|
# return leave_one_out_sets, {"shape": x.shape}
|
||||||
|
|
||||||
def get_interpretation_scores(self, x, neighbors, scores, shape):
|
# def get_interpretation_scores(self, x, neighbors, scores, shape):
|
||||||
"""
|
# """
|
||||||
Returns:
|
# Returns:
|
||||||
(List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell.
|
# (List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell.
|
||||||
"""
|
# """
|
||||||
return np.array(scores).reshape((shape)).tolist()
|
# return np.array(scores).reshape((shape)).tolist()
|
||||||
|
|
||||||
def embed(self, x):
|
def embed(self, x):
|
||||||
raise NotImplementedError("DataFrame doesn't currently support embeddings")
|
raise NotImplementedError("DataFrame doesn't currently support embeddings")
|
||||||
|
@ -7,7 +7,7 @@ import gradio
|
|||||||
from gradio.inputs import InputComponent, get_input_instance
|
from gradio.inputs import InputComponent, get_input_instance
|
||||||
from gradio.outputs import OutputComponent, get_output_instance
|
from gradio.outputs import OutputComponent, get_output_instance
|
||||||
from gradio import networking, strings, utils
|
from gradio import networking, strings, utils
|
||||||
from gradio.interpretation import quantify_difference_in_label
|
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
|
||||||
from gradio.external import load_interface
|
from gradio.external import load_interface
|
||||||
from gradio import encryptor
|
from gradio import encryptor
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
@ -66,8 +66,8 @@ class Interface:
|
|||||||
def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None,
|
def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None,
|
||||||
examples_per_page=10, live=False,
|
examples_per_page=10, live=False,
|
||||||
layout="horizontal", show_input=True, show_output=True,
|
layout="horizontal", show_input=True, show_output=True,
|
||||||
capture_session=False, interpretation=None, theme=None, repeat_outputs_per_model=True,
|
capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
|
||||||
title=None, description=None, article=None, thumbnail=None,
|
title=None, description=None, article=None, thumbnail=None,
|
||||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
||||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
||||||
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True):
|
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True):
|
||||||
@ -84,6 +84,7 @@ class Interface:
|
|||||||
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
|
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
|
||||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||||
|
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results.
|
||||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||||
article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content.
|
article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content.
|
||||||
@ -145,6 +146,7 @@ class Interface:
|
|||||||
self.examples = examples
|
self.examples = examples
|
||||||
else:
|
else:
|
||||||
raise ValueError("Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.")
|
raise ValueError("Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.")
|
||||||
|
self.num_shap = num_shap
|
||||||
self.examples_per_page = examples_per_page
|
self.examples_per_page = examples_per_page
|
||||||
self.server_port = server_port
|
self.server_port = server_port
|
||||||
self.simple_server = None
|
self.simple_server = None
|
||||||
@ -334,7 +336,7 @@ class Interface:
|
|||||||
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
||||||
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
|
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
|
||||||
"""
|
"""
|
||||||
if self.interpretation == "default":
|
if self.interpretation.lower() == "default":
|
||||||
processed_input = [input_component.preprocess(raw_input[i])
|
processed_input = [input_component.preprocess(raw_input[i])
|
||||||
for i, input_component in enumerate(self.input_components)]
|
for i, input_component in enumerate(self.input_components)]
|
||||||
original_output = self.run_prediction(processed_input)
|
original_output = self.run_prediction(processed_input)
|
||||||
@ -342,26 +344,75 @@ class Interface:
|
|||||||
for i, x in enumerate(raw_input):
|
for i, x in enumerate(raw_input):
|
||||||
input_component = self.input_components[i]
|
input_component = self.input_components[i]
|
||||||
neighbor_raw_input = list(raw_input)
|
neighbor_raw_input = list(raw_input)
|
||||||
neighbor_values, interpret_kwargs, interpret_by_removal = input_component.get_interpretation_neighbors(x)
|
if input_component.interpret_by_tokens:
|
||||||
interface_scores = []
|
tokens, neighbor_values, masks = input_component.tokenize(x)
|
||||||
alternative_output = []
|
interface_scores = []
|
||||||
for neighbor_input in neighbor_values:
|
alternative_output = []
|
||||||
neighbor_raw_input[i] = neighbor_input
|
for neighbor_input in neighbor_values:
|
||||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
neighbor_raw_input[i] = neighbor_input
|
||||||
for i, input_component in enumerate(self.input_components)]
|
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||||
neighbor_output = self.run_prediction(processed_neighbor_input)
|
for i, input_component in enumerate(self.input_components)]
|
||||||
processed_neighbor_output = [output_component.postprocess(
|
neighbor_output = self.run_prediction(processed_neighbor_input)
|
||||||
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
processed_neighbor_output = [output_component.postprocess(
|
||||||
|
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
||||||
|
|
||||||
alternative_output.append(processed_neighbor_output)
|
alternative_output.append(processed_neighbor_output)
|
||||||
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||||
alternative_outputs.append(alternative_output)
|
alternative_outputs.append(alternative_output)
|
||||||
if not interpret_by_removal:
|
scores.append(
|
||||||
|
input_component.get_interpretation_scores(
|
||||||
|
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
|
||||||
|
else:
|
||||||
|
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
|
||||||
|
interface_scores = []
|
||||||
|
alternative_output = []
|
||||||
|
for neighbor_input in neighbor_values:
|
||||||
|
neighbor_raw_input[i] = neighbor_input
|
||||||
|
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||||
|
for i, input_component in enumerate(self.input_components)]
|
||||||
|
neighbor_output = self.run_prediction(processed_neighbor_input)
|
||||||
|
processed_neighbor_output = [output_component.postprocess(
|
||||||
|
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
||||||
|
|
||||||
|
alternative_output.append(processed_neighbor_output)
|
||||||
|
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||||
|
alternative_outputs.append(alternative_output)
|
||||||
interface_scores = [-score for score in interface_scores]
|
interface_scores = [-score for score in interface_scores]
|
||||||
scores.append(
|
scores.append(
|
||||||
input_component.get_interpretation_scores(
|
input_component.get_interpretation_scores(
|
||||||
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
|
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
|
||||||
return scores, alternative_outputs
|
return scores, alternative_outputs
|
||||||
|
elif self.interpretation.lower() == "shap":
|
||||||
|
scores = []
|
||||||
|
try:
|
||||||
|
import shap
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
|
||||||
|
|
||||||
|
processed_input = [input_component.preprocess(raw_input[i])
|
||||||
|
for i, input_component in enumerate(self.input_components)]
|
||||||
|
original_output = self.run_prediction(processed_input)
|
||||||
|
|
||||||
|
for i, x in enumerate(raw_input): # iterate over reach interface
|
||||||
|
input_component = self.input_components[i]
|
||||||
|
tokens, _, masks = input_component.tokenize(x)
|
||||||
|
|
||||||
|
def get_masked_prediction(binary_mask): # construct a masked version of the input
|
||||||
|
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
|
||||||
|
preds = []
|
||||||
|
for masked_x in masked_xs:
|
||||||
|
processed_masked_input = copy.deepcopy(processed_input)
|
||||||
|
processed_masked_input[i] = input_component.preprocess(masked_x)
|
||||||
|
new_output = self.run_prediction(processed_masked_input)
|
||||||
|
pred = get_regression_or_classification_value(self, original_output, new_output)
|
||||||
|
preds.append(pred)
|
||||||
|
return np.array(preds)
|
||||||
|
|
||||||
|
num_total_segments = len(tokens)
|
||||||
|
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
|
||||||
|
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
|
||||||
|
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
|
||||||
|
return scores, []
|
||||||
else:
|
else:
|
||||||
processed_input = [input_component.preprocess(raw_input[i])
|
processed_input = [input_component.preprocess(raw_input[i])
|
||||||
for i, input_component in enumerate(self.input_components)]
|
for i, input_component in enumerate(self.input_components)]
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from gradio.outputs import Label, Textbox
|
from gradio.outputs import Label, Textbox
|
||||||
|
import math
|
||||||
|
|
||||||
def diff(original, perturbed):
|
def diff(original, perturbed):
|
||||||
try: # try computing numerical difference
|
try: # try computing numerical difference
|
||||||
@ -24,7 +25,33 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
|
|||||||
else:
|
else:
|
||||||
score = diff(original_label, perturbed_label)
|
score = diff(original_label, perturbed_label)
|
||||||
return score
|
return score
|
||||||
|
|
||||||
elif type(output_component) == Textbox:
|
elif type(output_component) == Textbox:
|
||||||
score = diff(post_original_output, post_perturbed_output)
|
score = diff(post_original_output, post_perturbed_output)
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
|
||||||
|
|
||||||
|
def get_regression_or_classification_value(interface, original_output, perturbed_output):
|
||||||
|
"""Used to combine regression/classification for Shap interpretation method."""
|
||||||
|
output_component = interface.output_components[0]
|
||||||
|
post_original_output = output_component.postprocess(original_output[0])
|
||||||
|
post_perturbed_output = output_component.postprocess(perturbed_output[0])
|
||||||
|
|
||||||
|
if type(output_component) == Label:
|
||||||
|
original_label = post_original_output["label"]
|
||||||
|
perturbed_label = post_perturbed_output["label"]
|
||||||
|
|
||||||
|
# Handle different return types of Label interface
|
||||||
|
if "confidences" in post_original_output:
|
||||||
|
if math.isnan(perturbed_output[0][original_label]):
|
||||||
|
return 0
|
||||||
|
return perturbed_output[0][original_label]
|
||||||
|
else:
|
||||||
|
score = diff(perturbed_label, original_label) # Intentionall inverted order of arguments.
|
||||||
|
return score
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user