From b5b5c03ec6b3d7cc9821f7aec929bffbf62fd552 Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Wed, 16 Sep 2020 16:43:37 -0700 Subject: [PATCH] changes --- build/lib/gradio/interface.py | 80 ++++++++++--------- .../static/css/interfaces/input/image.css | 10 +-- build/lib/gradio/static/js/all_io.js | 1 + .../static/js/interfaces/input/image.js | 16 ++++ build/lib/gradio/static/js/utils.js | 37 ++++++--- demo/image_classifier.py | 47 +++++++++++ demo/sentiment_analysis.py | 4 +- gradio/interface.py | 80 ++++++++++--------- gradio/static/css/interfaces/input/image.css | 10 +-- gradio/static/js/all_io.js | 1 + gradio/static/js/interfaces/input/image.js | 16 ++++ gradio/static/js/utils.js | 37 ++++++--- 12 files changed, 234 insertions(+), 105 deletions(-) create mode 100644 demo/image_classifier.py diff --git a/build/lib/gradio/interface.py b/build/lib/gradio/interface.py index 47956d2d3a..15de825f2c 100644 --- a/build/lib/gradio/interface.py +++ b/build/lib/gradio/interface.py @@ -14,6 +14,8 @@ from gradio import networking, strings, utils, processing_utils from distutils.version import StrictVersion from skimage.segmentation import slic from skimage.util import img_as_float +from gradio import processing_utils +import PIL import pkg_resources import requests import random @@ -216,7 +218,7 @@ class Interface: durations = [] for predict_fn in self.predict: start = time.time() - if self.capture_session and not (self.session is None): + if self.capture_session and self.session is not None: graph, sess = self.session with graph.as_default(): with sess.as_default(): @@ -430,13 +432,14 @@ class Interface: return tokens, leave_one_out_tokens def tokenize_image(self, image): - image = self.input_interfaces[0].preprocess(image) + image = np.array(processing_utils.decode_base64_to_image(image)) segments_slic = slic(image, n_segments=20, compactness=10, sigma=1) leave_one_out_tokens = [] for (i, segVal) in enumerate(np.unique(segments_slic)): - mask = np.copy(image) - mask[segments_slic == segVal] = 255 - leave_one_out_tokens.append(mask) + mask = segments_slic == segVal + white_screen = np.copy(image) + white_screen[segments_slic == segVal] = 255 + leave_one_out_tokens.append((mask, white_screen)) return leave_one_out_tokens def score_text(self, tokens, leave_one_out_tokens, text): @@ -445,14 +448,18 @@ class Interface: tokens = text.split() input_text = " ".join(tokens) - output = self.predict[0](input_text) - original_label = max(output, key=output.get) + original_output = self.process([input_text]) + output = {result["label"] : result["confidence"] + for result in original_output[0][0]['confidences']} + original_label = original_output[0][0]["label"] original_confidence = output[original_label] scores = [] for idx, input_text in enumerate(leave_one_out_tokens): input_text = " ".join(input_text) - output = self.predict[0](input_text) + raw_output = self.process([input_text]) + output = {result["label"] : result["confidence"] + for result in raw_output[0][0]['confidences']} scores.append(original_confidence - output[original_label]) scores_by_char = [] @@ -464,44 +471,45 @@ class Interface: return scores_by_char def score_image(self, leave_one_out_tokens, image): - original_output = self.process(image) - original_label = original_output[0][0]['confidences'][0][ - 'label'] - original_confidence = original_output[0][0]['confidences'][0][ - 'confidence'] - output_scores = np.full(np.shape(self.input_interfaces[0].preprocess( - image[0])), 255) - for input_image in leave_one_out_tokens: + original_output = self.process([image]) + output = {result["label"] : result["confidence"] + for result in original_output[0][0]['confidences']} + original_label = original_output[0][0]["label"] + original_confidence = output[original_label] + + image_interface = self.input_interfaces[0] + shape = processing_utils.decode_base64_to_image(image).size + output_scores = np.full((shape[1], shape[0]), 0.0) + + for mask, input_image in leave_one_out_tokens: input_image_base64 = processing_utils.encode_array_to_base64( input_image) - input_image_arr = [] - input_image_arr.append(input_image_base64) - output = self.process(input_image_arr) - np.set_printoptions(threshold=sys.maxsize) - if output[0][0]['confidences'][0]['label'] == original_label: - input_image[input_image == 255] = (original_confidence - - output[0][0][ - 'confidences'][0][ - 'confidence']) * 100 - mask = (output_scores == 255) - output_scores[mask] = input_image[mask] - return output_scores + raw_output = self.process([input_image_base64]) + output = {result["label"] : result["confidence"] + for result in raw_output[0][0]['confidences']} + score = original_confidence - output[original_label] + output_scores += score * mask + max_val = np.max(np.abs(output_scores)) + if max_val > 0: + output_scores = output_scores / max_val + return output_scores.tolist() - def simple_explanation(self, input): + def simple_explanation(self, x): if isinstance(self.input_interfaces[0], Textbox): - tokens, leave_one_out_tokens = self.tokenize_text(input[0]) - return [self.score_text(tokens, leave_one_out_tokens, input[0])] + tokens, leave_one_out_tokens = self.tokenize_text(x[0]) + return [self.score_text(tokens, leave_one_out_tokens, x[0])] elif isinstance(self.input_interfaces[0], Image): - leave_one_out_tokens = self.tokenize_image(input[0]) - return self.score_image(leave_one_out_tokens, input[0]) + leave_one_out_tokens = self.tokenize_image(x[0]) + return [self.score_image(leave_one_out_tokens, x[0])] else: print("Not valid input type") - def explain(self, input): + def explain(self, x): if self.explain_by == "default": - return self.simple_explanation(input) + return self.simple_explanation(x) else: - return self.explain_by(input) + preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)] + return self.explain_by(*preprocessed_x) def reset_all(): for io in Interface.get_instances(): diff --git a/build/lib/gradio/static/css/interfaces/input/image.css b/build/lib/gradio/static/css/interfaces/input/image.css index e6fab55694..0ac808bcbe 100644 --- a/build/lib/gradio/static/css/interfaces/input/image.css +++ b/build/lib/gradio/static/css/interfaces/input/image.css @@ -25,14 +25,10 @@ flex-direction: column; border: none; opacity: 1; + transition: opacity 0.2s ease; } -.saliency > div { - display: flex; - flex-grow: 1; -} -.saliency > div > div { - flex-grow: 1; - background-color: #e67e22; +.saliency:hover { + opacity: 0.4; } .image_preview { width: 100%; diff --git a/build/lib/gradio/static/js/all_io.js b/build/lib/gradio/static/js/all_io.js index 0b7575a5cd..b228764da2 100644 --- a/build/lib/gradio/static/js/all_io.js +++ b/build/lib/gradio/static/js/all_io.js @@ -94,6 +94,7 @@ var io_master_template = { data: JSON.stringify(post_data), success: function(data) { for (let [idx, interpretation] of data.entries()) { + console.log(idx) io.input_interfaces[idx].show_interpretation(interpretation); } io.target.find(".loading_in_progress").hide(); diff --git a/build/lib/gradio/static/js/interfaces/input/image.js b/build/lib/gradio/static/js/interfaces/input/image.js index 85754fc074..8dfaf32518 100644 --- a/build/lib/gradio/static/js/interfaces/input/image.js +++ b/build/lib/gradio/static/js/interfaces/input/image.js @@ -29,6 +29,9 @@ const image_input = {
+
+ +
@@ -180,6 +183,19 @@ const image_input = { this.cropper.destroy(); } } + this.target.find(".saliency_holder").addClass("hide"); + }, + show_interpretation: function(data) { + if (this.target.find(".image_preview").attr("src")) { + var img = this.target.find(".image_preview")[0]; + var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight) + var width = size.width; + var height = size.height; + this.target.find(".saliency_holder").removeClass("hide").html(` + `); + var ctx = this.target.find(".saliency")[0].getContext('2d'); + paintSaliency(data, ctx, width, height); + } }, state: "NO_IMAGE", image_data: null, diff --git a/build/lib/gradio/static/js/utils.js b/build/lib/gradio/static/js/utils.js index 61b8e6bd05..01083d8798 100644 --- a/build/lib/gradio/static/js/utils.js +++ b/build/lib/gradio/static/js/utils.js @@ -54,22 +54,19 @@ function toStringIfObject(input) { return input; } -function paintSaliency(data, width, height, ctx) { +function paintSaliency(data, ctx, width, height) { var cell_width = width / data[0].length var cell_height = height / data.length var r = 0 data.forEach(function(row) { var c = 0 row.forEach(function(cell) { - if (cell < 0.25) { - ctx.fillStyle = "white"; - } else if (cell < 0.5) { - ctx.fillStyle = "#add8ed"; - } else if (cell < 0.75) { - ctx.fillStyle = "#5aa7d3"; + if (cell < 0) { + var color = [7,47,95]; } else { - ctx.fillStyle = "#072F5F"; + var color = [112,62,8]; } + ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color)); ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height); c++; }) @@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) { }) } +function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){ + var doRatio = width / height; + var cRatio = containerWidth / containerHeight; + var targetWidth = 0; + var targetHeight = 0; + var test = contains ? (doRatio > cRatio) : (doRatio < cRatio); + + if (test) { + targetWidth = containerWidth; + targetHeight = targetWidth / doRatio; + } else { + targetHeight = containerHeight; + targetWidth = targetHeight * doRatio; + } + + return { + width: targetWidth, + height: targetHeight, + x: (containerWidth - targetWidth) / 2, + y: (containerHeight - targetHeight) / 2 + }; +} + // val should be in the range [0.0, 1.0] // rgb1 and rgb2 should be an array of 3 values each in the range [0, 255] function interpolate(val, rgb1, rgb2) { @@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) { return rgb; } -// quick helper function to convert the array into something we can use for css function colorToString(rgb) { return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")"; } diff --git a/demo/image_classifier.py b/demo/image_classifier.py new file mode 100644 index 0000000000..0628a472fa --- /dev/null +++ b/demo/image_classifier.py @@ -0,0 +1,47 @@ +import gradio as gr +import tensorflow as tf +# from vis.utils import utils +# from vis.visualization import visualize_cam + +import numpy as np +from PIL import Image +import requests +from urllib.request import urlretrieve + +# Download human-readable labels for ImageNet. +response = requests.get("https://git.io/JJkYN") +labels = response.text.split("\n") + +mobile_net = tf.keras.applications.MobileNetV2() + + +def image_classifier(im): + arr = np.expand_dims(im, axis=0) + arr = tf.keras.applications.mobilenet.preprocess_input(arr) + prediction = mobile_net.predict(arr).flatten() + return {labels[i]: float(prediction[i]) for i in range(1000)} + +def image_explain(im): + model.layers[-1].activation = keras.activations.linear + model = utils.apply_modifications(model) + penultimate_layer_idx = 2 + class_idx = class_idxs_sorted[0] + seed_input = img + grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input, + penultimate_layer_idx = penultimate_layer_idx,#None, + backprop_modifier = None, + grad_modifier = None) + print(grad_top_1) + return grad_top1 + + +imagein = gr.inputs.Image(shape=(224, 224)) +label = gr.outputs.Label(num_top_classes=3) + +gr.Interface(image_classifier, imagein, label, + capture_session=True, + explain_by="default", + examples=[ + ["images/cheetah1.jpg"], + ["images/lion.jpg"] + ]).launch(); \ No newline at end of file diff --git a/demo/sentiment_analysis.py b/demo/sentiment_analysis.py index a7fdf3088c..49c8a8c257 100644 --- a/demo/sentiment_analysis.py +++ b/demo/sentiment_analysis.py @@ -5,7 +5,9 @@ nltk.download('vader_lexicon') sid = SentimentIntensityAnalyzer() def sentiment_analysis(text): - return sid.polarity_scores(text) + scores = sid.polarity_scores(text) + del scores["compound"] + return scores io = gr.Interface(sentiment_analysis, "textbox", "label", explain_by="default") diff --git a/gradio/interface.py b/gradio/interface.py index 47956d2d3a..15de825f2c 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -14,6 +14,8 @@ from gradio import networking, strings, utils, processing_utils from distutils.version import StrictVersion from skimage.segmentation import slic from skimage.util import img_as_float +from gradio import processing_utils +import PIL import pkg_resources import requests import random @@ -216,7 +218,7 @@ class Interface: durations = [] for predict_fn in self.predict: start = time.time() - if self.capture_session and not (self.session is None): + if self.capture_session and self.session is not None: graph, sess = self.session with graph.as_default(): with sess.as_default(): @@ -430,13 +432,14 @@ class Interface: return tokens, leave_one_out_tokens def tokenize_image(self, image): - image = self.input_interfaces[0].preprocess(image) + image = np.array(processing_utils.decode_base64_to_image(image)) segments_slic = slic(image, n_segments=20, compactness=10, sigma=1) leave_one_out_tokens = [] for (i, segVal) in enumerate(np.unique(segments_slic)): - mask = np.copy(image) - mask[segments_slic == segVal] = 255 - leave_one_out_tokens.append(mask) + mask = segments_slic == segVal + white_screen = np.copy(image) + white_screen[segments_slic == segVal] = 255 + leave_one_out_tokens.append((mask, white_screen)) return leave_one_out_tokens def score_text(self, tokens, leave_one_out_tokens, text): @@ -445,14 +448,18 @@ class Interface: tokens = text.split() input_text = " ".join(tokens) - output = self.predict[0](input_text) - original_label = max(output, key=output.get) + original_output = self.process([input_text]) + output = {result["label"] : result["confidence"] + for result in original_output[0][0]['confidences']} + original_label = original_output[0][0]["label"] original_confidence = output[original_label] scores = [] for idx, input_text in enumerate(leave_one_out_tokens): input_text = " ".join(input_text) - output = self.predict[0](input_text) + raw_output = self.process([input_text]) + output = {result["label"] : result["confidence"] + for result in raw_output[0][0]['confidences']} scores.append(original_confidence - output[original_label]) scores_by_char = [] @@ -464,44 +471,45 @@ class Interface: return scores_by_char def score_image(self, leave_one_out_tokens, image): - original_output = self.process(image) - original_label = original_output[0][0]['confidences'][0][ - 'label'] - original_confidence = original_output[0][0]['confidences'][0][ - 'confidence'] - output_scores = np.full(np.shape(self.input_interfaces[0].preprocess( - image[0])), 255) - for input_image in leave_one_out_tokens: + original_output = self.process([image]) + output = {result["label"] : result["confidence"] + for result in original_output[0][0]['confidences']} + original_label = original_output[0][0]["label"] + original_confidence = output[original_label] + + image_interface = self.input_interfaces[0] + shape = processing_utils.decode_base64_to_image(image).size + output_scores = np.full((shape[1], shape[0]), 0.0) + + for mask, input_image in leave_one_out_tokens: input_image_base64 = processing_utils.encode_array_to_base64( input_image) - input_image_arr = [] - input_image_arr.append(input_image_base64) - output = self.process(input_image_arr) - np.set_printoptions(threshold=sys.maxsize) - if output[0][0]['confidences'][0]['label'] == original_label: - input_image[input_image == 255] = (original_confidence - - output[0][0][ - 'confidences'][0][ - 'confidence']) * 100 - mask = (output_scores == 255) - output_scores[mask] = input_image[mask] - return output_scores + raw_output = self.process([input_image_base64]) + output = {result["label"] : result["confidence"] + for result in raw_output[0][0]['confidences']} + score = original_confidence - output[original_label] + output_scores += score * mask + max_val = np.max(np.abs(output_scores)) + if max_val > 0: + output_scores = output_scores / max_val + return output_scores.tolist() - def simple_explanation(self, input): + def simple_explanation(self, x): if isinstance(self.input_interfaces[0], Textbox): - tokens, leave_one_out_tokens = self.tokenize_text(input[0]) - return [self.score_text(tokens, leave_one_out_tokens, input[0])] + tokens, leave_one_out_tokens = self.tokenize_text(x[0]) + return [self.score_text(tokens, leave_one_out_tokens, x[0])] elif isinstance(self.input_interfaces[0], Image): - leave_one_out_tokens = self.tokenize_image(input[0]) - return self.score_image(leave_one_out_tokens, input[0]) + leave_one_out_tokens = self.tokenize_image(x[0]) + return [self.score_image(leave_one_out_tokens, x[0])] else: print("Not valid input type") - def explain(self, input): + def explain(self, x): if self.explain_by == "default": - return self.simple_explanation(input) + return self.simple_explanation(x) else: - return self.explain_by(input) + preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)] + return self.explain_by(*preprocessed_x) def reset_all(): for io in Interface.get_instances(): diff --git a/gradio/static/css/interfaces/input/image.css b/gradio/static/css/interfaces/input/image.css index e6fab55694..0ac808bcbe 100644 --- a/gradio/static/css/interfaces/input/image.css +++ b/gradio/static/css/interfaces/input/image.css @@ -25,14 +25,10 @@ flex-direction: column; border: none; opacity: 1; + transition: opacity 0.2s ease; } -.saliency > div { - display: flex; - flex-grow: 1; -} -.saliency > div > div { - flex-grow: 1; - background-color: #e67e22; +.saliency:hover { + opacity: 0.4; } .image_preview { width: 100%; diff --git a/gradio/static/js/all_io.js b/gradio/static/js/all_io.js index 0b7575a5cd..b228764da2 100644 --- a/gradio/static/js/all_io.js +++ b/gradio/static/js/all_io.js @@ -94,6 +94,7 @@ var io_master_template = { data: JSON.stringify(post_data), success: function(data) { for (let [idx, interpretation] of data.entries()) { + console.log(idx) io.input_interfaces[idx].show_interpretation(interpretation); } io.target.find(".loading_in_progress").hide(); diff --git a/gradio/static/js/interfaces/input/image.js b/gradio/static/js/interfaces/input/image.js index 85754fc074..8dfaf32518 100644 --- a/gradio/static/js/interfaces/input/image.js +++ b/gradio/static/js/interfaces/input/image.js @@ -29,6 +29,9 @@ const image_input = {
+
+ +
@@ -180,6 +183,19 @@ const image_input = { this.cropper.destroy(); } } + this.target.find(".saliency_holder").addClass("hide"); + }, + show_interpretation: function(data) { + if (this.target.find(".image_preview").attr("src")) { + var img = this.target.find(".image_preview")[0]; + var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight) + var width = size.width; + var height = size.height; + this.target.find(".saliency_holder").removeClass("hide").html(` + `); + var ctx = this.target.find(".saliency")[0].getContext('2d'); + paintSaliency(data, ctx, width, height); + } }, state: "NO_IMAGE", image_data: null, diff --git a/gradio/static/js/utils.js b/gradio/static/js/utils.js index 61b8e6bd05..01083d8798 100644 --- a/gradio/static/js/utils.js +++ b/gradio/static/js/utils.js @@ -54,22 +54,19 @@ function toStringIfObject(input) { return input; } -function paintSaliency(data, width, height, ctx) { +function paintSaliency(data, ctx, width, height) { var cell_width = width / data[0].length var cell_height = height / data.length var r = 0 data.forEach(function(row) { var c = 0 row.forEach(function(cell) { - if (cell < 0.25) { - ctx.fillStyle = "white"; - } else if (cell < 0.5) { - ctx.fillStyle = "#add8ed"; - } else if (cell < 0.75) { - ctx.fillStyle = "#5aa7d3"; + if (cell < 0) { + var color = [7,47,95]; } else { - ctx.fillStyle = "#072F5F"; + var color = [112,62,8]; } + ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color)); ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height); c++; }) @@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) { }) } +function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){ + var doRatio = width / height; + var cRatio = containerWidth / containerHeight; + var targetWidth = 0; + var targetHeight = 0; + var test = contains ? (doRatio > cRatio) : (doRatio < cRatio); + + if (test) { + targetWidth = containerWidth; + targetHeight = targetWidth / doRatio; + } else { + targetHeight = containerHeight; + targetWidth = targetHeight * doRatio; + } + + return { + width: targetWidth, + height: targetHeight, + x: (containerWidth - targetWidth) / 2, + y: (containerHeight - targetHeight) / 2 + }; +} + // val should be in the range [0.0, 1.0] // rgb1 and rgb2 should be an array of 3 values each in the range [0, 255] function interpolate(val, rgb1, rgb2) { @@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) { return rgb; } -// quick helper function to convert the array into something we can use for css function colorToString(rgb) { return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")"; }