Merge pull request #52 from gradio-app/dawood/interpretation

Text & Image Interpretation
This commit is contained in:
aliabid94 2020-09-21 14:02:51 -07:00 committed by GitHub
commit f03c01f9a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 678 additions and 122 deletions

View File

@ -0,0 +1,59 @@
from gradio.inputs import Textbox
from gradio.inputs import Image
from skimage.color import rgb2gray
from skimage.filters import sobel
from skimage.segmentation import slic
from skimage.util import img_as_float
from skimage import io
import numpy as np
def tokenize_text(text):
leave_one_out_tokens = []
tokens = text.split()
leave_one_out_tokens.append(tokens)
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens
def tokenize_image(image):
img = img_as_float(image[::2, ::2])
segments_slic = slic(img, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(img)
mask[segments_slic == segVal] = 255
leave_one_out_tokens.append(mask)
return leave_one_out_tokens
def score(outputs):
print(outputs)
def simple_explanation(interface, input_interfaces,
output_interfaces, input):
if isinstance(input_interfaces[0], Textbox):
leave_one_out_tokens = tokenize_text(input[0])
outputs = []
for input_text in leave_one_out_tokens:
input_text = " ".join(input_text)
print("Input Text: ", input_text)
output = interface.process(input_text)
outputs.extend(output)
print("Output: ", output)
score(outputs)
elif isinstance(input_interfaces[0], Image):
leave_one_out_tokens = tokenize_image(input[0])
outputs = []
for input_text in leave_one_out_tokens:
input_text = " ".join(input_text)
print("Input Text: ", input_text)
output = interface.process(input_text)
outputs.extend(output)
print("Output: ", output)
score(outputs)
else:
print("Not valid input type")

View File

@ -5,10 +5,10 @@ interface using the input and output types.
import tempfile import tempfile
import webbrowser import webbrowser
from gradio.inputs import InputComponent from gradio.inputs import InputComponent
from gradio.outputs import OutputComponent from gradio.outputs import OutputComponent
from gradio import networking, strings, utils from gradio import networking, strings, utils
import gradio.interpretation
import requests import requests
import random import random
import time import time
@ -43,8 +43,9 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None, def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True, live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None, capture_session=False, interpretation=None, title=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME, description=None, thumbnail=None, server_port=None,
server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True, allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True): flagging_dir="flagged", analytics_enabled=True):
@ -57,6 +58,7 @@ class Interface:
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
live (bool): whether the interface should automatically reload on change. live (bool): whether the interface should automatically reload on change.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x) capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
title (str): a title for the interface; if provided, appears above the input and output components. title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components. description (str): a description for the interface; if provided, appears above the input and output components.
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
@ -98,6 +100,7 @@ class Interface:
if not isinstance(fn, list): if not isinstance(fn, list):
fn = [fn] fn = [fn]
self.output_interfaces *= len(fn) self.output_interfaces *= len(fn)
self.predict = fn self.predict = fn
self.verbose = verbose self.verbose = verbose
@ -107,6 +110,7 @@ class Interface:
self.show_output = show_output self.show_output = show_output
self.flag_hash = random.getrandbits(32) self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session self.capture_session = capture_session
self.interpretation = interpretation
self.session = None self.session = None
self.server_name = server_name self.server_name = server_name
self.title = title self.title = title
@ -175,6 +179,7 @@ class Interface:
"thumbnail": self.thumbnail, "thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot, "allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging, "allow_flagging": self.allow_flagging,
"allow_interpretation": self.interpretation is not None
} }
try: try:
param_names = inspect.getfullargspec(self.predict[0])[0] param_names = inspect.getfullargspec(self.predict[0])[0]
@ -187,8 +192,8 @@ class Interface:
iface[1]["label"] = ret_name iface[1]["label"] = ret_name
except ValueError: except ValueError:
pass pass
processed_examples = []
if self.examples is not None: if self.examples is not None:
processed_examples = []
for example_set in self.examples: for example_set in self.examples:
processed_set = [] processed_set = []
for iface, example in zip(self.input_interfaces, example_set): for iface, example in zip(self.input_interfaces, example_set):
@ -197,19 +202,7 @@ class Interface:
config["examples"] = processed_examples config["examples"] = processed_examples
return config return config
def process(self, raw_input, predict_fn=None): def run_prediction(self, processed_input, return_duration=False):
"""
:param raw_input: a list of raw inputs to process and apply the
prediction(s) on.
:param predict_fn: which function to process. If not provided, all of the model functions are used.
:return:
processed output: a list of processed outputs to return as the
prediction(s).
duration: a list of time deltas measuring inference time for each
prediction fn.
"""
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
predictions = [] predictions = []
durations = [] durations = []
for predict_fn in self.predict: for predict_fn in self.predict:
@ -239,6 +232,27 @@ class Interface:
prediction = [prediction] prediction = [prediction]
durations.append(duration) durations.append(duration)
predictions.extend(prediction) predictions.extend(prediction)
if return_duration:
return predictions, durations
else:
return predictions
def process(self, raw_input, predict_fn=None):
"""
:param raw_input: a list of raw inputs to process and apply the
prediction(s) on.
:param predict_fn: which function to process. If not provided, all of the model functions are used.
:return:
processed output: a list of processed outputs to return as the
prediction(s).
duration: a list of time deltas measuring inference time for each
prediction fn.
"""
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
predictions, durations = self.run_prediction(processed_input, return_duration=True)
processed_output = [output_interface.postprocess( processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations return processed_output, durations
@ -396,7 +410,6 @@ class Interface:
return app, path_to_local_server, share_url return app, path_to_local_server, share_url
def reset_all(): def reset_all():
for io in Interface.get_instances(): for io in Interface.get_instances():
io.close() io.close()

View File

@ -0,0 +1,103 @@
from gradio.inputs import Image, Textbox
from gradio.outputs import Label
from gradio import processing_utils
from skimage.segmentation import slic
import numpy as np
expected_types = {
Image: "numpy",
Textbox: "str"
}
def default(separator=" ", n_segments=20):
"""
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
inputs and outputs, uses the first component.
"""
def tokenize_text(text):
leave_one_out_tokens = []
tokens = text.split(separator)
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens
def tokenize_image(image):
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
replace_color = np.mean(image, axis=(0, 1))
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = replace_color
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(interface, leave_one_out_tokens, text):
tokens = text.split(separator)
original_output = interface.run_prediction([text])
scores_by_words = []
for idx, input_text in enumerate(leave_one_out_tokens):
perturbed_text = separator.join(input_text)
perturbed_output = interface.run_prediction([perturbed_text])
score = quantify_difference_in_label(interface, original_output, perturbed_output)
scores_by_words.append(score)
scores_by_char = []
for idx, token in enumerate(tokens):
if idx != 0:
scores_by_char.append((" ", 0))
for char in token:
scores_by_char.append((char, scores_by_words[idx]))
return scores_by_char
def score_image(interface, leave_one_out_tokens, image):
output_scores = np.zeros((image.shape[0], image.shape[1]))
original_output = interface.run_prediction([image])
for mask, perturbed_image in leave_one_out_tokens:
perturbed_output = interface.run_prediction([perturbed_image])
score = quantify_difference_in_label(interface, original_output, perturbed_output)
output_scores += score * mask
max_val, min_val = np.max(output_scores), np.min(output_scores)
if max_val > 0:
output_scores = (output_scores - min_val) / (max_val - min_val)
return output_scores.tolist()
def quantify_difference_in_label(interface, original_output, perturbed_output):
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
original_confidence = original_output[0][original_label]
perturbed_confidence = perturbed_output[0][original_label]
score = original_confidence - perturbed_confidence
else:
try: # try computing numerical difference
score = float(original_label) - float(perturbed_label)
except ValueError: # otherwise, look at strict difference in label
score = int(not(perturbed_label == original_label))
return score
def default_interpretation(interface, x):
if isinstance(interface.input_interfaces[0], Textbox) \
and isinstance(interface.output_interfaces[0], Label):
leave_one_out_tokens = tokenize_text(x[0])
return [score_text(interface, leave_one_out_tokens, x[0])]
if isinstance(interface.input_interfaces[0], Image) \
and isinstance(interface.output_interfaces[0], Label):
leave_one_out_tokens = tokenize_image(x[0])
return [score_image(interface, leave_one_out_tokens, x[0])]
else:
print("Not valid input or output types for 'default' interpretation")
return default_interpretation

View File

@ -9,7 +9,7 @@ from flask import Flask, request, jsonify, abort, send_file, render_template
from multiprocessing import Process from multiprocessing import Process
import pkg_resources import pkg_resources
from distutils import dir_util from distutils import dir_util
from gradio import inputs, outputs import gradio as gr
import time import time
import json import json
from gradio.tunneling import create_tunnel from gradio.tunneling import create_tunnel
@ -18,7 +18,7 @@ from shutil import copyfile
import requests import requests
import sys import sys
import csv import csv
import copy
INITIAL_PORT_VALUE = int(os.getenv( INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc. 'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -72,17 +72,19 @@ def get_first_available_port(initial, final):
@app.route("/", methods=["GET"]) @app.route("/", methods=["GET"])
def gradio(): def main():
return render_template("index.html", return render_template("index.html",
title=app.app_globals["title"], title=app.app_globals["title"],
description=app.app_globals["description"], description=app.app_globals["description"],
thumbnail=app.app_globals["thumbnail"], thumbnail=app.app_globals["thumbnail"],
) )
@app.route("/config/", methods=["GET"]) @app.route("/config/", methods=["GET"])
def config(): def config():
return jsonify(app.app_globals["config"]) return jsonify(app.app_globals["config"])
@app.route("/enable_sharing/<path:path>", methods=["GET"]) @app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path): def enable_sharing(path):
if path == "None": if path == "None":
@ -90,6 +92,7 @@ def enable_sharing(path):
app.app_globals["config"]["share_url"] = path app.app_globals["config"]["share_url"] = path
return jsonify(success=True) return jsonify(success=True)
@app.route("/api/predict/", methods=["POST"]) @app.route("/api/predict/", methods=["POST"])
def predict(): def predict():
raw_input = request.json["data"] raw_input = request.json["data"]
@ -97,6 +100,7 @@ def predict():
output = {"data": prediction, "durations": durations} output = {"data": prediction, "durations": durations}
return jsonify(output) return jsonify(output)
@app.route("/api/flag/", methods=["POST"]) @app.route("/api/flag/", methods=["POST"])
def flag(): def flag():
os.makedirs(app.interface.flagging_dir, exist_ok=True) os.makedirs(app.interface.flagging_dir, exist_ok=True)
@ -130,6 +134,25 @@ def flag():
) )
return jsonify(success=True) return jsonify(success=True)
@app.route("/api/interpret/", methods=["POST"])
def interpret():
raw_input = request.json["data"]
if app.interface.interpretation == "default":
interpreter = gr.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
processed_input.append(input_interface.preprocess(x))
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
interpreter = app.interface.interpretation
interpretation = interpreter(app.interface, processed_input)
return jsonify(interpretation)
@app.route("/file/<path:path>", methods=["GET"]) @app.route("/file/<path:path>", methods=["GET"])
def file(path): def file(path):
return send_file(os.path.join(os.getcwd(), path)) return send_file(os.path.join(os.getcwd(), path))

View File

@ -84,9 +84,6 @@ input.submit {
input.submit:hover { input.submit:hover {
background-color: #f39c12; background-color: #f39c12;
} }
.flag {
visibility: hidden;
}
.flagged { .flagged {
background-color: pink !important; background-color: pink !important;
} }
@ -111,9 +108,6 @@ input.submit:hover {
.invisible { .invisible {
display: none !important; display: none !important;
} }
.screenshot {
visibility: hidden;
}
.screenshot_logo { .screenshot_logo {
display: none; display: none;
flex-grow: 1; flex-grow: 1;

View File

@ -25,14 +25,10 @@
flex-direction: column; flex-direction: column;
border: none; border: none;
opacity: 1; opacity: 1;
transition: opacity 0.2s ease;
} }
.saliency > div { .saliency:hover {
display: flex; opacity: 0.4;
flex-grow: 1;
}
.saliency > div > div {
flex-grow: 1;
background-color: #e67e22;
} }
.image_preview { .image_preview {
width: 100%; width: 100%;

View File

@ -82,9 +82,26 @@ var io_master_template = {
data: JSON.stringify(post_data), data: JSON.stringify(post_data),
dataType: 'json', dataType: 'json',
contentType: 'application/json; charset=utf-8', contentType: 'application/json; charset=utf-8',
success: function(output){ });
console.log("Flagging successful") },
}, interpret: function() {
var io = this;
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
var post_data = {
'data': this.last_input
}
$.ajax({type: "POST",
url: "/api/interpret/",
data: JSON.stringify(post_data),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();
}
}); });
} }
}; };

