diff --git a/demo/tax_calculator.py b/demo/tax_calculator.py index 020ab25de2..866cef9f92 100644 --- a/demo/tax_calculator.py +++ b/demo/tax_calculator.py @@ -37,7 +37,7 @@ iface = gr.Interface( ) ], "number", - interpretation="default", + # interpretation="default", # Removed interpretation for dataframes examples=[ [10000, "Married", [["Car", 5000, False], ["Laptop", 800, True]]], [80000, "Single", [["Suit", 800, True], ["Watch", 1800, False]]], diff --git a/gradio.egg-info/PKG-INFO b/gradio.egg-info/PKG-INFO index 89f823f6d6..70598f2a22 100644 --- a/gradio.egg-info/PKG-INFO +++ b/gradio.egg-info/PKG-INFO @@ -1,4 +1,4 @@ -Metadata-Version: 1.0 +Metadata-Version: 2.1 Name: gradio Version: 2.1.1 Summary: Python library for easily interacting with trained machine learning models @@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI Author: Abubakar Abid Author-email: a12d@stanford.edu License: Apache License 2.0 -Description: UNKNOWN Keywords: machine learning,visualization,reproducibility Platform: UNKNOWN +License-File: LICENSE + +UNKNOWN + diff --git a/gradio.egg-info/SOURCES.txt b/gradio.egg-info/SOURCES.txt index 0ca8ad0f2d..7dfee2f1cf 100644 --- a/gradio.egg-info/SOURCES.txt +++ b/gradio.egg-info/SOURCES.txt @@ -1,3 +1,4 @@ +LICENSE MANIFEST.in README.md setup.py @@ -26,32 +27,6 @@ gradio.egg-info/requires.txt gradio.egg-info/top_level.txt gradio/frontend/asset-manifest.json gradio/frontend/index.html -gradio/frontend/static/bundle.css -gradio/frontend/static/bundle.css.map -gradio/frontend/static/bundle.js -gradio/frontend/static/bundle.js.LICENSE.txt -gradio/frontend/static/bundle.js.map -gradio/frontend/static/css/main.0094f11b.css -gradio/frontend/static/css/main.0094f11b.css.map -gradio/frontend/static/css/main.20be28ac.css -gradio/frontend/static/css/main.20be28ac.css.map -gradio/frontend/static/css/main.2b64a968.css -gradio/frontend/static/css/main.2b64a968.css.map -gradio/frontend/static/css/main.380e3222.css -gradio/frontend/static/css/main.380e3222.css.map -gradio/frontend/static/css/main.4aea80f8.css -gradio/frontend/static/css/main.4aea80f8.css.map -gradio/frontend/static/css/main.4f157d97.css -gradio/frontend/static/css/main.4f157d97.css.map -gradio/frontend/static/css/main.5c663906.css -gradio/frontend/static/css/main.5c663906.css.map -gradio/frontend/static/css/main.99922310.css -gradio/frontend/static/css/main.99922310.css.map -gradio/frontend/static/css/main.acb02c85.css -gradio/frontend/static/css/main.acb02c85.css.map -gradio/frontend/static/css/main.cbbf8898.css -gradio/frontend/static/css/main.cbbf8898.css.map -gradio/frontend/static/media/logo_loading.e93acd82.jpg test/test_demos.py test/test_inputs.py test/test_interfaces.py diff --git a/gradio/inputs.py b/gradio/inputs.py index 8c7c1b112e..cd800d031d 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -29,7 +29,7 @@ class InputComponent(Component): Input Component. All input components subclass this. """ def __init__(self, label, requires_permissions=False): - self.interpret() + self.set_interpret_parameters() super().__init__(label, requires_permissions) def preprocess(self, x): @@ -44,7 +44,7 @@ class InputComponent(Component): """ return x - def interpret(self): + def set_interpret_parameters(self): ''' Set any parameters for interpretation. ''' @@ -115,6 +115,7 @@ class Textbox(InputComponent): }[type] else: self.test_input = default + self.interpret_by_tokens = True super().__init__(label) def get_template_context(self): @@ -147,7 +148,7 @@ class Textbox(InputComponent): """ return x - def interpret(self, separator=" ", replacement=None): + def set_interpret_parameters(self, separator=" ", replacement=None): """ Calculates interpretation score of characters in input by splitting input into tokens, then using a "leave one out" method to calculate the score of each token by removing each token and measuring the delta of the output value. Parameters: @@ -157,8 +158,11 @@ class Textbox(InputComponent): self.interpretation_separator = separator self.interpretation_replacement = replacement return self - - def get_interpretation_neighbors(self, x): + + def tokenize(self, x): + """ + Tokenizes an input string by dividing into "words" delimited by self.interpretation_separator + """ tokens = x.split(self.interpretation_separator) leave_one_out_strings = [] for index in range(len(tokens)): @@ -168,9 +172,19 @@ class Textbox(InputComponent): else: leave_one_out_set[index] = self.interpretation_replacement leave_one_out_strings.append(self.interpretation_separator.join(leave_one_out_set)) - return leave_one_out_strings, {"tokens": tokens}, True + return tokens, leave_one_out_strings, None + + def get_masked_inputs(self, tokens, binary_mask_matrix): + """ + Constructs partially-masked sentences for SHAP interpretation + """ + masked_inputs = [] + for binary_mask_vector in binary_mask_matrix: + masked_input = np.array(tokens)[np.array(binary_mask_vector, dtype=bool)] + masked_inputs.append(self.interpretation_separator.join(masked_input)) + return masked_inputs - def get_interpretation_scores(self, x, neighbors, scores, tokens): + def get_interpretation_scores(self, x, neighbors, scores, tokens, masks=None): """ Returns: (List[Tuple[str, float]]): Each tuple set represents a set of characters and their corresponding interpretation score. @@ -216,6 +230,7 @@ class Number(InputComponent): ''' self.default = default self.test_input = default if default is not None else 1 + self.interpret_by_tokens = False super().__init__(label) def get_template_context(self): @@ -244,7 +259,7 @@ class Number(InputComponent): """ return x - def interpret(self, steps=3, delta=1, delta_type="percent"): + def set_interpret_parameters(self, steps=3, delta=1, delta_type="percent"): """ Calculates interpretation scores of numeric values close to the input number. Parameters: @@ -266,7 +281,7 @@ class Number(InputComponent): delta = self.interpretation_delta negatives = (x + np.arange(-self.interpretation_steps, 0) * delta).tolist() positives = (x + np.arange(1, self.interpretation_steps+1) * delta).tolist() - return negatives + positives, {}, False + return negatives + positives, {} def get_interpretation_scores(self, x, neighbors, scores): """ @@ -305,6 +320,7 @@ class Slider(InputComponent): self.step = step self.default = minimum if default is None else default self.test_input = self.default + self.interpret_by_tokens = False super().__init__(label) def get_template_context(self): @@ -329,7 +345,7 @@ class Slider(InputComponent): """ return x - def interpret(self, steps=8): + def set_interpret_parameters(self, steps=8): """ Calculates interpretation scores of numeric values ranging between the minimum and maximum values of the slider. Parameters: @@ -339,7 +355,7 @@ class Slider(InputComponent): return self def get_interpretation_neighbors(self, x): - return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {}, False + return np.linspace(self.minimum, self.maximum, self.interpretation_steps).tolist(), {} def get_interpretation_scores(self, x, neighbors, scores): """ @@ -367,6 +383,7 @@ class Checkbox(InputComponent): """ self.test_input = True self.default = default + self.interpret_by_tokens = False super().__init__(label) def get_template_context(self): @@ -388,14 +405,14 @@ class Checkbox(InputComponent): """ return x - def interpret(self): + def set_interpret_parameters(self): """ Calculates interpretation score of the input by comparing the output against the output when the input is the inverse boolean value of x. """ return self def get_interpretation_neighbors(self, x): - return [not x], {}, False + return [not x], {} def get_interpretation_scores(self, x, neighbors, scores): """ @@ -430,6 +447,7 @@ class CheckboxGroup(InputComponent): self.default = default self.type = type self.test_input = self.choices + self.interpret_by_tokens = False super().__init__(label) def get_template_context(self): @@ -447,7 +465,7 @@ class CheckboxGroup(InputComponent): else: raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.") - def interpret(self): + def set_interpret_parameters(self): """ Calculates interpretation score of each choice in the input by comparing the output against the outputs when each choice in the input is independently either removed or added. """ @@ -462,7 +480,7 @@ class CheckboxGroup(InputComponent): else: leave_one_out_set.append(choice) leave_one_out_sets.append(leave_one_out_set) - return leave_one_out_sets, {}, False + return leave_one_out_sets, {} def get_interpretation_scores(self, x, neighbors, scores): """ @@ -514,6 +532,7 @@ class Radio(InputComponent): self.type = type self.test_input = self.choices[0] self.default = default if default is not None else self.choices[0] + self.interpret_by_tokens = False super().__init__(label) def get_template_context(self): @@ -531,7 +550,7 @@ class Radio(InputComponent): else: raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.") - def interpret(self): + def set_interpret_parameters(self): """ Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected. """ @@ -540,7 +559,7 @@ class Radio(InputComponent): def get_interpretation_neighbors(self, x): choices = list(self.choices) choices.remove(x) - return choices, {}, False + return choices, {} def get_interpretation_scores(self, x, neighbors, scores): """ @@ -577,6 +596,7 @@ class Dropdown(InputComponent): self.type = type self.test_input = self.choices[0] self.default = default if default is not None else self.choices[0] + self.interpret_by_tokens = False super().__init__(label) def get_template_context(self): @@ -594,7 +614,7 @@ class Dropdown(InputComponent): else: raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.") - def interpret(self): + def set_interpret_parameters(self): """ Calculates interpretation score of each choice by comparing the output against each of the outputs when alternative choices are selected. """ @@ -603,7 +623,7 @@ class Dropdown(InputComponent): def get_interpretation_neighbors(self, x): choices = list(self.choices) choices.remove(x) - return choices, {}, False + return choices, {} def get_interpretation_scores(self, x, neighbors, scores): """ @@ -647,6 +667,7 @@ class Image(InputComponent): self.type = type self.invert_colors = invert_colors self.test_input = test_data.BASE64_IMAGE + self.interpret_by_tokens = True super().__init__(label, requires_permissions) @classmethod @@ -690,7 +711,7 @@ class Image(InputComponent): def preprocess_example(self, x): return processing_utils.encode_file_to_base64(x) - def interpret(self, segments=16): + def set_interpret_parameters(self, segments=16): """ Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value. Parameters: @@ -699,29 +720,60 @@ class Image(InputComponent): self.interpretation_segments = segments return self - def get_interpretation_neighbors(self, x): + def _segment_by_slic(self, x): + """ + Helper method that segments an image into superpixels using slic. + Parameters: + x: base64 representation of an image + """ x = processing_utils.decode_base64_to_image(x) if self.shape is not None: x = processing_utils.resize_and_crop(x, self.shape) - image = np.array(x) + resized_and_cropped_image = np.array(x) try: from skimage.segmentation import slic - except ImportError: - print("Error: running default interpretation for images requires scikit-image, please install it first.") - return - segments_slic = slic(image, self.interpretation_segments, compactness=10, sigma=1) - leave_one_out_tokens, masks = [], [] - replace_color = np.mean(image, axis=(0, 1)) - for (i, segVal) in enumerate(np.unique(segments_slic)): - mask = segments_slic == segVal - white_screen = np.copy(image) - white_screen[segments_slic == segVal] = replace_color - leave_one_out_tokens.append( - processing_utils.encode_array_to_base64(white_screen)) - masks.append(mask) - return leave_one_out_tokens, {"masks": masks}, True + except (ImportError, ModuleNotFoundError): + raise ValueError("Error: running this interpretation for images requires scikit-image, please install it first.") + segments_slic = slic( + resized_and_cropped_image, self.interpretation_segments, compactness=10, + sigma=1, start_label=0) + return segments_slic, resized_and_cropped_image - def get_interpretation_scores(self, x, neighbors, scores, masks): + def tokenize(self, x): + """ + Segments image into tokens, masks, and leave-one-out-tokens + Parameters: + x: base64 representation of an image + Returns: + tokens: list of tokens, used by the get_masked_input() method + leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method + masks: list of masks, used by the get_interpretation_neighbors() method + """ + segments_slic, resized_and_cropped_image = self._segment_by_slic(x) + tokens, masks, leave_one_out_tokens = [], [], [] + replace_color = np.mean(resized_and_cropped_image, axis=(0, 1)) + for (i, segment_value) in enumerate(np.unique(segments_slic)): + mask = (segments_slic == segment_value) + image_screen = np.copy(resized_and_cropped_image) + image_screen[segments_slic == segment_value] = replace_color + leave_one_out_tokens.append( + processing_utils.encode_array_to_base64(image_screen)) + token = np.copy(resized_and_cropped_image) + token[segments_slic != segment_value] = 0 + tokens.append(token) + masks.append(mask) + return tokens, leave_one_out_tokens, masks + + def get_masked_inputs(self, tokens, binary_mask_matrix): + masked_inputs = [] + for binary_mask_vector in binary_mask_matrix: + masked_input = np.zeros_like(tokens[0], dtype=int) + for token, b in zip(tokens, binary_mask_vector): + masked_input = masked_input + token*int(b) + masked_inputs.append(processing_utils.encode_array_to_base64(masked_input)) + return masked_inputs + + def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None): """ Returns: (List[List[float]]): A 2D array representing the interpretation score of each pixel of the image. @@ -825,6 +877,7 @@ class Audio(InputComponent): requires_permissions = source == "microphone" self.type = type self.test_input = test_data.BASE64_AUDIO + self.interpret_by_tokens = True super().__init__(label, requires_permissions) def get_template_context(self): @@ -855,7 +908,7 @@ class Audio(InputComponent): def preprocess_example(self, x): return processing_utils.encode_file_to_base64(x, type="audio") - def interpret(self, segments=8): + def set_interpret_parameters(self, segments=8): """ Calculates interpretation score of audio subsections by splitting the audio into subsections, then using a "leave one out" method to calculate the score of each subsection by removing the subsection and measuring the delta of the output value. Parameters: @@ -864,30 +917,66 @@ class Audio(InputComponent): self.interpretation_segments = segments return self - def get_interpretation_neighbors(self, x): + def tokenize(self, x): file_obj = processing_utils.decode_base64_to_file(x) x = scipy.io.wavfile.read(file_obj.name) sample_rate, data = x leave_one_out_sets = [] + tokens = [] + masks = [] duration = data.shape[0] boundaries = np.linspace(0, duration, self.interpretation_segments + 1).tolist() boundaries = [round(boundary) for boundary in boundaries] for index in range(len(boundaries) - 1): - leave_one_out_data = np.copy(data) start, stop = boundaries[index], boundaries[index + 1] + masks.append((start, stop)) + # Handle the leave one outs + leave_one_out_data = np.copy(data) leave_one_out_data[start:stop] = 0 file = tempfile.NamedTemporaryFile(delete=False) scipy.io.wavfile.write(file, sample_rate, leave_one_out_data) out_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav") leave_one_out_sets.append(out_data) - return leave_one_out_sets, {}, True + # Handle the tokens + token = np.copy(data) + token[0:start] = 0 + token[stop:] = 0 + file = tempfile.NamedTemporaryFile(delete=False) + scipy.io.wavfile.write(file, sample_rate, token) + token_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav") + tokens.append(token_data) + return tokens, leave_one_out_sets, masks - def get_interpretation_scores(self, x, neighbors, scores): + def get_masked_inputs(self, tokens, binary_mask_matrix): + # create a "zero input" vector and get sample rate + x = tokens[0] + file_obj = processing_utils.decode_base64_to_file(x) + sample_rate, data = scipy.io.wavfile.read(file_obj.name) + zero_input = np.zeros_like(data, dtype=int) + # decode all of the tokens + token_data = [] + for token in tokens: + file_obj = processing_utils.decode_base64_to_file(token) + _, data = scipy.io.wavfile.read(file_obj.name) + token_data.append(data) + # construct the masked version + masked_inputs = [] + for binary_mask_vector in binary_mask_matrix: + masked_input = np.copy(zero_input) + for t, b in zip(token_data, binary_mask_vector): + masked_input = masked_input + t*int(b) + file = tempfile.NamedTemporaryFile(delete=False) + scipy.io.wavfile.write(file, sample_rate, masked_input) + masked_data = processing_utils.encode_file_to_base64(file.name, type="audio", ext="wav") + masked_inputs.append(masked_data) + return masked_inputs + + def get_interpretation_scores(self, x, neighbors, scores, masks=None, tokens=None): """ Returns: (List[float]): Each value represents the interpretation score corresponding to an evenly spaced subsection of audio. """ - return scores + return list(scores) def embed(self, x): """ @@ -1049,35 +1138,35 @@ class Dataframe(InputComponent): else: raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'pandas', 'numpy', 'array'.") - def interpret(self): - """ - Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value. - """ - return self + # def set_interpret_parameters(self): + # """ + # Calculates interpretation score of each cell in the Dataframe by using a "leave one out" method to calculate the score of each cell by removing the cell and measuring the delta of the output value. + # """ + # return self - def get_interpretation_neighbors(self, x): - x = pd.DataFrame(x) - leave_one_out_sets = [] - shape = x.shape - for i in range(shape[0]): - for j in range(shape[1]): - scalar = x.iloc[i, j] - leave_one_out_df = x.copy() - if is_bool_dtype(scalar): - leave_one_out_df.iloc[i, j] = not scalar - elif is_numeric_dtype(scalar): - leave_one_out_df.iloc[i, j] = 0 - elif is_string_dtype(scalar): - leave_one_out_df.iloc[i, j] = "" - leave_one_out_sets.append(leave_one_out_df.values.tolist()) - return leave_one_out_sets, {"shape": x.shape}, True + # def get_interpretation_neighbors(self, x): + # x = pd.DataFrame(x) + # leave_one_out_sets = [] + # shape = x.shape + # for i in range(shape[0]): + # for j in range(shape[1]): + # scalar = x.iloc[i, j] + # leave_one_out_df = x.copy() + # if is_bool_dtype(scalar): + # leave_one_out_df.iloc[i, j] = not scalar + # elif is_numeric_dtype(scalar): + # leave_one_out_df.iloc[i, j] = 0 + # elif is_string_dtype(scalar): + # leave_one_out_df.iloc[i, j] = "" + # leave_one_out_sets.append(leave_one_out_df.values.tolist()) + # return leave_one_out_sets, {"shape": x.shape} - def get_interpretation_scores(self, x, neighbors, scores, shape): - """ - Returns: - (List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell. - """ - return np.array(scores).reshape((shape)).tolist() + # def get_interpretation_scores(self, x, neighbors, scores, shape): + # """ + # Returns: + # (List[List[float]]): A 2D array where each value corrseponds to the interpretation score of each cell. + # """ + # return np.array(scores).reshape((shape)).tolist() def embed(self, x): raise NotImplementedError("DataFrame doesn't currently support embeddings") diff --git a/gradio/interface.py b/gradio/interface.py index f7d3353bb4..9772374f38 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -7,7 +7,7 @@ import gradio from gradio.inputs import InputComponent, get_input_instance from gradio.outputs import OutputComponent, get_output_instance from gradio import networking, strings, utils -from gradio.interpretation import quantify_difference_in_label +from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value from gradio.external import load_interface from gradio import encryptor import pkg_resources @@ -66,8 +66,8 @@ class Interface: def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None, examples_per_page=10, live=False, layout="horizontal", show_input=True, show_output=True, - capture_session=False, interpretation=None, theme=None, repeat_outputs_per_model=True, - title=None, description=None, article=None, thumbnail=None, + capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True, + title=None, description=None, article=None, thumbnail=None, css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900, allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False, show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True): @@ -84,6 +84,7 @@ class Interface: layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically. capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x) interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter. + num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. title (str): a title for the interface; if provided, appears above the input and output components. description (str): a description for the interface; if provided, appears above the input and output components. article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content. @@ -145,6 +146,7 @@ class Interface: self.examples = examples else: raise ValueError("Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.") + self.num_shap = num_shap self.examples_per_page = examples_per_page self.server_port = server_port self.simple_server = None @@ -334,7 +336,7 @@ class Interface: interpretation for a certain set of UI component types, as well as the custom interpretation case. :param raw_input: a list of raw inputs to apply the interpretation(s) on. """ - if self.interpretation == "default": + if self.interpretation.lower() == "default": processed_input = [input_component.preprocess(raw_input[i]) for i, input_component in enumerate(self.input_components)] original_output = self.run_prediction(processed_input) @@ -342,26 +344,75 @@ class Interface: for i, x in enumerate(raw_input): input_component = self.input_components[i] neighbor_raw_input = list(raw_input) - neighbor_values, interpret_kwargs, interpret_by_removal = input_component.get_interpretation_neighbors(x) - interface_scores = [] - alternative_output = [] - for neighbor_input in neighbor_values: - neighbor_raw_input[i] = neighbor_input - processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i]) - for i, input_component in enumerate(self.input_components)] - neighbor_output = self.run_prediction(processed_neighbor_input) - processed_neighbor_output = [output_component.postprocess( - neighbor_output[i]) for i, output_component in enumerate(self.output_components)] + if input_component.interpret_by_tokens: + tokens, neighbor_values, masks = input_component.tokenize(x) + interface_scores = [] + alternative_output = [] + for neighbor_input in neighbor_values: + neighbor_raw_input[i] = neighbor_input + processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i]) + for i, input_component in enumerate(self.input_components)] + neighbor_output = self.run_prediction(processed_neighbor_input) + processed_neighbor_output = [output_component.postprocess( + neighbor_output[i]) for i, output_component in enumerate(self.output_components)] - alternative_output.append(processed_neighbor_output) - interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output)) - alternative_outputs.append(alternative_output) - if not interpret_by_removal: + alternative_output.append(processed_neighbor_output) + interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output)) + alternative_outputs.append(alternative_output) + scores.append( + input_component.get_interpretation_scores( + raw_input[i], neighbor_values, interface_scores, masks=masks, tokens=tokens)) + else: + neighbor_values, interpret_kwargs = input_component.get_interpretation_neighbors(x) + interface_scores = [] + alternative_output = [] + for neighbor_input in neighbor_values: + neighbor_raw_input[i] = neighbor_input + processed_neighbor_input = [input_component.preprocess(neighbor_raw_input[i]) + for i, input_component in enumerate(self.input_components)] + neighbor_output = self.run_prediction(processed_neighbor_input) + processed_neighbor_output = [output_component.postprocess( + neighbor_output[i]) for i, output_component in enumerate(self.output_components)] + + alternative_output.append(processed_neighbor_output) + interface_scores.append(quantify_difference_in_label(self, original_output, neighbor_output)) + alternative_outputs.append(alternative_output) interface_scores = [-score for score in interface_scores] - scores.append( - input_component.get_interpretation_scores( - raw_input[i], neighbor_values, interface_scores, **interpret_kwargs)) - return scores, alternative_outputs + scores.append( + input_component.get_interpretation_scores( + raw_input[i], neighbor_values, interface_scores, **interpret_kwargs)) + return scores, alternative_outputs + elif self.interpretation.lower() == "shap": + scores = [] + try: + import shap + except (ImportError, ModuleNotFoundError): + raise ValueError("The package `shap` is required for this interpretation method. Try: `pip install shap`") + + processed_input = [input_component.preprocess(raw_input[i]) + for i, input_component in enumerate(self.input_components)] + original_output = self.run_prediction(processed_input) + + for i, x in enumerate(raw_input): # iterate over reach interface + input_component = self.input_components[i] + tokens, _, masks = input_component.tokenize(x) + + def get_masked_prediction(binary_mask): # construct a masked version of the input + masked_xs = input_component.get_masked_inputs(tokens, binary_mask) + preds = [] + for masked_x in masked_xs: + processed_masked_input = copy.deepcopy(processed_input) + processed_masked_input[i] = input_component.preprocess(masked_x) + new_output = self.run_prediction(processed_masked_input) + pred = get_regression_or_classification_value(self, original_output, new_output) + preds.append(pred) + return np.array(preds) + + num_total_segments = len(tokens) + explainer = shap.KernelExplainer(get_masked_prediction, np.zeros((1, num_total_segments))) + shap_values = explainer.shap_values(np.ones((1, num_total_segments)), nsamples=int(self.num_shap*num_total_segments), silent=True) + scores.append(input_component.get_interpretation_scores(raw_input[i], None, shap_values[0], masks=masks, tokens=tokens)) + return scores, [] else: processed_input = [input_component.preprocess(raw_input[i]) for i, input_component in enumerate(self.input_components)] diff --git a/gradio/interpretation.py b/gradio/interpretation.py index 17e6cbdf95..fab0a88ed8 100644 --- a/gradio/interpretation.py +++ b/gradio/interpretation.py @@ -1,4 +1,5 @@ from gradio.outputs import Label, Textbox +import math def diff(original, perturbed): try: # try computing numerical difference @@ -24,7 +25,33 @@ def quantify_difference_in_label(interface, original_output, perturbed_output): else: score = diff(original_label, perturbed_label) return score + elif type(output_component) == Textbox: score = diff(post_original_output, post_perturbed_output) return score + + else: + raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component)) + +def get_regression_or_classification_value(interface, original_output, perturbed_output): + """Used to combine regression/classification for Shap interpretation method.""" + output_component = interface.output_components[0] + post_original_output = output_component.postprocess(original_output[0]) + post_perturbed_output = output_component.postprocess(perturbed_output[0]) + + if type(output_component) == Label: + original_label = post_original_output["label"] + perturbed_label = post_perturbed_output["label"] + + # Handle different return types of Label interface + if "confidences" in post_original_output: + if math.isnan(perturbed_output[0][original_label]): + return 0 + return perturbed_output[0][original_label] + else: + score = diff(perturbed_label, original_label) # Intentionall inverted order of arguments. + return score + + else: + raise ValueError("This interpretation method doesn't support the Output component: {}".format(output_component))