added shap interpretation for text, image, and audio

This commit is contained in:
Abubakar Abid 2021-07-12 12:44:22 -05:00
parent 4c2efc54a8
commit 9badae0570
6 changed files with 266 additions and 121 deletions

View File

@ -37,7 +37,7 @@ iface = gr.Interface(
)
],
"number",
interpretation="default",
# interpretation="default", # Removed interpretation for dataframes
examples=[
[10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]],
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]],

View File

@ -1,4 +1,4 @@
Metadata-Version: 1.0
Metadata-Version: 2.1
Name: gradio
Version: 2.1.1
Summary: Python library for easily interacting with trained machine learning models
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid
Author-email: a12d@stanford.edu
License: Apache License 2.0
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN
License-File: LICENSE
UNKNOWN

View File

@ -1,3 +1,4 @@
LICENSE
MANIFEST.in
README.md
setup.py
@ -26,32 +27,6 @@ gradio.egg-info/requires.txt
gradio.egg-info/top_level.txt
gradio/frontend/asset-manifest.json
gradio/frontend/index.html
gradio/frontend/static/bundle.css
gradio/frontend/static/bundle.css.map
gradio/frontend/static/bundle.js
gradio/frontend/static/bundle.js.LICENSE.txt
gradio/frontend/static/bundle.js.map
gradio/frontend/static/css/main.0094f11b.css
gradio/frontend/static/css/main.0094f11b.css.map
gradio/frontend/static/css/main.20be28ac.css
gradio/frontend/static/css/main.20be28ac.css.map
gradio/frontend/static/css/main.2b64a968.css
gradio/frontend/static/css/main.2b64a968.css.map
gradio/frontend/static/css/main.380e3222.css
gradio/frontend/static/css/main.380e3222.css.map
gradio/frontend/static/css/main.4aea80f8.css
gradio/frontend/static/css/main.4aea80f8.css.map
gradio/frontend/static/css/main.4f157d97.css
gradio/frontend/static/css/main.4f157d97.css.map
gradio/frontend/static/css/main.5c663906.css
gradio/frontend/static/css/main.5c663906.css.map
gradio/frontend/static/css/main.99922310.css
gradio/frontend/static/css/main.99922310.css.map
gradio/frontend/static/css/main.acb02c85.css
gradio/frontend/static/css/main.acb02c85.css.map
gradio/frontend/static/css/main.cbbf8898.css
gradio/frontend/static/css/main.cbbf8898.css.map
gradio/frontend/static/media/logo_loading.e93acd82.jpg
test/test_demos.py
test/test_inputs.py
test/test_interfaces.py

View File