View File

@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
io_master.last_output = null; io_master.last_output = null;
}); });
if (config["allow_screenshot"]) { if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
target.find(".screenshot").css("visibility", "visible"); target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
} } else {
if (config["allow_flagging"]) { if (!config["allow_screenshot"]) {
target.find(".flag").css("visibility", "visible"); target.find(".screenshot").hide();
} }
if (config["allow_interpretation"]) { if (!config["allow_flagging"]) {
target.find(".interpret").css("visibility", "visible"); target.find(".flag").hide();
}
if (!config["allow_interpretation"]) {
target.find(".interpret").hide();
}
} }
if (config["examples"]) { if (config["examples"]) {
target.find(".examples").removeClass("invisible"); target.find(".examples").removeClass("invisible");
@ -231,12 +235,15 @@ function gradio(config, fn, target, example_file_path) {
} }
target.find(".flag").click(function() { target.find(".flag").click(function() {
if (io_master.last_output) { if (io_master.last_output) {
target.find(".flag").addClass("flagged"); target.find(".flag").addClass("flagged");
target.find(".flag").val("FLAGGED"); target.find(".flag").val("FLAGGED");
io_master.flag(); io_master.flag();
}
// io_master.flag($(".flag_message").val()); })
target.find(".interpret").click(function() {
if (io_master.last_output) {
io_master.interpret();
} }
}) })

View File

@ -29,6 +29,9 @@ const image_input = {
<div class="image_preview_holder"> <div class="image_preview_holder">
<img class="image_preview" /> <img class="image_preview" />
</div> </div>
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
</div> </div>
</div> </div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" /> <input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
@ -180,6 +183,19 @@ const image_input = {
this.cropper.destroy(); this.cropper.destroy();
} }
} }
this.target.find(".saliency_holder").addClass("hide");
},
show_interpretation: function(data) {
if (this.target.find(".image_preview").attr("src")) {
var img = this.target.find(".image_preview")[0];
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
var width = size.width;
var height = size.height;
this.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = this.target.find(".saliency")[0].getContext('2d');
paintSaliency(data, ctx, width, height);
}
}, },
state: "NO_IMAGE", state: "NO_IMAGE",
image_data: null, image_data: null,

