This commit is contained in:
Ali Abid 2020-09-16 16:43:37 -07:00
parent 3b73a5522a
commit b5b5c03ec6
12 changed files with 234 additions and 105 deletions

View File

@ -14,6 +14,8 @@ from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion
from skimage.segmentation import slic
from skimage.util import img_as_float
from gradio import processing_utils
import PIL
import pkg_resources
import requests
import random
@ -216,7 +218,7 @@ class Interface:
durations = []
for predict_fn in self.predict:
start = time.time()
if self.capture_session and not (self.session is None):
if self.capture_session and self.session is not None:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
@ -430,13 +432,14 @@ class Interface:
return tokens, leave_one_out_tokens
def tokenize_image(self, image):
image = self.input_interfaces[0].preprocess(image)
image = np.array(processing_utils.decode_base64_to_image(image))
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(image)
mask[segments_slic == segVal] = 255
leave_one_out_tokens.append(mask)
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = 255
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text):
@ -445,14 +448,18 @@ class Interface:
tokens = text.split()
input_text = " ".join(tokens)
output = self.predict[0](input_text)
original_label = max(output, key=output.get)
original_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
scores = []
for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text)
output = self.predict[0](input_text)
raw_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
scores.append(original_confidence - output[original_label])
scores_by_char = []
@ -464,44 +471,45 @@ class Interface:
return scores_by_char
def score_image(self, leave_one_out_tokens, image):
original_output = self.process(image)
original_label = original_output[0][0]['confidences'][0][
'label']
original_confidence = original_output[0][0]['confidences'][0][
'confidence']
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
image[0])), 255)
for input_image in leave_one_out_tokens:
original_output = self.process([image])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
image_interface = self.input_interfaces[0]
shape = processing_utils.decode_base64_to_image(image).size
output_scores = np.full((shape[1], shape[0]), 0.0)
for mask, input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64(
input_image)
input_image_arr = []
input_image_arr.append(input_image_base64)
output = self.process(input_image_arr)
np.set_printoptions(threshold=sys.maxsize)
if output[0][0]['confidences'][0]['label'] == original_label:
input_image[input_image == 255] = (original_confidence -
output[0][0][
'confidences'][0][
'confidence']) * 100
mask = (output_scores == 255)
output_scores[mask] = input_image[mask]
return output_scores
raw_output = self.process([input_image_base64])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
score = original_confidence - output[original_label]
output_scores += score * mask
max_val = np.max(np.abs(output_scores))
if max_val > 0:
output_scores = output_scores / max_val
return output_scores.tolist()
def simple_explanation(self, input):
def simple_explanation(self, x):
if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(input[0])
return [self.score_text(tokens, leave_one_out_tokens, input[0])]
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(input[0])
return self.score_image(leave_one_out_tokens, input[0])
leave_one_out_tokens = self.tokenize_image(x[0])
return [self.score_image(leave_one_out_tokens, x[0])]
else:
print("Not valid input type")
def explain(self, input):
def explain(self, x):
if self.explain_by == "default":
return self.simple_explanation(input)
return self.simple_explanation(x)
else:
return self.explain_by(input)
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
return self.explain_by(*preprocessed_x)
def reset_all():
for io in Interface.get_instances():

View File

