mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-13 11:57:29 +08:00
added support for different interpretation methods for different inputs, or even optional interpretation; fixed interpretaiton tests
This commit is contained in:
parent
013c686202
commit
1360f423ae
@ -113,6 +113,14 @@ class Interface:
|
||||
|
||||
if repeat_outputs_per_model:
|
||||
self.output_components *= len(fn)
|
||||
|
||||
if interpretation is None or isinstance(interpretation, list) or callable(interpretation):
|
||||
self.interpretation = interpretation
|
||||
elif isinstance(interpretation, str):
|
||||
self.interpretation = [interpretation.lower() for _ in self.input_components]
|
||||
else:
|
||||
raise ValueError("Invalid value for parameter: interpretation")
|
||||
|
||||
self.predict = fn
|
||||
self.function_names = [func.__name__ for func in fn]
|
||||
self.__name__ = ", ".join(self.function_names)
|
||||
@ -124,7 +132,6 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.interpretation = interpretation
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -336,84 +343,87 @@ class Interface:
|
||||
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
||||
:param raw_input: a list of raw inputs to apply the interpretation(s) on.
|
||||
"""
|
||||
if self.interpretation.lower() == "default":
|
||||
if isinstance(self.interpretation, list): # Either "default" or "shap"
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
original_output = self.run_prediction(processed_input)
|
||||
scores, alternative_outputs = [], []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_component = self.input_components[i]
|
||||
neighbor_raw_input = list(raw_input)
|
||||
if input_component.interpret_by_tokens:
|
||||
tokens, neighbor_values, masks = input_component.tokenize(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
neighbor_raw_input[i] = neighbor_input
|
||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
neighbor_output = self.run_prediction(processed_neighbor_input)
|
||||
processed_neighbor_output = [output_component.postprocess(
|
||||
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
||||
for i, (x, interp) in enumerate(zip(raw_input, self.interpretation)):
|
||||
print(i, interp)
|
||||
if interp=="default":
|
||||
input_component = self.input_components[i]
|
||||
neighbor_raw_input = list(raw_input)
|
||||
if input_component.interpret_by_tokens:
|
||||
tokens, neighbor_values, masks = input_component.tokenize(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
neighbor_raw_input[i] = neighbor_input
|
||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
neighbor_output = self.run_prediction(processed_neighbor_input)
|
||||
processed_neighbor_output = [output_component.postprocess(
|
||||
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
||||
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||
alternative_outputs.append(alternative_output)
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||
alternative_outputs.append(alternative_output)
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens))
|
||||
else:
|
||||
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
neighbor_raw_input[i] = neighbor_input
|
||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
neighbor_output = self.run_prediction(processed_neighbor_input)
|
||||
processed_neighbor_output = [output_component.postprocess(
|
||||
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
||||
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||
alternative_outputs.append(alternative_output)
|
||||
interface_scores = [-score for score in interface_scores]
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
|
||||
elif interp == "shap":
|
||||
try:
|
||||
import shap
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
|
||||
input_component = self.input_components[i]
|
||||
if not(input_component.interpret_by_tokens):
|
||||
raise ValueError("Input component {} does not support `shap` interpretation".format(input_component))
|
||||
|
||||
tokens, _, masks = input_component.tokenize(x)
|
||||
|
||||
def get_masked_prediction(binary_mask): # construct a masked version of the input
|
||||
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
|
||||
preds = []
|
||||
for masked_x in masked_xs:
|
||||
processed_masked_input = copy.deepcopy(processed_input)
|
||||
processed_masked_input[i] = input_component.preprocess(masked_x)
|
||||
new_output = self.run_prediction(processed_masked_input)
|
||||
pred = get_regression_or_classification_value(self, original_output, new_output)
|
||||
preds.append(pred)
|
||||
return np.array(preds)
|
||||
|
||||
num_total_segments = len(tokens)
|
||||
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
|
||||
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
|
||||
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
|
||||
alternative_outputs.append([])
|
||||
elif interp is None:
|
||||
scores.append(None)
|
||||
alternative_outputs.append([])
|
||||
else:
|
||||
neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x)
|
||||
interface_scores = []
|
||||
alternative_output = []
|
||||
for neighbor_input in neighbor_values:
|
||||
neighbor_raw_input[i] = neighbor_input
|
||||
processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
neighbor_output = self.run_prediction(processed_neighbor_input)
|
||||
processed_neighbor_output = [output_component.postprocess(
|
||||
neighbor_output[i]) for i, output_component in enumerate(self.output_components)]
|
||||
|
||||
alternative_output.append(processed_neighbor_output)
|
||||
interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output))
|
||||
alternative_outputs.append(alternative_output)
|
||||
interface_scores = [-score for score in interface_scores]
|
||||
scores.append(
|
||||
input_component.get_interpretation_scores(
|
||||
raw_input[i], neighbor_values, interface_scores, **interpret_kwargs))
|
||||
return scores, alternative_outputs
|
||||
elif self.interpretation.lower() == "shap":
|
||||
scores = []
|
||||
try:
|
||||
import shap
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`")
|
||||
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
original_output = self.run_prediction(processed_input)
|
||||
|
||||
for i, x in enumerate(raw_input): # iterate over reach interface
|
||||
input_component = self.input_components[i]
|
||||
tokens, _, masks = input_component.tokenize(x)
|
||||
|
||||
def get_masked_prediction(binary_mask): # construct a masked version of the input
|
||||
masked_xs = input_component.get_masked_inputs(tokens, binary_mask)
|
||||
preds = []
|
||||
for masked_x in masked_xs:
|
||||
processed_masked_input = copy.deepcopy(processed_input)
|
||||
processed_masked_input[i] = input_component.preprocess(masked_x)
|
||||
new_output = self.run_prediction(processed_masked_input)
|
||||
pred = get_regression_or_classification_value(self, original_output, new_output)
|
||||
preds.append(pred)
|
||||
return np.array(preds)
|
||||
|
||||
num_total_segments = len(tokens)
|
||||
explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments)))
|
||||
shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True)
|
||||
scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens))
|
||||
return scores, []
|
||||
else:
|
||||
raise ValueError("Uknown intepretation method: {}".format(interp))
|
||||
return scores, alternative_outputs
|
||||
else: # custom interpretation function
|
||||
processed_input = [input_component.preprocess(raw_input[i])
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
interpreter = self.interpretation
|
||||
|
Loading…
x
Reference in New Issue
Block a user