@ -29,7 +29,7 @@ class InputComponent(Component):
Input Component. All input components subclass this.
"""
def __init__(self, label, requires_permissions=False):
self.interpret()
self.set_interpret_parameters()
super().__init__(label, requires_permissions)
def preprocess(self, x):
@ -44,7 +44,7 @@ class InputComponent(Component):
"""
return x
def interpret(self):
def set_interpret_parameters(self):
'''
Set any parameters for interpretation.
'''
@ -115,6 +115,7 @@ class Textbox(InputComponent):
}[type]
else:
self.test_input = default
self.interpret_by_tokens = True
super().__init__(label)
def get_template_context(self):
@ -147,7 +148,7 @@ class Textbox(InputComponent):
"""
return x
def interpret(self, separator=" ", replacement=None):
def set_interpret_parameters(self, separator=" ", replacement=None):
"""
Calculates interpretation score of characters in input by splitting input into tokens, then using a "leave one out" method to calculate the score of each token by removing each token and measuring the delta of the output value.
Parameters:
@ -158,7 +159,10 @@ class Textbox(InputComponent):
self.interpretation_replacement = replacement
return self
def get_interpretation_neighbors(self, x):
def tokenize(self, x):
"""
Tokenizes an input string by dividing into "words" delimited by self.interpretation_separator
"""
tokens = x.split(self.interpretation_separator)
leave_one_out_strings = []
for index in range(len(tokens)):
@ -168,9 +172,19 @@ class Textbox(InputComponent):
else:
leave_one_out_set[index] = self.interpretation_replacement
leave_one_out_strings.append(self.interpretation_separator.join(leave_one_out_set))
return leave_one_out_strings, {"tokens": tokens}, True
return tokens, leave_one_out_strings, None
def get_interpretation_scores(self, x, neighbors, scores, tokens):
def get_masked_inputs(self, tokens, binary_mask_matrix):
"""
Constructs partially-masked sentences for SHAP interpretation
"""
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.array(tokens)[np.array(binary_mask_vector, dtype=bool)]
masked_inputs.append(self.interpretation_separator.join(masked_input))
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, tokens, masks=None):
"""
Returns:
(List[Tuple[str, float]]): Each tuple set represents a set of characters and their corresponding interpretation score.
@ -216,6 +230,7 @@ class Number(InputComponent):
'''
self.default = default
self.test_input = default if default is not None else 1
self.interpret_by_tokens = False
super().__init__(label)
def get_template_context(self):
@ -244,7 +259,7 @@ class Number(InputComponent):
"""
return x
def interpret(self, steps=3, delta=1, delta_type="percent"):
def set_interpret_parameters(self, steps=3, delta=1, delta_type="percent"):
"""
Calculates interpretation scores of numeric values close to the input number.
Parameters:
@ -266,7 +281,7 @@ class Number(InputComponent):
delta = self.interpretation_delta
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
positives = (x + np.arange(1, self.interpretation_steps+1) * delta).tolist()
return negatives + positives, {}, False
return negatives + positives, {}
def get_interpretation_scores(self, x, neighbors, scores):
"""
@ -305,6 +320,7 @@ class Slider(InputComponent):
self.step = step
self.default = minimum if default is None else default
self.test_input = self.default
self.interpret_by_tokens = False
super().__init__(label)
def get_template_context(self):
@ -329,7 +345,7 @@ class Slider(InputComponent):
"""
return x
def interpret(self, steps=8):
def set_interpret_parameters(self, steps=8):
"""
Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider.
Parameters:
@ -339,7 +355,7 @@ class Slider(InputComponent):
return self
def get_interpretation_neighbors(self, x):
return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}, False
return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}
def get_interpretation_scores(self, x, neighbors, scores):
"""
@ -367,6 +383,7 @@ class Checkbox(InputComponent):
"""
self.test_input = True
self.default = default
self.interpret_by_tokens = False
super().__init__(label)
def get_template_context(self):
@ -388,14 +405,14 @@ class Checkbox(InputComponent):
"""
return x
def interpret(self):
def set_interpret_parameters(self):
"""
Calculates interpretation score of the input by comparing the output against the output when the input is the inverse boolean value of x.
"""
return self
def get_interpretation_neighbors(self, x):
return [not x], {}, False
return [not x], {}
def get_interpretation_scores(self, x, neighbors, scores):
"""
@ -430,6 +447,7 @@ class CheckboxGroup(InputComponent):
self.default = default
self.type = type
self.test_input = self.choices
self.interpret_by_tokens = False
super().__init__(label)
def get_template_context(self):
@ -447,7 +465,7 @@ class CheckboxGroup(InputComponent):
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
def interpret(self):
def set_interpret_parameters(self):
"""
Calculates interpretation score of each choice in the input by comparing the output against the outputs when each choice in the input is independently either removed or added.
"""
@ -462,7 +480,7 @@ class CheckboxGroup(InputComponent):
else:
leave_one_out_set.append(choice)
leave_one_out_sets.append(leave_one_out_set)
return leave_one_out_sets, {}, False
return leave_one_out_sets, {}
def get_interpretation_scores(self, x, neighbors, scores):
"""
@ -514,6 +532,7 @@ class Radio(InputComponent):
self.type = type
self.test_input = self.choices[0]
self.default = default if default is not None else self.choices[0]
self.interpret_by_tokens = False
super().__init__(label)
def get_template_context(self):
@ -531,7 +550,7 @@ class Radio(InputComponent):
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
def interpret(self):
def set_interpret_parameters(self):
"""
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
"""
@ -540,7 +559,7 @@ class Radio(InputComponent):
def get_interpretation_neighbors(self, x):
choices = list(self.choices)
choices.remove(x)
return choices, {}, False
return choices, {}
def get_interpretation_scores(self, x, neighbors, scores):
"""
@ -577,6 +596,7 @@ class Dropdown(InputComponent):
self.type = type
self.test_input = self.choices[0]
self.default = default if default is not None else self.choices[0]
self.interpret_by_tokens = False
super().__init__(label)
def get_template_context(self):
@ -594,7 +614,7 @@ class Dropdown(InputComponent):
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
def interpret(self):
def set_interpret_parameters(self):
"""
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
"""
@ -603,7 +623,7 @@ class Dropdown(InputComponent):
def get_interpretation_neighbors(self, x):
choices = list(self.choices)
choices.remove(x)
return choices, {}, False
return choices, {}
def get_interpretation_scores(self, x, neighbors, scores):
"""
@ -647,6 +667,7 @@ class Image(InputComponent):
self.type = type
self.invert_colors = invert_colors
self.test_input = test_data.BASE64_IMAGE
self.interpret_by_tokens = True
super().__init__(label, requires_permissions)
@classmethod
@ -690,7 +711,7 @@ class Image(InputComponent):
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x)
def interpret(self, segments=16):
def set_interpret_parameters(self, segments=16):
"""
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
Parameters:
@ -699,29 +720,60 @@ class Image(InputComponent):
self.interpretation_segments = segments
return self
def get_interpretation_neighbors(self, x):
def _segment_by_slic(self, x):
"""
Helper method that segments an image into superpixels using slic.
Parameters:
x: base64 representation of an image
"""
x = processing_utils.decode_base64_to_image(x)
if self.shape is not None:
x = processing_utils.resize_and_crop(x, self.shape)
image = np.array(x)
resized_and_cropped_image = np.array(x)
try:
from skimage.segmentation import slic
except ImportError:
print("Error: running default interpretation for images requires scikit-image, please install it first.")
return
segments_slic = slic(image, self.interpretation_segments, compactness=10, sigma=1)
leave_one_out_tokens, masks = [], []
replace_color = np.mean(image, axis=(0, 1))
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = replace_color
leave_one_out_tokens.append(
processing_utils.encode_array_to_base64(white_screen))
masks.append(mask)
return leave_one_out_tokens, {"masks": masks}, True
except (ImportError, ModuleNotFoundError):
raise ValueError("Error: running this interpretation for images requires scikit-image, please install it first.")
segments_slic = slic(
resized_and_cropped_image, self.interpretation_segments, compactness=10,
sigma=1, start_label=0)
return segments_slic, resized_and_cropped_image
def get_interpretation_scores(self, x, neighbors, scores, masks):
def tokenize(self, x):
"""
Segments image into tokens, masks, and leave-one-out-tokens
Parameters:
x: base64 representation of an image
Returns:
tokens: list of tokens, used by the get_masked_input() method
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
masks: list of masks, used by the get_interpretation_neighbors() method
"""
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
tokens, masks, leave_one_out_tokens = [], [], []
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
for (i, segment_value) in enumerate(np.unique(segments_slic)):
mask = (segments_slic == segment_value)
image_screen = np.copy(resized_and_cropped_image)
image_screen[segments_slic == segment_value] = replace_color
leave_one_out_tokens.append(
processing_utils.encode_array_to_base64(image_screen))
token = np.copy(resized_and_cropped_image)
token[segments_slic != segment_value] = 0
tokens.append(token)
masks.append(mask)
return tokens, leave_one_out_tokens, masks
def get_masked_inputs(self, tokens, binary_mask_matrix):
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.zeros_like(tokens[0], dtype=int)
for token, b in zip(tokens, binary_mask_vector):
masked_input = masked_input + token*int(b)
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
"""
Returns:
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
@ -825,6 +877,7 @@ class Audio(InputComponent):
requires_permissions = source == "microphone"
self.type = type
self.test_input = test_data.BASE64_AUDIO
self.interpret_by_tokens = True
super().__init__(label, requires_permissions)
def get_template_context(self):
@ -855,7 +908,7 @@ class Audio(InputComponent):
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x, type="audio")
def interpret(self, segments=8):
def set_interpret_parameters(self, segments=8):
"""
Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value.
Parameters:
@ -864,30 +917,66 @@ class Audio(InputComponent):
self.interpretation_segments = segments
return self
def get_interpretation_neighbors(self, x):
def tokenize(self, x):
file_obj = processing_utils.decode_base64_to_file(x)
x = scipy.io.wavfile.read(file_obj.name)
sample_rate, data = x
leave_one_out_sets = []
tokens = []
masks = []
duration = data.shape[0]
boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist()
boundaries = [round(boundary) for boundary in boundaries]
for index in range(len(boundaries) - 1):
leave_one_out_data = np.copy(data)
start, stop = boundaries[index], boundaries[index + 1]
masks.append((start, stop))
# Handle the leave one outs
leave_one_out_data = np.copy(data)
leave_one_out_data[start:stop] = 0
file = tempfile.NamedTemporaryFile(delete=False)
scipy.io.wavfile.write(file, sample_rate, leave_one_out_data)
out_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
leave_one_out_sets.append(out_data)
return leave_one_out_sets, {}, True
# Handle the tokens
token = np.copy(data)
token[0:start] = 0
token[stop:] = 0
file = tempfile.NamedTemporaryFile(delete=False)
scipy.io.wavfile.write(file, sample_rate, token)
token_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
tokens.append(token_data)
return tokens, leave_one_out_sets, masks
def get_interpretation_scores(self, x, neighbors, scores):
def get_masked_inputs(self, tokens, binary_mask_matrix):
# create a "zero input" vector and get sample rate
x = tokens[0]
file_obj = processing_utils.decode_base64_to_file(x)
sample_rate, data = scipy.io.wavfile.read(file_obj.name)
zero_input = np.zeros_like(data, dtype=int)
# decode all of the tokens
token_data = []
for token in tokens:
file_obj = processing_utils.decode_base64_to_file(token)
_, data = scipy.io.wavfile.read(file_obj.name)
token_data.append(data)
# construct the masked version
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.copy(zero_input)
for t, b in zip(token_data, binary_mask_vector):
masked_input = masked_input + t*int(b)
file = tempfile.NamedTemporaryFile(delete=False)
scipy.io.wavfile.write(file, sample_rate, masked_input)
masked_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
masked_inputs.append(masked_data)
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, masks=None, tokens=None):
"""
Returns:
(List[float]): Each value represents the interpretation score corresponding to an evenly spaced subsection of audio.
"""
return scores
return list(scores)
def embed(self, x):
"""
@ -1049,35 +1138,35 @@ class Dataframe(InputComponent):
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'pandas', 'numpy', 'array'.")
def interpret(self):
"""
Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value.
"""
return self
# def set_interpret_parameters(self):
# """
# Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value.
# """
# return self
def get_interpretation_neighbors(self, x):
x = pd.DataFrame(x)
leave_one_out_sets = []
shape = x.shape
for i in range(shape[0]):
for j in range(shape[1]):
scalar = x.iloc[i, j]
leave_one_out_df = x.copy()
if is_bool_dtype(scalar):
leave_one_out_df.iloc[i, j] = not scalar
elif is_numeric_dtype(scalar):
leave_one_out_df.iloc[i, j] = 0
elif is_string_dtype(scalar):
leave_one_out_df.iloc[i, j] = ""
leave_one_out_sets.append(leave_one_out_df.values.tolist())
return leave_one_out_sets, {"shape": x.shape}, True
# def get_interpretation_neighbors(self, x):
# x = pd.DataFrame(x)
# leave_one_out_sets = []
# shape = x.shape
# for i in range(shape[0]):
# for j in range(shape[1]):
# scalar = x.iloc[i, j]
# leave_one_out_df = x.copy()
# if is_bool_dtype(scalar):
# leave_one_out_df.iloc[i, j] = not scalar
# elif is_numeric_dtype(scalar):
# leave_one_out_df.iloc[i, j] = 0
# elif is_string_dtype(scalar):
# leave_one_out_df.iloc[i, j] = ""
# leave_one_out_sets.append(leave_one_out_df.values.tolist())
# return leave_one_out_sets, {"shape": x.shape}
def get_interpretation_scores(self, x, neighbors, scores, shape):
"""
Returns:
(List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell.
"""
return np.array(scores).reshape((shape)).tolist()
# def get_interpretation_scores(self, x, neighbors, scores, shape):
# """
# Returns:
# (List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell.
# """
# return np.array(scores).reshape((shape)).tolist()
def embed(self, x):
raise NotImplementedError("DataFrame doesn't currently support embeddings")

