This commit is contained in:
Ali Abid 2020-09-16 16:43:37 -07:00
parent dd28a944cf
commit ce3d6c6e2d
12 changed files with 234 additions and 105 deletions

View File

@ -14,6 +14,8 @@ from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion from distutils.version import StrictVersion
from skimage.segmentation import slic from skimage.segmentation import slic
from skimage.util import img_as_float from skimage.util import img_as_float
from gradio import processing_utils
import PIL
import pkg_resources import pkg_resources
import requests import requests
import random import random
@ -216,7 +218,7 @@ class Interface:
durations = [] durations = []
for predict_fn in self.predict: for predict_fn in self.predict:
start = time.time() start = time.time()
if self.capture_session and not (self.session is None): if self.capture_session and self.session is not None:
graph, sess = self.session graph, sess = self.session
with graph.as_default(): with graph.as_default():
with sess.as_default(): with sess.as_default():
@ -430,13 +432,14 @@ class Interface:
return tokens, leave_one_out_tokens return tokens, leave_one_out_tokens
def tokenize_image(self, image): def tokenize_image(self, image):
image = self.input_interfaces[0].preprocess(image) image = np.array(processing_utils.decode_base64_to_image(image))
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1) segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = [] leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)): for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(image) mask = segments_slic == segVal
mask[segments_slic == segVal] = 255 white_screen = np.copy(image)
leave_one_out_tokens.append(mask) white_screen[segments_slic == segVal] = 255
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text): def score_text(self, tokens, leave_one_out_tokens, text):
@ -445,14 +448,18 @@ class Interface:
tokens = text.split() tokens = text.split()
input_text = " ".join(tokens) input_text = " ".join(tokens)
output = self.predict[0](input_text) original_output = self.process([input_text])
original_label = max(output, key=output.get) output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label] original_confidence = output[original_label]
scores = [] scores = []
for idx, input_text in enumerate(leave_one_out_tokens): for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text) input_text = " ".join(input_text)
output = self.predict[0](input_text) raw_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
scores.append(original_confidence - output[original_label]) scores.append(original_confidence - output[original_label])
scores_by_char = [] scores_by_char = []
@ -464,44 +471,45 @@ class Interface:
return scores_by_char return scores_by_char
def score_image(self, leave_one_out_tokens, image): def score_image(self, leave_one_out_tokens, image):
original_output = self.process(image) original_output = self.process([image])
original_label = original_output[0][0]['confidences'][0][ output = {result["label"] : result["confidence"]
'label'] for result in original_output[0][0]['confidences']}
original_confidence = original_output[0][0]['confidences'][0][ original_label = original_output[0][0]["label"]
'confidence'] original_confidence = output[original_label]
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
image[0])), 255) image_interface = self.input_interfaces[0]
for input_image in leave_one_out_tokens: shape = processing_utils.decode_base64_to_image(image).size
output_scores = np.full((shape[1], shape[0]), 0.0)
for mask, input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64( input_image_base64 = processing_utils.encode_array_to_base64(
input_image) input_image)
input_image_arr = [] raw_output = self.process([input_image_base64])
input_image_arr.append(input_image_base64) output = {result["label"] : result["confidence"]
output = self.process(input_image_arr) for result in raw_output[0][0]['confidences']}
np.set_printoptions(threshold=sys.maxsize) score = original_confidence - output[original_label]
if output[0][0]['confidences'][0]['label'] == original_label: output_scores += score * mask
input_image[input_image == 255] = (original_confidence - max_val = np.max(np.abs(output_scores))
output[0][0][ if max_val > 0:
'confidences'][0][ output_scores = output_scores / max_val
'confidence']) * 100 return output_scores.tolist()
mask = (output_scores == 255)
output_scores[mask] = input_image[mask]
return output_scores
def simple_explanation(self, input): def simple_explanation(self, x):
if isinstance(self.input_interfaces[0], Textbox): if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(input[0]) tokens, leave_one_out_tokens = self.tokenize_text(x[0])
return [self.score_text(tokens, leave_one_out_tokens, input[0])] return [self.score_text(tokens, leave_one_out_tokens, x[0])]
elif isinstance(self.input_interfaces[0], Image): elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(input[0]) leave_one_out_tokens = self.tokenize_image(x[0])
return self.score_image(leave_one_out_tokens, input[0]) return [self.score_image(leave_one_out_tokens, x[0])]
else: else:
print("Not valid input type") print("Not valid input type")
def explain(self, input): def explain(self, x):
if self.explain_by == "default": if self.explain_by == "default":
return self.simple_explanation(input) return self.simple_explanation(x)
else: else:
return self.explain_by(input) preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
return self.explain_by(*preprocessed_x)
def reset_all(): def reset_all():
for io in Interface.get_instances(): for io in Interface.get_instances():

