added support for different interpretation methods for different inputs, or even optional interpretation; fixed interpretaiton tests

This commit is contained in:
Abubakar Abid 2021-07-14 17:37:45 -05:00
parent 013c686202
commit 1360f423ae

View File

@ -113,6 +113,14 @@ class Interface:
if repeat_outputs_per_model:
self.output_components *= len(fn)
if interpretation is None or isinstance(interpretation, list) or callable(interpretation):
self.interpretation = interpretation
elif isinstance(interpretation, str):
self.interpretation = [interpretation.lower() for _ in self.input_components]
else:
raise ValueError("Invalid value for parameter: interpretation")
self.predict = fn
self.function_names = [func.__name__ for func in fn]
self.__name__ = ", ".join(self.function_names)
@ -124,7 +132,6 @@ class Interface:
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.interpretation = interpretation
self.session = None
self.server_name = server_name
self.title = title
@ -336,84 +343,87 @@ class Interface:
interpretation for a certain set of UI component types, as well as the custom interpretation case.
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
"""
if self.interpretation.lower() == "default":
if isinstance(self.interpretation, list): # Either "default" or "shap"
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]
original_output = self.run_prediction(processed_input)
scores, alternative_outputs = [], []
for i, x in enumerate(raw_input):
input_component = self.input_components[i]
neighbor_raw_input = list(raw_input)
if input_component.interpret_by_tokens:
tokens, neighbor_values, masks = input_component.tokenize(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(self.input_components)]
neighbor_output = self.run_prediction(processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
for i, (x, interp) in enumerate(zip(raw_input, self.interpretation)):
print(i, interp)
if interp=="default":
input_component = self.input_components[i]
neighbor_raw_input = list(raw_input)
if input_component.interpret_by_tokens:
tokens, neighbor_values, masks = input_component.tokenize(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(self.input_components)]
neighbor_output = self.run_prediction(processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output)
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output)
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
else:
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(self.input_components)]
neighbor_output = self.run_prediction(processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output)
interface_scores = [-score for score in interface_scores]
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
elif interp == "shap":
try:
import shap
except (ImportError, ModuleNotFoundError):
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
input_component = self.input_components[i]
if not(input_component.interpret_by_tokens):
raise ValueError("Input component {} does not support `shap` interpretation".format(input_component))
tokens, _, masks = input_component.tokenize(x)
def get_masked_prediction(binary_mask): # construct a masked version of the input
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
preds = []
for masked_x in masked_xs:
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = self.run_prediction(processed_masked_input)
pred = get_regression_or_classification_value(self, original_output, new_output)
preds.append(pred)
return np.array(preds)
num_total_segments = len(tokens)
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
alternative_outputs.append([])
elif interp is None:
scores.append(None)
alternative_outputs.append([])
else:
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
interface_scores = []
alternative_output = []
for neighbor_input in neighbor_values:
neighbor_raw_input[i] = neighbor_input
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
for i, input_component in enumerate(self.input_components)]
neighbor_output = self.run_prediction(processed_neighbor_input)
processed_neighbor_output = [output_component.postprocess(
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
alternative_output.append(processed_neighbor_output)
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
alternative_outputs.append(alternative_output)
interface_scores = [-score for score in interface_scores]
scores.append(
input_component.get_interpretation_scores(
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
return scores, alternative_outputs
elif self.interpretation.lower() == "shap":
scores = []
try:
import shap
except (ImportError, ModuleNotFoundError):
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]
original_output = self.run_prediction(processed_input)
for i, x in enumerate(raw_input): # iterate over reach interface
input_component = self.input_components[i]
tokens, _, masks = input_component.tokenize(x)
def get_masked_prediction(binary_mask): # construct a masked version of the input
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
preds = []
for masked_x in masked_xs:
processed_masked_input = copy.deepcopy(processed_input)
processed_masked_input[i] = input_component.preprocess(masked_x)
new_output = self.run_prediction(processed_masked_input)
pred = get_regression_or_classification_value(self, original_output, new_output)
preds.append(pred)
return np.array(preds)
num_total_segments = len(tokens)
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
return scores, []
else:
raise ValueError("Uknown intepretation method: {}".format(interp))
return scores, alternative_outputs
else: # custom interpretation function
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]
interpreter = self.interpretation