mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
added shap interpretation for text, image, and audio
This commit is contained in:
parent
4c2efc54a8
commit
9badae0570
@ -37,7 +37,7 @@ iface = gr.Interface(
|
||||
)
|
||||
],
|
||||
"number",
|
||||
interpretation="default",
|
||||
# interpretation="default", # Removed interpretation for dataframes
|
||||
examples=[
|
||||
[10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]],
|
||||
[80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]],
|
||||
|
@ -1,4 +1,4 @@
|
||||
Metadata-Version: 1.0
|
||||
Metadata-Version: 2.1
|
||||
Name: gradio
|
||||
Version: 2.1.1
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
|
||||
Author: Abubakar Abid
|
||||
Author-email: a12d@stanford.edu
|
||||
License: Apache License 2.0
|
||||
Description: UNKNOWN
|
||||
Keywords: machine learning,visualization,reproducibility
|
||||
Platform: UNKNOWN
|
||||
License-File: LICENSE
|
||||
|
||||
UNKNOWN
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
LICENSE
|
||||
MANIFEST.in
|
||||
README.md
|
||||
setup.py
|
||||
@ -26,32 +27,6 @@ gradio.egg-info/requires.txt
|
||||
gradio.egg-info/top_level.txt
|
||||
gradio/frontend/asset-manifest.json
|
||||
gradio/frontend/index.html
|
||||
gradio/frontend/static/bundle.css
|
||||
gradio/frontend/static/bundle.css.map
|
||||
gradio/frontend/static/bundle.js
|
||||
gradio/frontend/static/bundle.js.LICENSE.txt
|
||||
gradio/frontend/static/bundle.js.map
|
||||
gradio/frontend/static/css/main.0094f11b.css
|
||||
gradio/frontend/static/css/main.0094f11b.css.map
|
||||
gradio/frontend/static/css/main.20be28ac.css
|
||||
gradio/frontend/static/css/main.20be28ac.css.map
|
||||
gradio/frontend/static/css/main.2b64a968.css
|
||||
gradio/frontend/static/css/main.2b64a968.css.map
|
||||
gradio/frontend/static/css/main.380e3222.css
|
||||
gradio/frontend/static/css/main.380e3222.css.map
|
||||
gradio/frontend/static/css/main.4aea80f8.css
|
||||
gradio/frontend/static/css/main.4aea80f8.css.map
|
||||
gradio/frontend/static/css/main.4f157d97.css
|
||||
gradio/frontend/static/css/main.4f157d97.css.map
|
||||
gradio/frontend/static/css/main.5c663906.css
|
||||
gradio/frontend/static/css/main.5c663906.css.map
|
||||
gradio/frontend/static/css/main.99922310.css
|
||||
gradio/frontend/static/css/main.99922310.css.map
|
||||
gradio/frontend/static/css/main.acb02c85.css
|
||||
gradio/frontend/static/css/main.acb02c85.css.map
|
||||
gradio/frontend/static/css/main.cbbf8898.css
|
||||
gradio/frontend/static/css/main.cbbf8898.css.map
|
||||
gradio/frontend/static/media/logo_loading.e93acd82.jpg
|
||||
test/test_demos.py
|
||||
test/test_inputs.py
|
||||
test/test_interfaces.py
|
||||
|
227
gradio/inputs.py
227
gradio/inputs.py
@ -29,7 +29,7 @@ class InputComponent(Component):
|
||||
Input Component. All input components subclass this.
|
||||
"""
|
||||
def __init__(self, label, requires_permissions=False):
|
||||
self.interpret()
|
||||
self.set_interpret_parameters()
|
||||
super().__init__(label, requires_permissions)
|
||||
|
||||
def preprocess(self, x):
|
||||
@ -44,7 +44,7 @@ class InputComponent(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def interpret(self):
|
||||
def set_interpret_parameters(self):
|
||||
'''
|
||||
Set any parameters for interpretation.
|
||||
'''
|
||||
@ -115,6 +115,7 @@ class Textbox(InputComponent):
|
||||
}[type]
|
||||
else:
|
||||
self.test_input = default
|
||||
self.interpret_by_tokens = True
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -147,7 +148,7 @@ class Textbox(InputComponent):
|
||||
"""
|
||||
return x
|
||||
|
||||
def interpret(self, separator=" ", replacement=None):
|
||||
def set_interpret_parameters(self, separator=" ", replacement=None):
|
||||
"""
|
||||
Calculates interpretation score of characters in input by splitting input into tokens, then using a "leave one out" method to calculate the score of each token by removing each token and measuring the delta of the output value.
|
||||
Parameters:
|
||||
@ -158,7 +159,10 @@ class Textbox(InputComponent):
|
||||
self.interpretation_replacement = replacement
|
||||
return self
|
||||
|
||||
def get_interpretation_neighbors(self, x):
|
||||
def tokenize(self, x):
|
||||
"""
|
||||
Tokenizes an input string by dividing into "words" delimited by self.interpretation_separator
|
||||
"""
|
||||
tokens = x.split(self.interpretation_separator)
|
||||
leave_one_out_strings = []
|
||||
for index in range(len(tokens)):
|
||||
@ -168,9 +172,19 @@ class Textbox(InputComponent):
|
||||
else:
|
||||
leave_one_out_set[index] = self.interpretation_replacement
|
||||
leave_one_out_strings.append(self.interpretation_separator.join(leave_one_out_set))
|
||||
return leave_one_out_strings, {"tokens": tokens}, True
|
||||
return tokens, leave_one_out_strings, None
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, tokens):
|
||||
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||
"""
|
||||
Constructs partially-masked sentences for SHAP interpretation
|
||||
"""
|
||||
masked_inputs = []
|
||||
for binary_mask_vector in binary_mask_matrix:
|
||||
masked_input = np.array(tokens)[np.array(binary_mask_vector, dtype=bool)]
|
||||
masked_inputs.append(self.interpretation_separator.join(masked_input))
|
||||
return masked_inputs
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, tokens, masks=None):
|
||||
"""
|
||||
Returns:
|
||||
(List[Tuple[str, float]]): Each tuple set represents a set of characters and their corresponding interpretation score.
|
||||
@ -216,6 +230,7 @@ class Number(InputComponent):
|
||||
'''
|
||||
self.default = default
|
||||
self.test_input = default if default is not None else 1
|
||||
self.interpret_by_tokens = False
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -244,7 +259,7 @@ class Number(InputComponent):
|
||||
"""
|
||||
return x
|
||||
|
||||
def interpret(self, steps=3, delta=1, delta_type="percent"):
|
||||
def set_interpret_parameters(self, steps=3, delta=1, delta_type="percent"):
|
||||
"""
|
||||
Calculates interpretation scores of numeric values close to the input number.
|
||||
Parameters:
|
||||
@ -266,7 +281,7 @@ class Number(InputComponent):
|
||||
delta = self.interpretation_delta
|
||||
negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist()
|
||||
positives = (x + np.arange(1, self.interpretation_steps+1) * delta).tolist()
|
||||
return negatives + positives, {}, False
|
||||
return negatives + positives, {}
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores):
|
||||
"""
|
||||
@ -305,6 +320,7 @@ class Slider(InputComponent):
|
||||
self.step = step
|
||||
self.default = minimum if default is None else default
|
||||
self.test_input = self.default
|
||||
self.interpret_by_tokens = False
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -329,7 +345,7 @@ class Slider(InputComponent):
|
||||
"""
|
||||
return x
|
||||
|
||||
def interpret(self, steps=8):
|
||||
def set_interpret_parameters(self, steps=8):
|
||||
"""
|
||||
Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider.
|
||||
Parameters:
|
||||
@ -339,7 +355,7 @@ class Slider(InputComponent):
|
||||
return self
|
||||
|
||||
def get_interpretation_neighbors(self, x):
|
||||
return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}, False
|
||||
return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores):
|
||||
"""
|
||||
@ -367,6 +383,7 @@ class Checkbox(InputComponent):
|
||||
"""
|
||||
self.test_input = True
|
||||
self.default = default
|
||||
self.interpret_by_tokens = False
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -388,14 +405,14 @@ class Checkbox(InputComponent):
|
||||
"""
|
||||
return x
|
||||
|
||||
def interpret(self):
|
||||
def set_interpret_parameters(self):
|
||||
"""
|
||||
Calculates interpretation score of the input by comparing the output against the output when the input is the inverse boolean value of x.
|
||||
"""
|
||||
return self
|
||||
|
||||
def get_interpretation_neighbors(self, x):
|
||||
return [not x], {}, False
|
||||
return [not x], {}
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores):
|
||||
"""
|
||||
@ -430,6 +447,7 @@ class CheckboxGroup(InputComponent):
|
||||
self.default = default
|
||||
self.type = type
|
||||
self.test_input = self.choices
|
||||
self.interpret_by_tokens = False
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -447,7 +465,7 @@ class CheckboxGroup(InputComponent):
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
def interpret(self):
|
||||
def set_interpret_parameters(self):
|
||||
"""
|
||||
Calculates interpretation score of each choice in the input by comparing the output against the outputs when each choice in the input is independently either removed or added.
|
||||
"""
|
||||
@ -462,7 +480,7 @@ class CheckboxGroup(InputComponent):
|
||||
else:
|
||||
leave_one_out_set.append(choice)
|
||||
leave_one_out_sets.append(leave_one_out_set)
|
||||
return leave_one_out_sets, {}, False
|
||||
return leave_one_out_sets, {}
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores):
|
||||
"""
|
||||
@ -514,6 +532,7 @@ class Radio(InputComponent):
|
||||
self.type = type
|
||||
self.test_input = self.choices[0]
|
||||
self.default = default if default is not None else self.choices[0]
|
||||
self.interpret_by_tokens = False
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -531,7 +550,7 @@ class Radio(InputComponent):
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
def interpret(self):
|
||||
def set_interpret_parameters(self):
|
||||
"""
|
||||
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
|
||||
"""
|
||||
@ -540,7 +559,7 @@ class Radio(InputComponent):
|
||||
def get_interpretation_neighbors(self, x):
|
||||
choices = list(self.choices)
|
||||
choices.remove(x)
|
||||
return choices, {}, False
|
||||
return choices, {}
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores):
|
||||
"""
|
||||
@ -577,6 +596,7 @@ class Dropdown(InputComponent):
|
||||
self.type = type
|
||||
self.test_input = self.choices[0]
|
||||
self.default = default if default is not None else self.choices[0]
|
||||
self.interpret_by_tokens = False
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -594,7 +614,7 @@ class Dropdown(InputComponent):
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
def interpret(self):
|
||||
def set_interpret_parameters(self):
|
||||
"""
|
||||
Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected.
|
||||
"""
|
||||
@ -603,7 +623,7 @@ class Dropdown(InputComponent):
|
||||
def get_interpretation_neighbors(self, x):
|
||||
choices = list(self.choices)
|
||||
choices.remove(x)
|
||||
return choices, {}, False
|
||||
return choices, {}
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores):
|
||||
"""
|
||||
@ -647,6 +667,7 @@ class Image(InputComponent):
|
||||
self.type = type
|
||||
self.invert_colors = invert_colors
|
||||
self.test_input = test_data.BASE64_IMAGE
|
||||
self.interpret_by_tokens = True
|
||||
super().__init__(label, requires_permissions)
|
||||
|
||||
@classmethod
|
||||
@ -690,7 +711,7 @@ class Image(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x)
|
||||
|
||||
def interpret(self, segments=16):
|
||||
def set_interpret_parameters(self, segments=16):
|
||||
"""
|
||||
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
|
||||
Parameters:
|
||||
@ -699,29 +720,60 @@ class Image(InputComponent):
|
||||
self.interpretation_segments = segments
|
||||
return self
|
||||
|
||||
def get_interpretation_neighbors(self, x):
|
||||
def _segment_by_slic(self, x):
|
||||
"""
|
||||
Helper method that segments an image into superpixels using slic.
|
||||
Parameters:
|
||||
x: base64 representation of an image
|
||||
"""
|
||||
x = processing_utils.decode_base64_to_image(x)
|
||||
if self.shape is not None:
|
||||
x = processing_utils.resize_and_crop(x, self.shape)
|
||||
image = np.array(x)
|
||||
resized_and_cropped_image = np.array(x)
|
||||
try:
|
||||
from skimage.segmentation import slic
|
||||
except ImportError:
|
||||
print("Error: running default interpretation for images requires scikit-image, please install it first.")
|
||||
return
|
||||
segments_slic = slic(image, self.interpretation_segments, compactness=10, sigma=1)
|
||||
leave_one_out_tokens, masks = [], []
|
||||
replace_color = np.mean(image, axis=(0, 1))
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = replace_color
|
||||
leave_one_out_tokens.append(
|
||||
processing_utils.encode_array_to_base64(white_screen))
|
||||
masks.append(mask)
|
||||
return leave_one_out_tokens, {"masks": masks}, True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ValueError("Error: running this interpretation for images requires scikit-image, please install it first.")
|
||||
segments_slic = slic(
|
||||
resized_and_cropped_image, self.interpretation_segments, compactness=10,
|
||||
sigma=1, start_label=0)
|
||||
return segments_slic, resized_and_cropped_image
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, masks):
|
||||
def tokenize(self, x):
|
||||
"""
|
||||
Segments image into tokens, masks, and leave-one-out-tokens
|
||||
Parameters:
|
||||
x: base64 representation of an image
|
||||
Returns:
|
||||
tokens: list of tokens, used by the get_masked_input() method
|
||||
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
|
||||
masks: list of masks, used by the get_interpretation_neighbors() method
|
||||
"""
|
||||
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
|
||||
tokens, masks, leave_one_out_tokens = [], [], []
|
||||
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
|
||||
for (i, segment_value) in enumerate(np.unique(segments_slic)):
|
||||
mask = (segments_slic == segment_value)
|
||||
image_screen = np.copy(resized_and_cropped_image)
|
||||
image_screen[segments_slic == segment_value] = replace_color
|
||||
leave_one_out_tokens.append(
|
||||
processing_utils.encode_array_to_base64(image_screen))
|
||||
token = np.copy(resized_and_cropped_image)
|
||||
token[segments_slic != segment_value] = 0
|
||||
tokens.append(token)
|
||||
masks.append(mask)
|
||||
return tokens, leave_one_out_tokens, masks
|
||||
|
||||
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||
masked_inputs = []
|
||||
for binary_mask_vector in binary_mask_matrix:
|
||||
masked_input = np.zeros_like(tokens[0], dtype=int)
|
||||
for token, b in zip(tokens, binary_mask_vector):
|
||||
masked_input = masked_input + token*int(b)
|
||||
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
|
||||
return masked_inputs
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
|
||||
"""
|
||||
Returns:
|
||||
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
|
||||
@ -825,6 +877,7 @@ class Audio(InputComponent):
|
||||
requires_permissions = source == "microphone"
|
||||
self.type = type
|
||||
self.test_input = test_data.BASE64_AUDIO
|
||||
self.interpret_by_tokens = True
|
||||
super().__init__(label, requires_permissions)
|
||||
|
||||
def get_template_context(self):
|
||||
@ -855,7 +908,7 @@ class Audio(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x, type="audio")
|
||||
|
||||
def interpret(self, segments=8):
|
||||
def set_interpret_parameters(self, segments=8):
|
||||
"""
|
||||
Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value.
|
||||
Parameters:
|
||||
@ -864,30 +917,66 @@ class Audio(InputComponent):
|
||||
self.interpretation_segments = segments
|
||||
return self
|
||||
|
||||
def get_interpretation_neighbors(self, x):
|
||||
def tokenize(self, x):
|
||||
file_obj = processing_utils.decode_base64_to_file(x)
|
||||
x = scipy.io.wavfile.read(file_obj.name)
|
||||
sample_rate, data = x
|
||||
leave_one_out_sets = []
|
||||
tokens = []
|
||||
masks = []
|
||||
duration = data.shape[0]
|
||||
boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist()
|
||||
boundaries = [round(boundary) for boundary in boundaries]
|
||||
for index in range(len(boundaries) - 1):
|
||||
leave_one_out_data = np.copy(data)
|
||||
start, stop = boundaries[index], boundaries[index + 1]
|
||||
masks.append((start, stop))
|
||||
# Handle the leave one outs
|
||||
leave_one_out_data = np.copy(data)
|
||||
leave_one_out_data[start:stop] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
scipy.io.wavfile.write(file, sample_rate, leave_one_out_data)
|
||||
out_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
|
||||
leave_one_out_sets.append(out_data)
|
||||
return leave_one_out_sets, {}, True
|
||||
# Handle the tokens
|
||||
token = np.copy(data)
|
||||
token[0:start] = 0
|
||||
token[stop:] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
scipy.io.wavfile.write(file, sample_rate, token)
|
||||
token_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
|
||||
tokens.append(token_data)
|
||||
return tokens, leave_one_out_sets, masks
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores):
|
||||
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||
# create a "zero input" vector and get sample rate
|
||||
x = tokens[0]
|
||||
file_obj = processing_utils.decode_base64_to_file(x)
|
||||
sample_rate, data = scipy.io.wavfile.read(file_obj.name)
|
||||
zero_input = np.zeros_like(data, dtype=int)
|
||||
# decode all of the tokens
|
||||
token_data = []
|
||||
for token in tokens:
|
||||
file_obj = processing_utils.decode_base64_to_file(token)
|
||||
_, data = scipy.io.wavfile.read(file_obj.name)
|
||||
token_data.append(data)
|
||||
# construct the masked version
|
||||
masked_inputs = []
|
||||
for binary_mask_vector in binary_mask_matrix:
|
||||
masked_input = np.copy(zero_input)
|
||||
for t, b in zip(token_data, binary_mask_vector):
|
||||
masked_input = masked_input + t*int(b)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
scipy.io.wavfile.write(file, sample_rate, masked_input)
|
||||
masked_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav")
|
||||
masked_inputs.append(masked_data)
|
||||
return masked_inputs
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, masks=None, tokens=None):
|
||||
"""
|
||||
Returns:
|
||||
(List[float]): Each value represents the interpretation score corresponding to an evenly spaced subsection of audio.
|
||||
"""
|
||||
return scores
|
||||
return list(scores)
|
||||
|
||||
def embed(self, x):
|
||||
"""
|
||||
@ -1049,35 +1138,35 @@ class Dataframe(InputComponent):
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'pandas', 'numpy', 'array'.")
|
||||
|
||||
def interpret(self):
|
||||
"""
|
||||
Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value.
|
||||
"""
|
||||
return self
|
||||
# def set_interpret_parameters(self):
|
||||
# """
|
||||
# Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value.
|
||||
# """
|
||||
# return self
|
||||
|
||||
def get_interpretation_neighbors(self, x):
|
||||
x = pd.DataFrame(x)
|
||||
leave_one_out_sets = []
|
||||
shape = x.shape
|
||||
for i in range(shape[0]):
|
||||
for j in range(shape[1]):
|
||||
scalar = x.iloc[i, j]
|
||||
leave_one_out_df = x.copy()
|
||||
if is_bool_dtype(scalar):
|
||||
leave_one_out_df.iloc[i, j] = not scalar
|
||||
elif is_numeric_dtype(scalar):
|
||||
leave_one_out_df.iloc[i, j] = 0
|
||||
elif is_string_dtype(scalar):
|
||||
leave_one_out_df.iloc[i, j] = ""
|
||||
leave_one_out_sets.append(leave_one_out_df.values.tolist())
|
||||
return leave_one_out_sets, {"shape": x.shape}, True
|
||||
# def get_interpretation_neighbors(self, x):
|
||||
# x = pd.DataFrame(x)
|
||||
# leave_one_out_sets = []
|
||||
# shape = x.shape
|
||||
# for i in range(shape[0]):
|
||||
# for j in range(shape[1]):
|
||||
# scalar = x.iloc[i, j]
|
||||
# leave_one_out_df = x.copy()
|
||||
# if is_bool_dtype(scalar):
|
||||
# leave_one_out_df.iloc[i, j] = not scalar
|
||||
# elif is_numeric_dtype(scalar):
|
||||
# leave_one_out_df.iloc[i, j] = 0
|
||||
# elif is_string_dtype(scalar):
|
||||
# leave_one_out_df.iloc[i, j] = ""
|
||||
# leave_one_out_sets.append(leave_one_out_df.values.tolist())
|
||||
# return leave_one_out_sets, {"shape": x.shape}
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, shape):
|
||||
"""
|
||||
Returns:
|
||||
(List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell.
|
||||
"""
|
||||
return np.array(scores).reshape((shape)).tolist()
|
||||
# def get_interpretation_scores(self, x, neighbors, scores, shape):
|
||||
# """
|
||||
# Returns:
|
||||
# (List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell.
|
||||
# """
|
||||
# return np.array(scores).reshape((shape)).tolist()
|
||||
|
||||
def embed(self, x):
|
||||
raise NotImplementedError("DataFrame doesn't currently support embeddings")
|
||||
|
@ -7,7 +7,7 @@ import gradio
|
||||
from gradio.inputs import InputComponent, get_input_instance
|
||||
from gradio.outputs import OutputComponent, get_output_instance
|
||||
from gradio import networking, strings, utils
|
||||
from gradio.interpretation import quantify_difference_in_label
|
||||
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
|
||||
from gradio.external import load_interface
|
||||
from gradio import encryptor
|
||||
import pkg_resources
|
||||
@ -66,7 +66,7 @@ class Interface:
|
||||
def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None,
|
||||
examples_per_page=10, live=False,
|
||||
layout="horizontal", show_input=True, show_output=True,
|
||||
capture_session=False, interpretation=None, theme=None, repeat_outputs_per_model=True,
|
||||
capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
||||
@ -84,6 +84,7 @@ class Interface:
|
||||
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
|
||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results.
|
||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||
article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content.
|
||||
@ -145,6 +146,7 @@ class Interface:
|
||||
self.examples = examples
|
||||
else:
|
||||
raise ValueError("Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.")
|
||||
self.num_shap = num_shap
|
||||
self.examples_per_page = examples_per_page
|
||||
self.server_port = server_port
|
||||
self.simple_server = None
|
||||
@ -334,7 +336,7 @@ class Interface:
|
||||
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
||||
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
|
||||
"""
|
||||
if self.interpretation == "default":
|
||||
if self.interpretation.lower() == "default":
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
original_output = self.run_prediction(processed_input)
|
||||
@ -342,7 +344,26 @@ class Interface:
|
||||
for i, x in enumerate(raw_input):
|
||||
input_component = self.input_components[i]
|
||||
neighbor_raw_input = list(raw_input)
|
||||
neighbor_values, interpret_kwargs, interpret_by_removal = input_component.get_interpretation_neighbors(x)
|
||||
if input_component.interpret_by_tokens:
|
||||
tokens, neighbor_values, masks = input_component.tokenize(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
neighbor_raw_input[i] = neighbor_input
|
||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
neighbor_output = self.run_prediction(processed_neighbor_input)
|
||||
processed_neighbor_output = [output_component.postprocess(
|
||||
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
||||
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||
alternative_outputs.append(alternative_output)
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
|
||||
else:
|
||||
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
@ -356,12 +377,42 @@ class Interface:
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||
alternative_outputs.append(alternative_output)
|
||||
if not interpret_by_removal:
|
||||
interface_scores = [-score for score in interface_scores]
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
|
||||
return scores, alternative_outputs
|
||||
elif self.interpretation.lower() == "shap":
|
||||
scores = []
|
||||
try:
|
||||
import shap
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
|
||||
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
original_output = self.run_prediction(processed_input)
|
||||
|
||||
for i, x in enumerate(raw_input): # iterate over reach interface
|
||||
input_component = self.input_components[i]
|
||||
tokens, _, masks = input_component.tokenize(x)
|
||||
|
||||
def get_masked_prediction(binary_mask): # construct a masked version of the input
|
||||
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
|
||||
preds = []
|
||||
for masked_x in masked_xs:
|
||||
processed_masked_input = copy.deepcopy(processed_input)
|
||||
processed_masked_input[i] = input_component.preprocess(masked_x)
|
||||
new_output = self.run_prediction(processed_masked_input)
|
||||
pred = get_regression_or_classification_value(self, original_output, new_output)
|
||||
preds.append(pred)
|
||||
return np.array(preds)
|
||||
|
||||
num_total_segments = len(tokens)
|
||||
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
|
||||
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
|
||||
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
|
||||
return scores, []
|
||||
else:
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
|
@ -1,4 +1,5 @@
|
||||
from gradio.outputs import Label, Textbox
|
||||
import math
|
||||
|
||||
def diff(original, perturbed):
|
||||
try: # try computing numerical difference
|
||||
@ -24,7 +25,33 @@ def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||
else:
|
||||
score = diff(original_label, perturbed_label)
|
||||
return score
|
||||
|
||||
elif type(output_component) == Textbox:
|
||||
score = diff(post_original_output, post_perturbed_output)
|
||||
return score
|
||||
|
||||
else:
|
||||
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
|
||||
|
||||
def get_regression_or_classification_value(interface, original_output, perturbed_output):
|
||||
"""Used to combine regression/classification for Shap interpretation method."""
|
||||
output_component = interface.output_components[0]
|
||||
post_original_output = output_component.postprocess(original_output[0])
|
||||
post_perturbed_output = output_component.postprocess(perturbed_output[0])
|
||||
|
||||
if type(output_component) == Label:
|
||||
original_label = post_original_output["label"]
|
||||
perturbed_label = post_perturbed_output["label"]
|
||||
|
||||
# Handle different return types of Label interface
|
||||
if "confidences" in post_original_output:
|
||||
if math.isnan(perturbed_output[0][original_label]):
|
||||
return 0
|
||||
return perturbed_output[0][original_label]
|
||||
else:
|
||||
score = diff(perturbed_label, original_label) # Intentionall inverted order of arguments.
|
||||
return score
|
||||
|
||||
else:
|
||||
raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user