View File

@ -25,14 +25,10 @@
flex-direction: column; flex-direction: column;
border: none; border: none;
opacity: 1; opacity: 1;
transition: opacity 0.2s ease;
} }
.saliency > div { .saliency:hover {
display: flex; opacity: 0.4;
flex-grow: 1;
}
.saliency > div > div {
flex-grow: 1;
background-color: #e67e22;
} }
.image_preview { .image_preview {
width: 100%; width: 100%;

View File

@ -94,6 +94,7 @@ var io_master_template = {
data: JSON.stringify(post_data), data: JSON.stringify(post_data),
success: function(data) { success: function(data) {
for (let [idx, interpretation] of data.entries()) { for (let [idx, interpretation] of data.entries()) {
console.log(idx)
io.input_interfaces[idx].show_interpretation(interpretation); io.input_interfaces[idx].show_interpretation(interpretation);
} }
io.target.find(".loading_in_progress").hide(); io.target.find(".loading_in_progress").hide();

View File

@ -29,6 +29,9 @@ const image_input = {
<div class="image_preview_holder"> <div class="image_preview_holder">
<img class="image_preview" /> <img class="image_preview" />
</div> </div>
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
</div> </div>
</div> </div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" /> <input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
@ -180,6 +183,19 @@ const image_input = {
this.cropper.destroy(); this.cropper.destroy();
} }
} }
this.target.find(".saliency_holder").addClass("hide");
},
show_interpretation: function(data) {
if (this.target.find(".image_preview").attr("src")) {
var img = this.target.find(".image_preview")[0];
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
var width = size.width;
var height = size.height;
this.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = this.target.find(".saliency")[0].getContext('2d');
paintSaliency(data, ctx, width, height);
}
}, },
state: "NO_IMAGE", state: "NO_IMAGE",
image_data: null, image_data: null,

View File

@ -54,22 +54,19 @@ function toStringIfObject(input) {
return input; return input;
} }
function paintSaliency(data, width, height, ctx) { function paintSaliency(data, ctx, width, height) {
var cell_width = width / data[0].length var cell_width = width / data[0].length
var cell_height = height / data.length var cell_height = height / data.length
var r = 0 var r = 0
data.forEach(function(row) { data.forEach(function(row) {
var c = 0 var c = 0
row.forEach(function(cell) { row.forEach(function(cell) {
if (cell < 0.25) { if (cell < 0) {
ctx.fillStyle = "white"; var color = [7,47,95];
} else if (cell < 0.5) {
ctx.fillStyle = "#add8ed";
} else if (cell < 0.75) {
ctx.fillStyle = "#5aa7d3";
} else { } else {
ctx.fillStyle = "#072F5F"; var color = [112,62,8];
} }
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height); ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
c++; c++;
}) })
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
}) })
} }
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
}
// val should be in the range [0.0, 1.0] // val should be in the range [0.0, 1.0]
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255] // rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
function interpolate(val, rgb1, rgb2) { function interpolate(val, rgb1, rgb2) {
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
return rgb; return rgb;
} }
// quick helper function to convert the array into something we can use for css
function colorToString(rgb) { function colorToString(rgb) {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")"; return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
} }

47
demo/image_classifier.py Normal file
View File

@ -0,0 +1,47 @@
import gradio as gr
import tensorflow as tf
# from vis.utils import utils
# from vis.visualization import visualize_cam
import numpy as np
from PIL import Image
import requests
from urllib.request import urlretrieve
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im):
arr = np.expand_dims(im, axis=0)
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
def image_explain(im):
model.layers[-1].activation = keras.activations.linear
model = utils.apply_modifications(model)
penultimate_layer_idx = 2
class_idx = class_idxs_sorted[0]
seed_input = img
grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input,
penultimate_layer_idx = penultimate_layer_idx,#None,
backprop_modifier = None,
grad_modifier = None)
print(grad_top_1)
return grad_top1
imagein = gr.inputs.Image(shape=(224, 224))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(image_classifier, imagein, label,
capture_session=True,
explain_by="default",
examples=[
["images/cheetah1.jpg"],
["images/lion.jpg"]
]).launch();