View File

@ -7,7 +7,7 @@ import gradio
from gradio.inputs import InputComponent, get_input_instance
from gradio.outputs import OutputComponent, get_output_instance
from gradio import networking, strings, utils
from gradio.interpretation import quantify_difference_in_label
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
from gradio.external import load_interface
from gradio import encryptor
import pkg_resources
@ -66,7 +66,7 @@ class Interface:
def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None,
examples_per_page=10, live=False,
layout="horizontal", show_input=True, show_output=True,
capture_session=False, interpretation=None, theme=None, repeat_outputs_per_model=True,
capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
title=None, description=None, article=None, thumbnail=None,
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
@ -84,6 +84,7 @@ class Interface:
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results.
title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components.
article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content.
@ -145,6 +146,7 @@ class Interface:
self.examples = examples
else:
raise ValueError("Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.")
self.num_shap = num_shap
self.examples_per_page = examples_per_page
self.server_port = server_port
self.simple_server = None
@ -334,7 +336,7 @@ class Interface:
interpretation for a certain set of UI component types, as well as the custom interpretation case.
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
"""
if self.interpretation == "default":
if self.interpretation.lower() == "default":
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]
original_output = self.run_prediction(processed_input)
@ -342,7 +344,26 @@ class Interface:
for i, x in enumerate(raw_input):
input_component = self.input_components[i]
neighbor_raw_input = list(raw_input)
neighbor_values, interpret_kwargs, interpret_by_removal = input_component.get_interpretation_neighbors(x)
if input_component.interpret_by_tokens:
tokens, neighbor_values, masks = input_component.tokenize(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(self.input_components)]
neighbor_output = self.run_prediction(processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output)
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
else:
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
@ -356,12 +377,42 @@ class Interface:
alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output)
if not interpret_by_removal:
interface_scores = [-score for score in interface_scores]
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
return scores, alternative_outputs
elif self.interpretation.lower() == "shap":
scores = []
try:
import shap
except (ImportError, ModuleNotFoundError):
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]
original_output = self.run_prediction(processed_input)
for i, x in enumerate(raw_input): # iterate over reach interface
input_component = self.input_components[i]
tokens, _, masks = input_component.tokenize(x)
def get_masked_prediction(binary_mask): # construct a masked version of the input
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
preds = []
for masked_x in masked_xs:
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = self.run_prediction(processed_masked_input)
pred = get_regression_or_classification_value(self, original_output, new_output)
preds.append(pred)
return np.array(preds)
num_total_segments = len(tokens)
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
return scores, []
else:
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]

View File

@ -1,4 +1,5 @@
from gradio.outputs import Label, Textbox
import math
def diff(original, perturbed):
try: # try computing numerical difference
@ -24,7 +25,33 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
else:
score = diff(original_label, perturbed_label)
return score
elif type(output_component) == Textbox:
score = diff(post_original_output, post_perturbed_output)
return score
else:
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
def get_regression_or_classification_value(interface, original_output, perturbed_output):
"""Used to combine regression/classification for Shap interpretation method."""
output_component = interface.output_components[0]
post_original_output = output_component.postprocess(original_output[0])
post_perturbed_output = output_component.postprocess(perturbed_output[0])
if type(output_component) == Label:
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
if math.isnan(perturbed_output[0][original_label]):
return 0
return perturbed_output[0][original_label]
else:
score = diff(perturbed_label, original_label) # Intentionall inverted order of arguments.
return score
else:
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))