added shap interpretation for text, image, and audio

This commit is contained in:
Abubakar Abid 2021-07-12 12:44:22 -05:00
parent 4c2efc54a8
commit 9badae0570
6 changed files with 266 additions and 121 deletions

View File

@ -37,7 +37,7 @@ iface = gr.Interface(
) )
], ],
"number", "number",
interpretation="default", # interpretation="default", # Removed interpretation for dataframes
examples=[ examples=[
[10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]], [10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]],
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]], [80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]],

View File

@ -1,4 +1,4 @@
Metadata-Version: 1.0 Metadata-Version: 2.1
Name: gradio Name: gradio
Version: 2.1.1 Version: 2.1.1
Summary: Python library for easily interacting with trained machine learning models Summary: Python library for easily interacting with trained machine learning models
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid Author: Abubakar Abid
Author-email: a12d@stanford.edu Author-email: a12d@stanford.edu
License: Apache License 2.0 License: Apache License 2.0
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN Platform: UNKNOWN
License-File: LICENSE
UNKNOWN

View File

@ -1,3 +1,4 @@
LICENSE
MANIFEST.in MANIFEST.in
README.md README.md
setup.py setup.py
@ -26,32 +27,6 @@ gradio.egg-info/requires.txt
gradio.egg-info/top_level.txt gradio.egg-info/top_level.txt
gradio/frontend/asset-manifest.json gradio/frontend/asset-manifest.json
gradio/frontend/index.html gradio/frontend/index.html
gradio/frontend/static/bundle.css
gradio/frontend/static/bundle.css.map
gradio/frontend/static/bundle.js
gradio/frontend/static/bundle.js.LICENSE.txt
gradio/frontend/static/bundle.js.map
gradio/frontend/static/css/main.0094f11b.css
gradio/frontend/static/css/main.0094f11b.css.map
gradio/frontend/static/css/main.20be28ac.css
gradio/frontend/static/css/main.20be28ac.css.map
gradio/frontend/static/css/main.2b64a968.css
gradio/frontend/static/css/main.2b64a968.css.map
gradio/frontend/static/css/main.380e3222.css
gradio/frontend/static/css/main.380e3222.css.map
gradio/frontend/static/css/main.4aea80f8.css
gradio/frontend/static/css/main.4aea80f8.css.map
gradio/frontend/static/css/main.4f157d97.css
gradio/frontend/static/css/main.4f157d97.css.map
gradio/frontend/static/css/main.5c663906.css
gradio/frontend/static/css/main.5c663906.css.map
gradio/frontend/static/css/main.99922310.css
gradio/frontend/static/css/main.99922310.css.map
gradio/frontend/static/css/main.acb02c85.css
gradio/frontend/static/css/main.acb02c85.css.map
gradio/frontend/static/css/main.cbbf8898.css
gradio/frontend/static/css/main.cbbf8898.css.map
gradio/frontend/static/media/logo_loading.e93acd82.jpg
test/test_demos.py test/test_demos.py
test/test_inputs.py test/test_inputs.py
test/test_interfaces.py test/test_interfaces.py

View File