View File

@ -5,7 +5,9 @@ nltk.download('vader_lexicon')
sid = SentimentIntensityAnalyzer() sid = SentimentIntensityAnalyzer()
def sentiment_analysis(text): def sentiment_analysis(text):
return sid.polarity_scores(text) scores = sid.polarity_scores(text)
del scores["compound"]
return scores
io = gr.Interface(sentiment_analysis, "textbox", "label", explain_by="default") io = gr.Interface(sentiment_analysis, "textbox", "label", explain_by="default")

View File

@ -14,6 +14,8 @@ from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion from distutils.version import StrictVersion
from skimage.segmentation import slic from skimage.segmentation import slic
from skimage.util import img_as_float from skimage.util import img_as_float
from gradio import processing_utils
import PIL
import pkg_resources import pkg_resources
import requests import requests
import random import random
@ -216,7 +218,7 @@ class Interface:
durations = [] durations = []
for predict_fn in self.predict: for predict_fn in self.predict:
start = time.time() start = time.time()
if self.capture_session and not (self.session is None): if self.capture_session and self.session is not None:
graph, sess = self.session graph, sess = self.session
with graph.as_default(): with graph.as_default():
with sess.as_default(): with sess.as_default():
@ -430,13 +432,14 @@ class Interface:
return tokens, leave_one_out_tokens return tokens, leave_one_out_tokens
def tokenize_image(self, image): def tokenize_image(self, image):
image = self.input_interfaces[0].preprocess(image) image = np.array(processing_utils.decode_base64_to_image(image))
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1) segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = [] leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)): for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(image) mask = segments_slic == segVal
mask[segments_slic == segVal] = 255 white_screen = np.copy(image)
leave_one_out_tokens.append(mask) white_screen[segments_slic == segVal] = 255
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text): def score_text(self, tokens, leave_one_out_tokens, text):
@ -445,14 +448,18 @@ class Interface:
tokens = text.split() tokens = text.split()
input_text = " ".join(tokens) input_text = " ".join(tokens)
output = self.predict[0](input_text) original_output = self.process([input_text])
original_label = max(output, key=output.get) output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label] original_confidence = output[original_label]
scores = [] scores = []
for idx, input_text in enumerate(leave_one_out_tokens): for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text) input_text = " ".join(input_text)
output = self.predict[0](input_text) raw_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
scores.append(original_confidence - output[original_label]) scores.append(original_confidence - output[original_label])
scores_by_char = [] scores_by_char = []
@ -464,44 +471,45 @@ class Interface:
return scores_by_char return scores_by_char
def score_image(self, leave_one_out_tokens, image): def score_image(self, leave_one_out_tokens, image):
original_output = self.process(image) original_output = self.process([image])
original_label = original_output[0][0]['confidences'][0][ output = {result["label"] : result["confidence"]
'label'] for result in original_output[0][0]['confidences']}
original_confidence = original_output[0][0]['confidences'][0][ original_label = original_output[0][0]["label"]
'confidence'] original_confidence = output[original_label]
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
image[0])), 255) image_interface = self.input_interfaces[0]
for input_image in leave_one_out_tokens: shape = processing_utils.decode_base64_to_image(image).size
output_scores = np.full((shape[1], shape[0]), 0.0)
for mask, input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64( input_image_base64 = processing_utils.encode_array_to_base64(
input_image) input_image)
input_image_arr = [] raw_output = self.process([input_image_base64])
input_image_arr.append(input_image_base64) output = {result["label"] : result["confidence"]
output = self.process(input_image_arr) for result in raw_output[0][0]['confidences']}
np.set_printoptions(threshold=sys.maxsize) score = original_confidence - output[original_label]
if output[0][0]['confidences'][0]['label'] == original_label: output_scores += score * mask
input_image[input_image == 255] = (original_confidence - max_val = np.max(np.abs(output_scores))
output[0][0][ if max_val > 0:
'confidences'][0][ output_scores = output_scores / max_val
'confidence']) * 100 return output_scores.tolist()
mask = (output_scores == 255)
output_scores[mask] = input_image[mask]
return output_scores
def simple_explanation(self, input): def simple_explanation(self, x):
if isinstance(self.input_interfaces[0], Textbox): if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(input[0]) tokens, leave_one_out_tokens = self.tokenize_text(x[0])
return [self.score_text(tokens, leave_one_out_tokens, input[0])] return [self.score_text(tokens, leave_one_out_tokens, x[0])]
elif isinstance(self.input_interfaces[0], Image): elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(input[0]) leave_one_out_tokens = self.tokenize_image(x[0])
return self.score_image(leave_one_out_tokens, input[0]) return [self.score_image(leave_one_out_tokens, x[0])]
else: else:
print("Not valid input type") print("Not valid input type")
def explain(self, input): def explain(self, x):
if self.explain_by == "default": if self.explain_by == "default":
return self.simple_explanation(input) return self.simple_explanation(x)
else: else:
return self.explain_by(input) preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
return self.explain_by(*preprocessed_x)
def reset_all(): def reset_all():
for io in Interface.get_instances(): for io in Interface.get_instances():