View File

@ -1,5 +1,8 @@
const textbox_input = { const textbox_input = {
html: `<textarea class="input_text"></textarea>`, html: `
<textarea class="input_text"></textarea>
<div class="output_text"></div>
`,
one_line_html: `<input type="text" class="input_text">`, one_line_html: `<input type="text" class="input_text">`,
init: function(opts) { init: function(opts) {
if (opts.lines > 1) { if (opts.lines > 1) {
@ -13,6 +16,7 @@ const textbox_input = {
if (opts.default) { if (opts.default) {
this.target.find(".input_text").val(opts.default) this.target.find(".input_text").val(opts.default)
} }
this.target.find(".output_text").hide();
}, },
submit: function() { submit: function() {
text = this.target.find(".input_text").val(); text = this.target.find(".input_text").val();
@ -20,6 +24,24 @@ const textbox_input = {
}, },
clear: function() { clear: function() {
this.target.find(".input_text").val(""); this.target.find(".input_text").val("");
this.target.find(".output_text").empty();
this.target.find(".input_text").show();
this.target.find(".output_text").hide();
},
show_interpretation: function(data) {
this.target.find(".input_text").hide();
this.target.find(".output_text").show();
let html = "";
for (let char_set of data) {
[char, value] = char_set;
if (value < 0) {
color = "8,241,255," + (-value * 2);
} else {
color = "230,126,34," + value * 2;
}
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
}
this.target.find(".output_text").html(html);
}, },
load_example: function(data) { load_example: function(data) {
this.target.find(".input_text").val(data); this.target.find(".input_text").val(data);

View File

@ -54,22 +54,19 @@ function toStringIfObject(input) {
return input; return input;
} }
function paintSaliency(data, width, height, ctx) { function paintSaliency(data, ctx, width, height) {
var cell_width = width / data[0].length var cell_width = width / data[0].length
var cell_height = height / data.length var cell_height = height / data.length
var r = 0 var r = 0
data.forEach(function(row) { data.forEach(function(row) {
var c = 0 var c = 0
row.forEach(function(cell) { row.forEach(function(cell) {
if (cell < 0.25) { if (cell < 0) {
ctx.fillStyle = "white"; var color = [7,47,95];
} else if (cell < 0.5) {
ctx.fillStyle = "#add8ed";
} else if (cell < 0.75) {
ctx.fillStyle = "#5aa7d3";
} else { } else {
ctx.fillStyle = "#072F5F"; var color = [112,62,8];
} }
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height); ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
c++; c++;
}) })
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
}) })
} }
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
}
// val should be in the range [0.0, 1.0] // val should be in the range [0.0, 1.0]
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255] // rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
function interpolate(val, rgb1, rgb2) { function interpolate(val, rgb1, rgb2) {
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
return rgb; return rgb;
} }
// quick helper function to convert the array into something we can use for css
function colorToString(rgb) { function colorToString(rgb) {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")"; return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
} }

48
demo/image_classifier.py Normal file
View File

@ -0,0 +1,48 @@
import gradio as gr
import tensorflow as tf
# from vis.utils import utils
# from vis.visualization import visualize_cam
import numpy as np
from PIL import Image
import requests
from urllib.request import urlretrieve
# # Download human-readable labels for ImageNet.
# response = requests.get("https://git.io/JJkYN")
# labels = response.text.split("\n")
labels = range(1000) # comment this later
mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im):
arr = np.expand_dims(im, axis=0)
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
def image_explain(im):
model.layers[-1].activation = tf.keras.activations.linear
model = utils.apply_modifications(model)
penultimate_layer_idx = 2
class_idx = class_idxs_sorted[0]
seed_input = img
grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input,
penultimate_layer_idx = penultimate_layer_idx,#None,
backprop_modifier = None,
grad_modifier = None)
print(grad_top_1)
return grad_top1
image = gr.inputs.Image(shape=(224, 224))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(image_classifier, image, label,
capture_session=True,
interpretation="default",
examples=[
["images/cheetah1.jpg"],
["images/lion.jpg"]
]).launch();

13
demo/longest_word.py Normal file
View File

@ -0,0 +1,13 @@
import gradio as gr
def longest_word(text):
words = text.split(" ")
lengths = [len(word) for word in words]
return max(lengths)
ex = "The quick brown fox jumped over the lazy dog."
io = gr.Interface(longest_word, "textbox", "label", interpretation="default", examples=[[ex]])
io.test_launch()
io.launch()

View File

@ -0,0 +1,15 @@
import gradio as gr
import nltk
from nltk.sentiment.vader import SentimentIntensityAnalyzer
nltk.download('vader_lexicon')
sid = SentimentIntensityAnalyzer()
def sentiment_analysis(text):
scores = sid.polarity_scores(text)
del scores["compound"]
return scores
io = gr.Interface(sentiment_analysis, "textbox", "label", interpretation="default")
io.test_launch()
io.launch()

View File

@ -5,6 +5,7 @@ gradio/__init__.py
gradio/component.py gradio/component.py
gradio/inputs.py gradio/inputs.py
gradio/interface.py gradio/interface.py
gradio/interpretation.py
gradio/networking.py gradio/networking.py
gradio/notebook.py gradio/notebook.py
gradio/outputs.py gradio/outputs.py

View File

@ -5,10 +5,10 @@ interface using the input and output types.
import tempfile import tempfile
import webbrowser import webbrowser
from gradio.inputs import InputComponent from gradio.inputs import InputComponent
from gradio.outputs import OutputComponent from gradio.outputs import OutputComponent
from gradio import networking, strings, utils from gradio import networking, strings, utils
import gradio.interpretation
import requests import requests
import random import random
import time import time
@ -43,8 +43,9 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None, def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True, live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None, capture_session=False, interpretation=None, title=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME, description=None, thumbnail=None, server_port=None,
server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True, allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True): flagging_dir="flagged", analytics_enabled=True):
@ -57,6 +58,7 @@ class Interface:
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
live (bool): whether the interface should automatically reload on change. live (bool): whether the interface should automatically reload on change.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x) capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
title (str): a title for the interface; if provided, appears above the input and output components. title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components. description (str): a description for the interface; if provided, appears above the input and output components.
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
@ -98,6 +100,7 @@ class Interface:
if not isinstance(fn, list): if not isinstance(fn, list):
fn = [fn] fn = [fn]
self.output_interfaces *= len(fn) self.output_interfaces *= len(fn)
self.predict = fn self.predict = fn
self.verbose = verbose self.verbose = verbose
@ -107,6 +110,7 @@ class Interface:
self.show_output = show_output self.show_output = show_output
self.flag_hash = random.getrandbits(32) self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session self.capture_session = capture_session
self.interpretation = interpretation
self.session = None self.session = None
self.server_name = server_name self.server_name = server_name
self.title = title self.title = title
@ -175,6 +179,7 @@ class Interface:
"thumbnail": self.thumbnail, "thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot, "allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging, "allow_flagging": self.allow_flagging,
"allow_interpretation": self.interpretation is not None
} }
try: try:
param_names = inspect.getfullargspec(self.predict[0])[0] param_names = inspect.getfullargspec(self.predict[0])[0]
@ -187,8 +192,8 @@ class Interface:
iface[1]["label"] = ret_name iface[1]["label"] = ret_name
except ValueError: except ValueError:
pass pass
processed_examples = []
if self.examples is not None: if self.examples is not None:
processed_examples = []
for example_set in self.examples: for example_set in self.examples:
processed_set = [] processed_set = []
for iface, example in zip(self.input_interfaces, example_set): for iface, example in zip(self.input_interfaces, example_set):
@ -197,19 +202,7 @@ class Interface:
config["examples"] = processed_examples config["examples"] = processed_examples
return config return config
def process(self, raw_input, predict_fn=None): def run_prediction(self, processed_input, return_duration=False):
"""
:param raw_input: a list of raw inputs to process and apply the
prediction(s) on.
:param predict_fn: which function to process. If not provided, all of the model functions are used.
:return:
processed output: a list of processed outputs to return as the
prediction(s).
duration: a list of time deltas measuring inference time for each
prediction fn.
"""
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
predictions = [] predictions = []
durations = [] durations = []
for predict_fn in self.predict: for predict_fn in self.predict:
@ -239,6 +232,27 @@ class Interface:
prediction = [prediction] prediction = [prediction]
durations.append(duration) durations.append(duration)
predictions.extend(prediction) predictions.extend(prediction)
if return_duration:
return predictions, durations
else:
return predictions
def process(self, raw_input, predict_fn=None):
"""
:param raw_input: a list of raw inputs to process and apply the
prediction(s) on.
:param predict_fn: which function to process. If not provided, all of the model functions are used.
:return:
processed output: a list of processed outputs to return as the
prediction(s).
duration: a list of time deltas measuring inference time for each
prediction fn.
"""
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
predictions, durations = self.run_prediction(processed_input, return_duration=True)
processed_output = [output_interface.postprocess( processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations return processed_output, durations
@ -396,7 +410,6 @@ class Interface:
return app, path_to_local_server, share_url return app, path_to_local_server, share_url
def reset_all(): def reset_all():
for io in Interface.get_instances(): for io in Interface.get_instances():
io.close() io.close()

103
gradio/interpretation.py Normal file
View File

@ -0,0 +1,103 @@
from gradio.inputs import Image, Textbox
from gradio.outputs import Label
from gradio import processing_utils
from skimage.segmentation import slic
import numpy as np
expected_types = {
Image: "numpy",
Textbox: "str"
}
def default(separator=" ", n_segments=20):
"""
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
inputs and outputs, uses the first component.
"""
def tokenize_text(text):
leave_one_out_tokens = []
tokens = text.split(separator)
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens
def tokenize_image(image):
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
replace_color = np.mean(image, axis=(0, 1))
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = replace_color
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(interface, leave_one_out_tokens, text):
tokens = text.split(separator)
original_output = interface.run_prediction([text])
scores_by_words = []
for idx, input_text in enumerate(leave_one_out_tokens):
perturbed_text = separator.join(input_text)
perturbed_output = interface.run_prediction([perturbed_text])
score = quantify_difference_in_label(interface, original_output, perturbed_output)
scores_by_words.append(score)
scores_by_char = []
for idx, token in enumerate(tokens):
if idx != 0:
scores_by_char.append((" ", 0))
for char in token:
scores_by_char.append((char, scores_by_words[idx]))
return scores_by_char
def score_image(interface, leave_one_out_tokens, image):
output_scores = np.zeros((image.shape[0], image.shape[1]))
original_output = interface.run_prediction([image])
for mask, perturbed_image in leave_one_out_tokens:
perturbed_output = interface.run_prediction([perturbed_image])
score = quantify_difference_in_label(interface, original_output, perturbed_output)
output_scores += score * mask
max_val, min_val = np.max(output_scores), np.min(output_scores)
if max_val > 0:
output_scores = (output_scores - min_val) / (max_val - min_val)
return output_scores.tolist()
def quantify_difference_in_label(interface, original_output, perturbed_output):
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
original_confidence = original_output[0][original_label]
perturbed_confidence = perturbed_output[0][original_label]
score = original_confidence - perturbed_confidence
else:
try: # try computing numerical difference
score = float(original_label) - float(perturbed_label)
except ValueError: # otherwise, look at strict difference in label
score = int(not(perturbed_label == original_label))
return score
def default_interpretation(interface, x):
if isinstance(interface.input_interfaces[0], Textbox) \
and isinstance(interface.output_interfaces[0], Label):
leave_one_out_tokens = tokenize_text(x[0])
return [score_text(interface, leave_one_out_tokens, x[0])]
if isinstance(interface.input_interfaces[0], Image) \
and isinstance(interface.output_interfaces[0], Label):
leave_one_out_tokens = tokenize_image(x[0])
return [score_image(interface, leave_one_out_tokens, x[0])]
else:
print("Not valid input or output types for 'default' interpretation")
return default_interpretation

View File

@ -9,7 +9,7 @@ from flask import Flask, request, jsonify, abort, send_file, render_template
from multiprocessing import Process from multiprocessing import Process
import pkg_resources import pkg_resources
from distutils import dir_util from distutils import dir_util
from gradio import inputs, outputs import gradio as gr
import time import time
import json import json
from gradio.tunneling import create_tunnel from gradio.tunneling import create_tunnel
@ -18,7 +18,7 @@ from shutil import copyfile
import requests import requests
import sys import sys
import csv import csv
import copy
INITIAL_PORT_VALUE = int(os.getenv( INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc. 'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -72,17 +72,19 @@ def get_first_available_port(initial, final):
@app.route("/", methods=["GET"]) @app.route("/", methods=["GET"])
def gradio(): def main():
return render_template("index.html", return render_template("index.html",
title=app.app_globals["title"], title=app.app_globals["title"],
description=app.app_globals["description"], description=app.app_globals["description"],
thumbnail=app.app_globals["thumbnail"], thumbnail=app.app_globals["thumbnail"],
) )
@app.route("/config/", methods=["GET"]) @app.route("/config/", methods=["GET"])
def config(): def config():
return jsonify(app.app_globals["config"]) return jsonify(app.app_globals["config"])
@app.route("/enable_sharing/<path:path>", methods=["GET"]) @app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path): def enable_sharing(path):
if path == "None": if path == "None":
@ -90,6 +92,7 @@ def enable_sharing(path):
app.app_globals["config"]["share_url"] = path app.app_globals["config"]["share_url"] = path
return jsonify(success=True) return jsonify(success=True)
@app.route("/api/predict/", methods=["POST"]) @app.route("/api/predict/", methods=["POST"])
def predict(): def predict():
raw_input = request.json["data"] raw_input = request.json["data"]
@ -97,6 +100,7 @@ def predict():
output = {"data": prediction, "durations": durations} output = {"data": prediction, "durations": durations}
return jsonify(output) return jsonify(output)
@app.route("/api/flag/", methods=["POST"]) @app.route("/api/flag/", methods=["POST"])
def flag(): def flag():
os.makedirs(app.interface.flagging_dir, exist_ok=True) os.makedirs(app.interface.flagging_dir, exist_ok=True)
@ -130,6 +134,25 @@ def flag():
) )
return jsonify(success=True) return jsonify(success=True)
@app.route("/api/interpret/", methods=["POST"])
def interpret():
raw_input = request.json["data"]
if app.interface.interpretation == "default":
interpreter = gr.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
processed_input.append(input_interface.preprocess(x))
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
interpreter = app.interface.interpretation
interpretation = interpreter(app.interface, processed_input)
return jsonify(interpretation)
@app.route("/file/<path:path>", methods=["GET"]) @app.route("/file/<path:path>", methods=["GET"])
def file(path): def file(path):
return send_file(os.path.join(os.getcwd(), path)) return send_file(os.path.join(os.getcwd(), path))

View File

@ -84,9 +84,6 @@ input.submit {
input.submit:hover { input.submit:hover {
background-color: #f39c12; background-color: #f39c12;
} }
.flag {
visibility: hidden;
}
.flagged { .flagged {
background-color: pink !important; background-color: pink !important;
} }
@ -111,9 +108,6 @@ input.submit:hover {
.invisible { .invisible {
display: none !important; display: none !important;
} }
.screenshot {
visibility: hidden;
}
.screenshot_logo { .screenshot_logo {
display: none; display: none;
flex-grow: 1; flex-grow: 1;

View File

@ -25,14 +25,10 @@
flex-direction: column; flex-direction: column;
border: none; border: none;
opacity: 1; opacity: 1;
transition: opacity 0.2s ease;
} }
.saliency > div { .saliency:hover {
display: flex; opacity: 0.4;
flex-grow: 1;
}
.saliency > div > div {
flex-grow: 1;
background-color: #e67e22;
} }
.image_preview { .image_preview {
width: 100%; width: 100%;

View File

@ -82,9 +82,26 @@ var io_master_template = {
data: JSON.stringify(post_data), data: JSON.stringify(post_data),
dataType: 'json', dataType: 'json',
contentType: 'application/json; charset=utf-8', contentType: 'application/json; charset=utf-8',
success: function(output){ });
console.log("Flagging successful") },
}, interpret: function() {
var io = this;
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
var post_data = {
'data': this.last_input
}
$.ajax({type: "POST",
url: "/api/interpret/",
data: JSON.stringify(post_data),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();
}
}); });
} }
}; };