@ -25,14 +25,10 @@
flex-direction: column;
border: none;
opacity: 1;
transition: opacity 0.2s ease;
}
.saliency > div {
display: flex;
flex-grow: 1;
}
.saliency > div > div {
flex-grow: 1;
background-color: #e67e22;
.saliency:hover {
opacity: 0.4;
}
.image_preview {
width: 100%;

View File

@ -94,6 +94,7 @@ var io_master_template = {
data: JSON.stringify(post_data),
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
console.log(idx)
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();

View File

@ -29,6 +29,9 @@ const image_input = {
<div class="image_preview_holder">
<img class="image_preview" />
</div>
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
</div>
</div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
@ -180,6 +183,19 @@ const image_input = {
this.cropper.destroy();
}
}
this.target.find(".saliency_holder").addClass("hide");
},
show_interpretation: function(data) {
if (this.target.find(".image_preview").attr("src")) {
var img = this.target.find(".image_preview")[0];
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
var width = size.width;
var height = size.height;
this.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = this.target.find(".saliency")[0].getContext('2d');
paintSaliency(data, ctx, width, height);
}
},
state: "NO_IMAGE",
image_data: null,

View File

@ -54,22 +54,19 @@ function toStringIfObject(input) {
return input;
}
function paintSaliency(data, width, height, ctx) {
function paintSaliency(data, ctx, width, height) {
var cell_width = width / data[0].length
var cell_height = height / data.length
var r = 0
data.forEach(function(row) {
var c = 0
row.forEach(function(cell) {
if (cell < 0.25) {
ctx.fillStyle = "white";
} else if (cell < 0.5) {
ctx.fillStyle = "#add8ed";
} else if (cell < 0.75) {
ctx.fillStyle = "#5aa7d3";
if (cell < 0) {
var color = [7,47,95];
} else {
ctx.fillStyle = "#072F5F";
var color = [112,62,8];
}
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
c++;
})
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
})
}
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
}
// val should be in the range [0.0, 1.0]
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
function interpolate(val, rgb1, rgb2) {
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
return rgb;
}
// quick helper function to convert the array into something we can use for css
function colorToString(rgb) {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
}

47
demo/image_classifier.py Normal file
View File

@ -0,0 +1,47 @@
import gradio as gr
import tensorflow as tf
# from vis.utils import utils
# from vis.visualization import visualize_cam
import numpy as np
from PIL import Image
import requests
from urllib.request import urlretrieve
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im):
arr = np.expand_dims(im, axis=0)
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
def image_explain(im):
model.layers[-1].activation = keras.activations.linear
model = utils.apply_modifications(model)
penultimate_layer_idx = 2
class_idx = class_idxs_sorted[0]
seed_input = img
grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input,
penultimate_layer_idx = penultimate_layer_idx,#None,
backprop_modifier = None,
grad_modifier = None)
print(grad_top_1)
return grad_top1
imagein = gr.inputs.Image(shape=(224, 224))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(image_classifier, imagein, label,
capture_session=True,
explain_by="default",
examples=[
["images/cheetah1.jpg"],
["images/lion.jpg"]
]).launch();

View File

@ -5,7 +5,9 @@ nltk.download('vader_lexicon')
sid = SentimentIntensityAnalyzer()
def sentiment_analysis(text):
return sid.polarity_scores(text)
scores = sid.polarity_scores(text)
del scores["compound"]
return scores
io = gr.Interface(sentiment_analysis, "textbox", "label", explain_by="default")

View File

@ -14,6 +14,8 @@ from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion
from skimage.segmentation import slic
from skimage.util import img_as_float
from gradio import processing_utils
import PIL
import pkg_resources
import requests
import random
@ -216,7 +218,7 @@ class Interface:
durations = []
for predict_fn in self.predict:
start = time.time()
if self.capture_session and not (self.session is None):
if self.capture_session and self.session is not None:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
@ -430,13 +432,14 @@ class Interface:
return tokens, leave_one_out_tokens
def tokenize_image(self, image):
image = self.input_interfaces[0].preprocess(image)
image = np.array(processing_utils.decode_base64_to_image(image))
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(image)
mask[segments_slic == segVal] = 255
leave_one_out_tokens.append(mask)
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = 255
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text):
@ -445,14 +448,18 @@ class Interface:
tokens = text.split()
input_text = " ".join(tokens)
output = self.predict[0](input_text)
original_label = max(output, key=output.get)
original_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
scores = []
for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text)
output = self.predict[0](input_text)
raw_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
scores.append(original_confidence - output[original_label])
scores_by_char = []
@ -464,44 +471,45 @@ class Interface:
return scores_by_char
def score_image(self, leave_one_out_tokens, image):
original_output = self.process(image)
original_label = original_output[0][0]['confidences'][0][
'label']
original_confidence = original_output[0][0]['confidences'][0][
'confidence']
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
image[0])), 255)
for input_image in leave_one_out_tokens:
original_output = self.process([image])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
image_interface = self.input_interfaces[0]
shape = processing_utils.decode_base64_to_image(image).size
output_scores = np.full((shape[1], shape[0]), 0.0)
for mask, input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64(
input_image)
input_image_arr = []
input_image_arr.append(input_image_base64)
output = self.process(input_image_arr)
np.set_printoptions(threshold=sys.maxsize)
if output[0][0]['confidences'][0]['label'] == original_label:
input_image[input_image == 255] = (original_confidence -
output[0][0][
'confidences'][0][
'confidence']) * 100
mask = (output_scores == 255)
output_scores[mask] = input_image[mask]
return output_scores
raw_output = self.process([input_image_base64])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
score = original_confidence - output[original_label]
output_scores += score * mask
max_val = np.max(np.abs(output_scores))
if max_val > 0:
output_scores = output_scores / max_val
return output_scores.tolist()
def simple_explanation(self, input):
def simple_explanation(self, x):
if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(input[0])
return [self.score_text(tokens, leave_one_out_tokens, input[0])]
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(input[0])
return self.score_image(leave_one_out_tokens, input[0])
leave_one_out_tokens = self.tokenize_image(x[0])
return [self.score_image(leave_one_out_tokens, x[0])]
else:
print("Not valid input type")
def explain(self, input):
def explain(self, x):
if self.explain_by == "default":
return self.simple_explanation(input)
return self.simple_explanation(x)
else:
return self.explain_by(input)
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
return self.explain_by(*preprocessed_x)
def reset_all():
for io in Interface.get_instances():

