diff --git a/build/lib/gradio/explain.py b/build/lib/gradio/explain.py new file mode 100644 index 0000000000..68e78d1165 --- /dev/null +++ b/build/lib/gradio/explain.py @@ -0,0 +1,59 @@ +from gradio.inputs import Textbox +from gradio.inputs import Image + +from skimage.color import rgb2gray +from skimage.filters import sobel +from skimage.segmentation import slic +from skimage.util import img_as_float +from skimage import io +import numpy as np + + +def tokenize_text(text): + leave_one_out_tokens = [] + tokens = text.split() + leave_one_out_tokens.append(tokens) + for idx, _ in enumerate(tokens): + new_token_array = tokens.copy() + del new_token_array[idx] + leave_one_out_tokens.append(new_token_array) + return leave_one_out_tokens + +def tokenize_image(image): + img = img_as_float(image[::2, ::2]) + segments_slic = slic(img, n_segments=20, compactness=10, sigma=1) + leave_one_out_tokens = [] + for (i, segVal) in enumerate(np.unique(segments_slic)): + mask = np.copy(img) + mask[segments_slic == segVal] = 255 + leave_one_out_tokens.append(mask) + return leave_one_out_tokens + +def score(outputs): + print(outputs) + +def simple_explanation(interface, input_interfaces, + output_interfaces, input): + if isinstance(input_interfaces[0], Textbox): + leave_one_out_tokens = tokenize_text(input[0]) + outputs = [] + for input_text in leave_one_out_tokens: + input_text = " ".join(input_text) + print("Input Text: ", input_text) + output = interface.process(input_text) + outputs.extend(output) + print("Output: ", output) + score(outputs) + + elif isinstance(input_interfaces[0], Image): + leave_one_out_tokens = tokenize_image(input[0]) + outputs = [] + for input_text in leave_one_out_tokens: + input_text = " ".join(input_text) + print("Input Text: ", input_text) + output = interface.process(input_text) + outputs.extend(output) + print("Output: ", output) + score(outputs) + else: + print("Not valid input type") diff --git a/build/lib/gradio/interface.py b/build/lib/gradio/interface.py index d83d2de2cc..47956d2d3a 100644 --- a/build/lib/gradio/interface.py +++ b/build/lib/gradio/interface.py @@ -7,9 +7,13 @@ import tempfile import webbrowser from gradio.inputs import InputComponent +from gradio.inputs import Image +from gradio.inputs import Textbox from gradio.outputs import OutputComponent -from gradio import networking, strings, utils +from gradio import networking, strings, utils, processing_utils from distutils.version import StrictVersion +from skimage.segmentation import slic +from skimage.util import img_as_float import pkg_resources import requests import random @@ -20,6 +24,7 @@ import sys import weakref import analytics import os +import numpy as np PKG_VERSION_URL = "https://gradio.app/api/pkg-version" analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" @@ -46,7 +51,7 @@ class Interface: def __init__(self, fn, inputs, outputs, verbose=False, examples=None, live=False, show_input=True, show_output=True, - capture_session=False, title=None, description=None, + capture_session=False, explain_by=None, title=None, description=None, thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME, allow_screenshot=True, allow_flagging=True, flagging_dir="flagged", analytics_enabled=True): @@ -110,6 +115,7 @@ class Interface: self.show_output = show_output self.flag_hash = random.getrandbits(32) self.capture_session = capture_session + self.explain_by = explain_by self.session = None self.server_name = server_name self.title = title @@ -177,7 +183,8 @@ class Interface: "description": self.description, "thumbnail": self.thumbnail, "allow_screenshot": self.allow_screenshot, - "allow_flagging": self.allow_flagging + "allow_flagging": self.allow_flagging, + "allow_interpretation": self.explain_by is not None } try: param_names = inspect.getfullargspec(self.predict[0])[0] @@ -190,7 +197,6 @@ class Interface: iface[1]["label"] = ret_name except ValueError: pass - return config def process(self, raw_input, predict_fn=None): @@ -414,6 +420,88 @@ class Interface: return httpd, path_to_local_server, share_url + def tokenize_text(self, text): + leave_one_out_tokens = [] + tokens = text.split() + for idx, _ in enumerate(tokens): + new_token_array = tokens.copy() + del new_token_array[idx] + leave_one_out_tokens.append(new_token_array) + return tokens, leave_one_out_tokens + + def tokenize_image(self, image): + image = self.input_interfaces[0].preprocess(image) + segments_slic = slic(image, n_segments=20, compactness=10, sigma=1) + leave_one_out_tokens = [] + for (i, segVal) in enumerate(np.unique(segments_slic)): + mask = np.copy(image) + mask[segments_slic == segVal] = 255 + leave_one_out_tokens.append(mask) + return leave_one_out_tokens + + def score_text(self, tokens, leave_one_out_tokens, text): + original_label = "" + original_confidence = 0 + tokens = text.split() + + input_text = " ".join(tokens) + output = self.predict[0](input_text) + original_label = max(output, key=output.get) + original_confidence = output[original_label] + + scores = [] + for idx, input_text in enumerate(leave_one_out_tokens): + input_text = " ".join(input_text) + output = self.predict[0](input_text) + scores.append(original_confidence - output[original_label]) + + scores_by_char = [] + for idx, token in enumerate(tokens): + if idx != 0: + scores_by_char.append((" ", 0)) + for char in token: + scores_by_char.append((char, scores[idx])) + return scores_by_char + + def score_image(self, leave_one_out_tokens, image): + original_output = self.process(image) + original_label = original_output[0][0]['confidences'][0][ + 'label'] + original_confidence = original_output[0][0]['confidences'][0][ + 'confidence'] + output_scores = np.full(np.shape(self.input_interfaces[0].preprocess( + image[0])), 255) + for input_image in leave_one_out_tokens: + input_image_base64 = processing_utils.encode_array_to_base64( + input_image) + input_image_arr = [] + input_image_arr.append(input_image_base64) + output = self.process(input_image_arr) + np.set_printoptions(threshold=sys.maxsize) + if output[0][0]['confidences'][0]['label'] == original_label: + input_image[input_image == 255] = (original_confidence - + output[0][0][ + 'confidences'][0][ + 'confidence']) * 100 + mask = (output_scores == 255) + output_scores[mask] = input_image[mask] + return output_scores + + def simple_explanation(self, input): + if isinstance(self.input_interfaces[0], Textbox): + tokens, leave_one_out_tokens = self.tokenize_text(input[0]) + return [self.score_text(tokens, leave_one_out_tokens, input[0])] + elif isinstance(self.input_interfaces[0], Image): + leave_one_out_tokens = self.tokenize_image(input[0]) + return self.score_image(leave_one_out_tokens, input[0]) + else: + print("Not valid input type") + + def explain(self, input): + if self.explain_by == "default": + return self.simple_explanation(input) + else: + return self.explain_by(input) def reset_all(): for io in Interface.get_instances(): diff --git a/build/lib/gradio/networking.py b/build/lib/gradio/networking.py index e768f2252b..02a4402d89 100644 --- a/build/lib/gradio/networking.py +++ b/build/lib/gradio/networking.py @@ -194,6 +194,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n dict(zip(headers, output["inputs"] + output["outputs"])) ) + elif self.path == "/api/interpret/": + self._set_headers() + data_string = self.rfile.read( + int(self.headers["Content-Length"])) + msg = json.loads(data_string) + interpretation = interface.explain(msg["data"]) + self.wfile.write(json.dumps(interpretation).encode()) else: self.send_error(404, 'Path not found: {}'.format(self.path)) diff --git a/build/lib/gradio/static/js/all_io.js b/build/lib/gradio/static/js/all_io.js index d9958b5176..0b7575a5cd 100644 --- a/build/lib/gradio/static/js/all_io.js +++ b/build/lib/gradio/static/js/all_io.js @@ -80,9 +80,24 @@ var io_master_template = { $.ajax({type: "POST", url: "/api/flag/", data: JSON.stringify(post_data), - success: function(output){ - console.log("Flagging successful") - }, + }); + }, + interpret: function() { + var io = this; + this.target.find(".loading").removeClass("invisible"); + this.target.find(".loading_in_progress").show(); + var post_data = { + 'data': this.last_input + } + $.ajax({type: "POST", + url: "/api/interpret/", + data: JSON.stringify(post_data), + success: function(data) { + for (let [idx, interpretation] of data.entries()) { + io.input_interfaces[idx].show_interpretation(interpretation); + } + io.target.find(".loading_in_progress").hide(); + } }); } }; diff --git a/build/lib/gradio/static/js/gradio.js b/build/lib/gradio/static/js/gradio.js index ece85282d6..85c5a87da1 100644 --- a/build/lib/gradio/static/js/gradio.js +++ b/build/lib/gradio/static/js/gradio.js @@ -24,6 +24,7 @@ function gradio(config, fn, target, example_file_path) {