@ -29,7 +29,7 @@ class InputComponent(Component):
Input Component. All input components subclass this. Input Component. All input components subclass this.
""" """
def __init__(self, label, requires_permissions=False): def __init__(self, label, requires_permissions=False):
self.interpret() self.set_interpret_parameters()
super().__init__(label, requires_permissions) super().__init__(label, requires_permissions)
def preprocess(self, x): def preprocess(self, x):
@ -44,7 +44,7 @@ class InputComponent(Component):
""" """
return x return x
def interpret(self): def set_interpret_parameters(self):
''' '''
Set any parameters for interpretation. Set any parameters for interpretation.
''' '''
@ -115,6 +115,7 @@ class Textbox(InputComponent):
}[type] }[type]
else: else:
self.test_input = default self.test_input = default
self.interpret_by_tokens = True
super().__init__(label) super().__init__(label)
def get_template_context(self): def get_template_context(self):
@ -147,7 +148,7 @@ class Textbox(InputComponent):
""" """
return x return x
def interpret(self, separator=" ", replacement=None): def set_interpret_parameters(self, separator=" ", replacement=None):
""" """
Calculates interpretation score of characters in input by splitting input into tokens, then using a "leave one out" method to calculate the score of each token by removing each token and measuring the delta of the output value. Calculates interpretation score of characters in input by splitting input into tokens, then using a "leave one out" method to calculate the score of each token by removing each token and measuring the delta of the output value.
Parameters: Parameters:
@ -157,8 +158,11 @@ class Textbox(InputComponent):
self.interpretation_separator = separator self.interpretation_separator = separator
self.interpretation_replacement = replacement self.interpretation_replacement = replacement
return self return self
def get_interpretation_neighbors(self, x): def tokenize(self, x):
"""
Tokenizes an input string by dividing into "words" delimited by self.interpretation_separator
"""
tokens = x.split(self.interpretation_separator) tokens = x.split(self.interpretation_separator)
leave_one_out_strings = [] leave_one_out_strings = []
for index in range(len(tokens)): for index in range(len(tokens)):
@ -168,9 +172,19 @@ class Textbox(InputComponent):
else: else:
leave_one_out_set[index] = self.interpretation_replacement leave_one_out_set[index] = self.interpretation_replacement
leave_one_out_strings.append(self.interpretation_separator.join(leave_one_out_set)) leave_one_out_strings.append(self.interpretation_separator.join(leave_one_out_set))
return leave_one_out_strings, {"tokens": tokens}, True return tokens, leave_one_out_strings, None
def get_masked_inputs(self, tokens, binary_mask_matrix):
"""
Constructs partially-masked sentences for SHAP interpretation
"""
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.array(tokens)[np.array(binary_mask_vector, dtype=bool)]
masked_inputs.append(self.interpretation_separator.join(masked_input))
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, tokens): def get_interpretation_scores(self, x, neighbors, scores, tokens, masks=None):
""" """
Returns: Returns:
(List[Tuple[str, float]]): Each tuple set represents a set of characters and their corresponding interpretation score. (List[Tuple[str, float]]): Each tuple set represents a set of characters and their corresponding interpretation score.
@ -216,6 +230,7 @@ class Number(InputComponent):
''' '''
self.default = default self.default = default
self.test_input = default if default is not None else 1 self.test_input = default if default is not None else 1
self.interpret_by_tokens = False
super().__init__(label) super().__init__(label)
def get_template_context(self): def get_template_context(self):
@ -244,7 +259,7 @@ class Number(InputComponent):
""" """
return x return x
def interpret(self, steps=3, delta=1, delta_type="percent"): def set_interpret_parameters(self, steps=3, delta=1, delta_type="percent"):
""" """
Calculates interpretation scores of numeric values close to the input number. Calculates interpretation scores of numeric values close to the input number.
Parameters: Parameters:
@ -266,7 +281,7 @@ class Number(InputComponent):
delta = self.interpretation_delta delta = self.interpretation_delta
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist() negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
positives = (x + np.arange(1, self.interpretation_steps+1) * delta).tolist() positives = (x + np.arange(1, self.interpretation_steps+1) * delta).tolist()
return negatives + positives, {}, False return negatives + positives, {}
def get_interpretation_scores(self, x, neighbors, scores): def get_interpretation_scores(self, x, neighbors, scores):
""" """
@ -305,6 +320,7 @@ class Slider(InputComponent):
self.step = step self.step = step
self.default = minimum if default is None else default self.default = minimum if default is None else default
self.test_input = self.default self.test_input = self.default
self.interpret_by_tokens = False
super().__init__(label) super().__init__(label)
def get_template_context(self): def get_template_context(self):
@ -329,7 +345,7 @@ class Slider(InputComponent):
""" """
return x return x
def interpret(self, steps=8): def set_interpret_parameters(self, steps=8):
""" """
Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider. Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider.
Parameters: Parameters:
@ -339,7 +355,7 @@ class Slider(InputComponent):
return self return self
def get_interpretation_neighbors(self, x): def get_interpretation_neighbors(self, x):
return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}, False return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}
def get_interpretation_scores(self, x, neighbors, scores): def get_interpretation_scores(self, x, neighbors, scores):
""" """
@ -367,6 +383,7 @@ class Checkbox(InputComponent):
""" """
self.test_input = True self.test_input = True
self.default = default self.default = default
self.interpret_by_tokens = False
super().__init__(label) super().__init__(label)
def get_template_context(self): def get_template_context(self):
@ -388,14 +405,14 @@ class Checkbox(InputComponent):
""" """
return x return x
def interpret(self): def set_interpret_parameters(self):
""" """
Calculates interpretation score of the input by comparing the output against the output when the input is the inverse boolean value of x. Calculates interpretation score of the input by comparing the output against the output when the input is the inverse boolean value of x.
""" """
return self return self
def get_interpretation_neighbors(self, x): def get_interpretation_neighbors(self, x):
return [not x], {}, False return [not x], {}
def get_interpretation_scores(self, x, neighbors, scores): def get_interpretation_scores(self, x, neighbors, scores):
""" """
@ -430,6 +447,7 @@ class CheckboxGroup(InputComponent):
self.default = default self.default = default
self.type = type self.type = type
self.test_input = self.choices self.test_input = self.choices
self.interpret_by_tokens = False
super().__init__(label) super().__init__(label)
def get_template_context(self): def get_template_context(self):
@ -447,7 +465,7 @@ class CheckboxGroup(InputComponent):
else: else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.") raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
def interpret(self): def set_interpret_parameters(self):
""" """
Calculates interpretation score of each choice in the input by comparing the output against the outputs when each choice in the input is independently either removed or added. Calculates interpretation score of each choice in the input by comparing the output against the outputs when each choice in the input is independently either removed or added.
""" """
@ -462,7 +480,7 @@ class CheckboxGroup(InputComponent):
else: else:
leave_one_out_set.append(choice) leave_one_out_set.append(choice)
leave_one_out_sets.append(leave_one_out_set) leave_one_out_sets.append(leave_one_out_set)
return leave_one_out_sets, {}, False return leave_one_out_sets, {}
def get_interpretation_scores(self, x, neighbors, scores): def get_interpretation_scores(self, x, neighbors, scores):
""" """
@ -514,6 +532,7 @@ class Radio(InputComponent):
self.type = type self.type = type
self.test_input = self.choices[0] self.test_input = self.choices[0]
self.default = default if default is not None else self.choices[0] self.default = default if default is not None else self.choices[0]
self.interpret_by_tokens = False
super().__init__(label) super().__init__(label)
def get_template_context(self): def get_template_context(self):
@ -531,7 +550,7 @@ class Radio(InputComponent):
else: else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.") raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
def interpret(self): def set_interpret_parameters(self):
""" """
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected. Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
""" """
@ -540,7 +559,7 @@ class Radio(InputComponent):
def get_interpretation_neighbors(self, x): def get_interpretation_neighbors(self, x):
choices = list(self.choices) choices = list(self.choices)
choices.remove(x) choices.remove(x)
return choices, {}, False return choices, {}
def get_interpretation_scores(self, x, neighbors, scores): def get_interpretation_scores(self, x, neighbors, scores):
""" """
@ -577,6 +596,7 @@ class Dropdown(InputComponent):
self.type = type self.type = type
self.test_input = self.choices[0] self.test_input = self.choices[0]
self.default = default if default is not None else self.choices[0] self.default = default if default is not None else self.choices[0]
self.interpret_by_tokens = False
super().__init__(label) super().__init__(label)
def get_template_context(self): def get_template_context(self):
@ -594,7 +614,7 @@ class Dropdown(InputComponent):
else: else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.") raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
def interpret(self): def set_interpret_parameters(self):
""" """
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected. Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
""" """
@ -603,7 +623,7 @@ class Dropdown(InputComponent):
def get_interpretation_neighbors(self, x): def get_interpretation_neighbors(self, x):
choices = list(self.choices) choices = list(self.choices)
choices.remove(x) choices.remove(x)
return choices, {}, False return choices, {}
def get_interpretation_scores(self, x, neighbors, scores): def get_interpretation_scores(self, x, neighbors, scores):
""" """
@ -647,6 +667,7 @@ class Image(InputComponent):
self.type = type self.type = type
self.invert_colors = invert_colors self.invert_colors = invert_colors
self.test_input = test_data.BASE64_IMAGE self.test_input = test_data.BASE64_IMAGE
self.interpret_by_tokens = True
super().__init__(label, requires_permissions) super().__init__(label, requires_permissions)
@classmethod @classmethod
@ -690,7 +711,7 @@ class Image(InputComponent):
def preprocess_example(self, x): def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x) return processing_utils.encode_file_to_base64(x)
def interpret(self, segments=16): def set_interpret_parameters(self, segments=16):
""" """
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value. Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
Parameters: Parameters:
@ -699,29 +720,60 @@ class Image(InputComponent):
self.interpretation_segments = segments self.interpretation_segments = segments
return self return self
def get_interpretation_neighbors(self, x): def _segment_by_slic(self, x):
"""
Helper method that segments an image into superpixels using slic.
Parameters:
x: base64 representation of an image
"""
x = processing_utils.decode_base64_to_image(x) x = processing_utils.decode_base64_to_image(x)
if self.shape is not None: if self.shape is not None:
x = processing_utils.resize_and_crop(x, self.shape) x = processing_utils.resize_and_crop(x, self.shape)
image = np.array(x) resized_and_cropped_image = np.array(x)
try: try:
from skimage.segmentation import slic from skimage.segmentation import slic
except ImportError: except (ImportError, ModuleNotFoundError):
print("Error: running default interpretation for images requires scikit-image, please install it first.") raise ValueError("Error: running this interpretation for images requires scikit-image, please install it first.")
return segments_slic = slic(
segments_slic = slic(image, self.interpretation_segments, compactness=10, sigma=1) resized_and_cropped_image, self.interpretation_segments, compactness=10,
leave_one_out_tokens, masks = [], [] sigma=1, start_label=0)
replace_color = np.mean(image, axis=(0, 1)) return segments_slic, resized_and_cropped_image
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = replace_color
leave_one_out_tokens.append(
processing_utils.encode_array_to_base64(white_screen))
masks.append(mask)
return leave_one_out_tokens, {"masks": masks}, True
def get_interpretation_scores(self, x, neighbors, scores, masks): def tokenize(self, x):
"""
Segments image into tokens, masks, and leave-one-out-tokens
Parameters:
x: base64 representation of an image
Returns:
tokens: list of tokens, used by the get_masked_input() method
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
masks: list of masks, used by the get_interpretation_neighbors() method
"""
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
tokens, masks, leave_one_out_tokens = [], [], []
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
for (i, segment_value) in enumerate(np.unique(segments_slic)):
mask = (segments_slic == segment_value)
image_screen = np.copy(resized_and_cropped_image)
image_screen[segments_slic == segment_value] = replace_color
leave_one_out_tokens.append(
processing_utils.encode_array_to_base64(image_screen))
token = np.copy(resized_and_cropped_image)
token[segments_slic != segment_value] = 0
tokens.append(token)
masks.append(mask)
return tokens, leave_one_out_tokens, masks
def get_masked_inputs(self, tokens, binary_mask_matrix):
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.zeros_like(tokens[0], dtype=int)
for token, b in zip(tokens, binary_mask_vector):
masked_input = masked_input + token*int(b)
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
""" """
Returns: Returns:
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image. (List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
@ -825,6 +877,7 @@ class Audio(InputComponent):
requires_permissions = source == "microphone" requires_permissions = source == "microphone"
self.type = type self.type = type
self.test_input = test_data.BASE64_AUDIO self.test_input = test_data.BASE64_AUDIO
self.interpret_by_tokens = True
super().__init__(label, requires_permissions) super().__init__(label, requires_permissions)
def get_template_context(self): def get_template_context(self):
@ -855,7 +908,7 @@ class Audio(InputComponent):
def preprocess_example(self, x): def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x, type="audio") return processing_utils.encode_file_to_base64(x, type="audio")
def interpret(self, segments=8): def set_interpret_parameters(self, segments=8):
""" """
Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value. Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value.
Parameters: Parameters:
@ -864,30 +917,66 @@ class Audio(InputComponent):
self.interpretation_segments = segments self.interpretation_segments = segments
return self return self
def get_interpretation_neighbors(self, x): def tokenize(self, x):
file_obj = processing_utils.decode_base64_to_file(x) file_obj = processing_utils.decode_base64_to_file(x)
x = scipy.io.wavfile.read(file_obj.name) x = scipy.io.wavfile.read(file_obj.name)
sample_rate, data = x sample_rate, data = x
leave_one_out_sets = [] leave_one_out_sets = []
tokens = []
masks = []
duration = data.shape[0] duration = data.shape[0]
boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist() boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist()
boundaries = [round(boundary) for boundary in boundaries] boundaries = [round(boundary) for boundary in boundaries]
for index in range(len(boundaries) - 1): for index in range(len(boundaries) - 1):
leave_one_out_data = np.copy(data)
start, stop = boundaries[index], boundaries[index + 1] start, stop = boundaries[index], boundaries[index + 1]
masks.append((start, stop))
# Handle the leave one outs
leave_one_out_data = np.copy(data)
leave_one_out_data[start:stop] = 0 leave_one_out_data[start:stop] = 0
file = tempfile.NamedTemporaryFile(delete=False) file = tempfile.NamedTemporaryFile(delete=False)
scipy.io.wavfile.write(file, sample_rate, leave_one_out_data) scipy.io.wavfile.write(file, sample_rate, leave_one_out_data)
out_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav") out_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
leave_one_out_sets.append(out_data) leave_one_out_sets.append(out_data)
return leave_one_out_sets, {}, True # Handle the tokens
token = np.copy(data)
token[0:start] = 0
token[stop:] = 0
file = tempfile.NamedTemporaryFile(delete=False)
scipy.io.wavfile.write(file, sample_rate, token)
token_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
tokens.append(token_data)
return tokens, leave_one_out_sets, masks
def get_interpretation_scores(self, x, neighbors, scores): def get_masked_inputs(self, tokens, binary_mask_matrix):
# create a "zero input" vector and get sample rate
x = tokens[0]
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = scipy.io.wavfile.read(file_obj.name)
zero_input = np.zeros_like(data, dtype=int)
# decode all of the tokens
token_data = []
for token in tokens:
file_obj = processing_utils.decode_base64_to_file(token)
_, data = scipy.io.wavfile.read(file_obj.name)
token_data.append(data)
# construct the masked version
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.copy(zero_input)
for t, b in zip(token_data, binary_mask_vector):
masked_input = masked_input + t*int(b)
file = tempfile.NamedTemporaryFile(delete=False)
scipy.io.wavfile.write(file, sample_rate, masked_input)
masked_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
masked_inputs.append(masked_data)
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, masks=None, tokens=None):
""" """
Returns: Returns:
(List[float]): Each value represents the interpretation score corresponding to an evenly spaced subsection of audio. (List[float]): Each value represents the interpretation score corresponding to an evenly spaced subsection of audio.
""" """
return scores return list(scores)
def embed(self, x): def embed(self, x):
""" """
@ -1049,35 +1138,35 @@ class Dataframe(InputComponent):
else: else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'pandas', 'numpy', 'array'.") raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'pandas', 'numpy', 'array'.")
def interpret(self): # def set_interpret_parameters(self):
""" # """
Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value. # Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value.
""" # """
return self # return self
def get_interpretation_neighbors(self, x): # def get_interpretation_neighbors(self, x):
x = pd.DataFrame(x) # x = pd.DataFrame(x)
leave_one_out_sets = [] # leave_one_out_sets = []
shape = x.shape # shape = x.shape
for i in range(shape[0]): # for i in range(shape[0]):
for j in range(shape[1]): # for j in range(shape[1]):
scalar = x.iloc[i, j] # scalar = x.iloc[i, j]
leave_one_out_df = x.copy() # leave_one_out_df = x.copy()
if is_bool_dtype(scalar): # if is_bool_dtype(scalar):
leave_one_out_df.iloc[i, j] = not scalar # leave_one_out_df.iloc[i, j] = not scalar
elif is_numeric_dtype(scalar): # elif is_numeric_dtype(scalar):
leave_one_out_df.iloc[i, j] = 0 # leave_one_out_df.iloc[i, j] = 0
elif is_string_dtype(scalar): # elif is_string_dtype(scalar):
leave_one_out_df.iloc[i, j] = "" # leave_one_out_df.iloc[i, j] = ""
leave_one_out_sets.append(leave_one_out_df.values.tolist()) # leave_one_out_sets.append(leave_one_out_df.values.tolist())
return leave_one_out_sets, {"shape": x.shape}, True # return leave_one_out_sets, {"shape": x.shape}
def get_interpretation_scores(self, x, neighbors, scores, shape): # def get_interpretation_scores(self, x, neighbors, scores, shape):
""" # """
Returns: # Returns:
(List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell. # (List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell.
""" # """
return np.array(scores).reshape((shape)).tolist() # return np.array(scores).reshape((shape)).tolist()
def embed(self, x): def embed(self, x):
raise NotImplementedError("DataFrame doesn't currently support embeddings") raise NotImplementedError("DataFrame doesn't currently support embeddings")

View File

@ -7,7 +7,7 @@ import gradio
from gradio.inputs import InputComponent, get_input_instance from gradio.inputs import InputComponent, get_input_instance
from gradio.outputs import OutputComponent, get_output_instance from gradio.outputs import OutputComponent, get_output_instance
from gradio import networking, strings, utils from gradio import networking, strings, utils
from gradio.interpretation import quantify_difference_in_label from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
from gradio.external import load_interface from gradio.external import load_interface
from gradio import encryptor from gradio import encryptor
import pkg_resources import pkg_resources
@ -66,8 +66,8 @@ class Interface:
def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None, def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None,
examples_per_page=10, live=False, examples_per_page=10, live=False,
layout="horizontal", show_input=True, show_output=True, layout="horizontal", show_input=True, show_output=True,
capture_session=False, interpretation=None, theme=None, repeat_outputs_per_model=True, capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
title=None, description=None, article=None, thumbnail=None, title=None, description=None, article=None, thumbnail=None,
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900, css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False, allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True): show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True):
@ -84,6 +84,7 @@ class Interface:
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically. layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x) capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter. interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results.
title (str): a title for the interface; if provided, appears above the input and output components. title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components. description (str): a description for the interface; if provided, appears above the input and output components.
article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content. article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content.
@ -145,6 +146,7 @@ class Interface:
self.examples = examples self.examples = examples
else: else:
raise ValueError("Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.") raise ValueError("Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.")
self.num_shap = num_shap
self.examples_per_page = examples_per_page self.examples_per_page = examples_per_page
self.server_port = server_port self.server_port = server_port
self.simple_server = None self.simple_server = None
@ -334,7 +336,7 @@ class Interface:
interpretation for a certain set of UI component types, as well as the custom interpretation case. interpretation for a certain set of UI component types, as well as the custom interpretation case.
:param raw_input: a list of raw inputs to apply the interpretation(s) on. :param raw_input: a list of raw inputs to apply the interpretation(s) on.
""" """
if self.interpretation == "default": if self.interpretation.lower() == "default":
processed_input = [input_component.preprocess(raw_input[i]) processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)] for i, input_component in enumerate(self.input_components)]
original_output = self.run_prediction(processed_input) original_output = self.run_prediction(processed_input)
@ -342,26 +344,75 @@ class Interface:
for i, x in enumerate(raw_input): for i, x in enumerate(raw_input):
input_component = self.input_components[i] input_component = self.input_components[i]
neighbor_raw_input = list(raw_input) neighbor_raw_input = list(raw_input)
neighbor_values, interpret_kwargs, interpret_by_removal = input_component.get_interpretation_neighbors(x) if input_component.interpret_by_tokens:
interface_scores = [] tokens, neighbor_values, masks = input_component.tokenize(x)
alternative_output = [] interface_scores = []
for neighbor_input in neighbor_values: alternative_output = []
neighbor_raw_input[i] = neighbor_input for neighbor_input in neighbor_values:
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i]) neighbor_raw_input[i] = neighbor_input
for i, input_component in enumerate(self.input_components)] processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
neighbor_output = self.run_prediction(processed_neighbor_input) for i, input_component in enumerate(self.input_components)]
processed_neighbor_output = [output_component.postprocess( neighbor_output = self.run_prediction(processed_neighbor_input)
neighbor_output[i]) for i, output_component in enumerate(self.output_components)] processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
alternative_output.append(processed_neighbor_output) alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output)) interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output) alternative_outputs.append(alternative_output)
if not interpret_by_removal: scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
else:
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(self.input_components)]
neighbor_output = self.run_prediction(processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output)
interface_scores = [-score for score in interface_scores] interface_scores = [-score for score in interface_scores]
scores.append( scores.append(
input_component.get_interpretation_scores( input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs)) raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
return scores, alternative_outputs return scores, alternative_outputs
elif self.interpretation.lower() == "shap":
scores = []
try:
import shap
except (ImportError, ModuleNotFoundError):
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]
original_output = self.run_prediction(processed_input)
for i, x in enumerate(raw_input): # iterate over reach interface
input_component = self.input_components[i]
tokens, _, masks = input_component.tokenize(x)
def get_masked_prediction(binary_mask): # construct a masked version of the input
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
preds = []
for masked_x in masked_xs:
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = self.run_prediction(processed_masked_input)
pred = get_regression_or_classification_value(self, original_output, new_output)
preds.append(pred)
return np.array(preds)
num_total_segments = len(tokens)
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
return scores, []
else: else:
processed_input = [input_component.preprocess(raw_input[i]) processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)] for i, input_component in enumerate(self.input_components)]

View File

@ -1,4 +1,5 @@
from gradio.outputs import Label, Textbox from gradio.outputs import Label, Textbox
import math
def diff(original, perturbed): def diff(original, perturbed):
try: # try computing numerical difference try: # try computing numerical difference
@ -24,7 +25,33 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
else: else:
score = diff(original_label, perturbed_label) score = diff(original_label, perturbed_label)
return score return score
elif type(output_component) == Textbox: elif type(output_component) == Textbox:
score = diff(post_original_output, post_perturbed_output) score = diff(post_original_output, post_perturbed_output)
return score return score
else:
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
def get_regression_or_classification_value(interface, original_output, perturbed_output):
"""Used to combine regression/classification for Shap interpretation method."""
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
post_perturbed_output = output_component.postprocess(perturbed_output[0])
if type(output_component) == Label:
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
if math.isnan(perturbed_output[0][original_label]):
return 0
return perturbed_output[0][original_label]
else:
score = diff(perturbed_label, original_label) # Intentionall inverted order of arguments.
return score
else:
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))