View File

@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
io_master.last_output = null; io_master.last_output = null;
}); });
if (config["allow_screenshot"]) { if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
target.find(".screenshot").css("visibility", "visible"); target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
} } else {
if (config["allow_flagging"]) { if (!config["allow_screenshot"]) {
target.find(".flag").css("visibility", "visible"); target.find(".screenshot").hide();
} }
if (config["allow_interpretation"]) { if (!config["allow_flagging"]) {
target.find(".interpret").css("visibility", "visible"); target.find(".flag").hide();
}
if (!config["allow_interpretation"]) {
target.find(".interpret").hide();
}
} }
if (config["examples"]) { if (config["examples"]) {
target.find(".examples").removeClass("invisible"); target.find(".examples").removeClass("invisible");
@ -231,12 +235,15 @@ function gradio(config, fn, target, example_file_path) {
} }
target.find(".flag").click(function() { target.find(".flag").click(function() {
if (io_master.last_output) { if (io_master.last_output) {
target.find(".flag").addClass("flagged"); target.find(".flag").addClass("flagged");
target.find(".flag").val("FLAGGED"); target.find(".flag").val("FLAGGED");
io_master.flag(); io_master.flag();
}
// io_master.flag($(".flag_message").val()); })
target.find(".interpret").click(function() {
if (io_master.last_output) {
io_master.interpret();
} }
}) })

View File

@ -29,6 +29,9 @@ const image_input = {
<div class="image_preview_holder"> <div class="image_preview_holder">
<img class="image_preview" /> <img class="image_preview" />
</div> </div>
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
</div> </div>
</div> </div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" /> <input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
@ -180,6 +183,19 @@ const image_input = {
this.cropper.destroy(); this.cropper.destroy();
} }
} }
this.target.find(".saliency_holder").addClass("hide");
},
show_interpretation: function(data) {
if (this.target.find(".image_preview").attr("src")) {
var img = this.target.find(".image_preview")[0];
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
var width = size.width;
var height = size.height;
this.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = this.target.find(".saliency")[0].getContext('2d');
paintSaliency(data, ctx, width, height);
}
}, },
state: "NO_IMAGE", state: "NO_IMAGE",
image_data: null, image_data: null,