View File

@ -25,14 +25,10 @@
flex-direction: column; flex-direction: column;
border: none; border: none;
opacity: 1; opacity: 1;
transition: opacity 0.2s ease;
} }
.saliency > div { .saliency:hover {
display: flex; opacity: 0.4;
flex-grow: 1;
}
.saliency > div > div {
flex-grow: 1;
background-color: #e67e22;
} }
.image_preview { .image_preview {
width: 100%; width: 100%;

View File

@ -94,6 +94,7 @@ var io_master_template = {
data: JSON.stringify(post_data), data: JSON.stringify(post_data),
success: function(data) { success: function(data) {
for (let [idx, interpretation] of data.entries()) { for (let [idx, interpretation] of data.entries()) {
console.log(idx)
io.input_interfaces[idx].show_interpretation(interpretation); io.input_interfaces[idx].show_interpretation(interpretation);
} }
io.target.find(".loading_in_progress").hide(); io.target.find(".loading_in_progress").hide();

View File

@ -29,6 +29,9 @@ const image_input = {
<div class="image_preview_holder"> <div class="image_preview_holder">
<img class="image_preview" /> <img class="image_preview" />
</div> </div>
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
</div> </div>
</div> </div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" /> <input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
@ -180,6 +183,19 @@ const image_input = {
this.cropper.destroy(); this.cropper.destroy();
} }
} }
this.target.find(".saliency_holder").addClass("hide");
},
show_interpretation: function(data) {
if (this.target.find(".image_preview").attr("src")) {
var img = this.target.find(".image_preview")[0];
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
var width = size.width;
var height = size.height;
this.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = this.target.find(".saliency")[0].getContext('2d');
paintSaliency(data, ctx, width, height);
}
}, },
state: "NO_IMAGE", state: "NO_IMAGE",
image_data: null, image_data: null,

View File

@ -54,22 +54,19 @@ function toStringIfObject(input) {
return input; return input;
} }
function paintSaliency(data, width, height, ctx) { function paintSaliency(data, ctx, width, height) {
var cell_width = width / data[0].length var cell_width = width / data[0].length
var cell_height = height / data.length var cell_height = height / data.length
var r = 0 var r = 0
data.forEach(function(row) { data.forEach(function(row) {
var c = 0 var c = 0
row.forEach(function(cell) { row.forEach(function(cell) {
if (cell < 0.25) { if (cell < 0) {
ctx.fillStyle = "white"; var color = [7,47,95];
} else if (cell < 0.5) {
ctx.fillStyle = "#add8ed";
} else if (cell < 0.75) {
ctx.fillStyle = "#5aa7d3";
} else { } else {
ctx.fillStyle = "#072F5F"; var color = [112,62,8];
} }
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height); ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
c++; c++;
}) })
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
}) })
} }
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
}
// val should be in the range [0.0, 1.0] // val should be in the range [0.0, 1.0]
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255] // rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
function interpolate(val, rgb1, rgb2) { function interpolate(val, rgb1, rgb2) {
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
return rgb; return rgb;
} }
// quick helper function to convert the array into something we can use for css
function colorToString(rgb) { function colorToString(rgb) {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")"; return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
} }