View File

@ -25,14 +25,10 @@
flex-direction: column;
border: none;
opacity: 1;
transition: opacity 0.2s ease;
}
.saliency > div {
display: flex;
flex-grow: 1;
}
.saliency > div > div {
flex-grow: 1;
background-color: #e67e22;
.saliency:hover {
opacity: 0.4;
}
.image_preview {
width: 100%;

View File

@ -94,6 +94,7 @@ var io_master_template = {
data: JSON.stringify(post_data),
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
console.log(idx)
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();

View File

@ -29,6 +29,9 @@ const image_input = {
<div class="image_preview_holder">
<img class="image_preview" />
</div>
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
</div>
</div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
@ -180,6 +183,19 @@ const image_input = {
this.cropper.destroy();
}
}
this.target.find(".saliency_holder").addClass("hide");
},
show_interpretation: function(data) {
if (this.target.find(".image_preview").attr("src")) {
var img = this.target.find(".image_preview")[0];
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
var width = size.width;
var height = size.height;
this.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = this.target.find(".saliency")[0].getContext('2d');
paintSaliency(data, ctx, width, height);
}
},
state: "NO_IMAGE",
image_data: null,

View File

@ -54,22 +54,19 @@ function toStringIfObject(input) {
return input;
}
function paintSaliency(data, width, height, ctx) {
function paintSaliency(data, ctx, width, height) {
var cell_width = width / data[0].length
var cell_height = height / data.length
var r = 0
data.forEach(function(row) {
var c = 0
row.forEach(function(cell) {
if (cell < 0.25) {
ctx.fillStyle = "white";
} else if (cell < 0.5) {
ctx.fillStyle = "#add8ed";
} else if (cell < 0.75) {
ctx.fillStyle = "#5aa7d3";
if (cell < 0) {
var color = [7,47,95];
} else {
ctx.fillStyle = "#072F5F";
var color = [112,62,8];
}
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
c++;
})
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
})
}
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
}
// val should be in the range [0.0, 1.0]
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
function interpolate(val, rgb1, rgb2) {
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
return rgb;
}
// quick helper function to convert the array into something we can use for css
function colorToString(rgb) {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
}