View File

@ -1,5 +1,8 @@
const textbox_input = { const textbox_input = {
html: `<textarea class="input_text"></textarea>`, html: `
<textarea class="input_text"></textarea>
<div class="output_text"></div>
`,
one_line_html: `<input type="text" class="input_text">`, one_line_html: `<input type="text" class="input_text">`,
init: function(opts) { init: function(opts) {
if (opts.lines > 1) { if (opts.lines > 1) {
@ -13,6 +16,7 @@ const textbox_input = {
if (opts.default) { if (opts.default) {
this.target.find(".input_text").val(opts.default) this.target.find(".input_text").val(opts.default)
} }
this.target.find(".output_text").hide();
}, },
submit: function() { submit: function() {
text = this.target.find(".input_text").val(); text = this.target.find(".input_text").val();
@ -20,6 +24,24 @@ const textbox_input = {
}, },
clear: function() { clear: function() {
this.target.find(".input_text").val(""); this.target.find(".input_text").val("");
this.target.find(".output_text").empty();
this.target.find(".input_text").show();
this.target.find(".output_text").hide();
},
show_interpretation: function(data) {
this.target.find(".input_text").hide();
this.target.find(".output_text").show();
let html = "";
for (let char_set of data) {
[char, value] = char_set;
if (value < 0) {
color = "8,241,255," + (-value * 2);
} else {
color = "230,126,34," + value * 2;
}
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
}
this.target.find(".output_text").html(html);
}, },
load_example: function(data) { load_example: function(data) {
this.target.find(".input_text").val(data); this.target.find(".input_text").val(data);

View File

@ -54,22 +54,19 @@ function toStringIfObject(input) {
return input; return input;
} }
function paintSaliency(data, width, height, ctx) { function paintSaliency(data, ctx, width, height) {
var cell_width = width / data[0].length var cell_width = width / data[0].length
var cell_height = height / data.length var cell_height = height / data.length
var r = 0 var r = 0
data.forEach(function(row) { data.forEach(function(row) {
var c = 0 var c = 0
row.forEach(function(cell) { row.forEach(function(cell) {
if (cell < 0.25) { if (cell < 0) {
ctx.fillStyle = "white"; var color = [7,47,95];
} else if (cell < 0.5) {
ctx.fillStyle = "#add8ed";
} else if (cell < 0.75) {
ctx.fillStyle = "#5aa7d3";
} else { } else {
ctx.fillStyle = "#072F5F"; var color = [112,62,8];
} }
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height); ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
c++; c++;
}) })
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
}) })
} }
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
var doRatio = width / height;
var cRatio = containerWidth / containerHeight;
var targetWidth = 0;
var targetHeight = 0;
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
if (test) {
targetWidth = containerWidth;
targetHeight = targetWidth / doRatio;
} else {
targetHeight = containerHeight;
targetWidth = targetHeight * doRatio;
}
return {
width: targetWidth,
height: targetHeight,
x: (containerWidth - targetWidth) / 2,
y: (containerHeight - targetHeight) / 2
};
}
// val should be in the range [0.0, 1.0] // val should be in the range [0.0, 1.0]
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255] // rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
function interpolate(val, rgb1, rgb2) { function interpolate(val, rgb1, rgb2) {
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
return rgb; return rgb;
} }
// quick helper function to convert the array into something we can use for css
function colorToString(rgb) { function colorToString(rgb) {
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")"; return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
} }