mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
Merge pull request #52 from gradio-app/dawood/interpretation
Text & Image Interpretation
This commit is contained in:
commit
f03c01f9a8
59
build/lib/gradio/explain.py
Normal file
59
build/lib/gradio/explain.py
Normal file
@ -0,0 +1,59 @@
|
||||
from gradio.inputs import Textbox
|
||||
from gradio.inputs import Image
|
||||
|
||||
from skimage.color import rgb2gray
|
||||
from skimage.filters import sobel
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
from skimage import io
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tokenize_text(text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
leave_one_out_tokens.append(tokens)
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def tokenize_image(image):
|
||||
img = img_as_float(image[::2, ::2])
|
||||
segments_slic = slic(img, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = np.copy(img)
|
||||
mask[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append(mask)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score(outputs):
|
||||
print(outputs)
|
||||
|
||||
def simple_explanation(interface, input_interfaces,
|
||||
output_interfaces, input):
|
||||
if isinstance(input_interfaces[0], Textbox):
|
||||
leave_one_out_tokens = tokenize_text(input[0])
|
||||
outputs = []
|
||||
for input_text in leave_one_out_tokens:
|
||||
input_text = " ".join(input_text)
|
||||
print("Input Text: ", input_text)
|
||||
output = interface.process(input_text)
|
||||
outputs.extend(output)
|
||||
print("Output: ", output)
|
||||
score(outputs)
|
||||
|
||||
elif isinstance(input_interfaces[0], Image):
|
||||
leave_one_out_tokens = tokenize_image(input[0])
|
||||
outputs = []
|
||||
for input_text in leave_one_out_tokens:
|
||||
input_text = " ".join(input_text)
|
||||
print("Input Text: ", input_text)
|
||||
output = interface.process(input_text)
|
||||
outputs.extend(output)
|
||||
print("Output: ", output)
|
||||
score(outputs)
|
||||
else:
|
||||
print("Not valid input type")
|
@ -5,10 +5,10 @@ interface using the input and output types.
|
||||
|
||||
import tempfile
|
||||
import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
import gradio.interpretation
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
@ -43,8 +43,9 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, title=None, description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
capture_session=False, interpretation=None, title=None,
|
||||
description=None, thumbnail=None, server_port=None,
|
||||
server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
|
||||
@ -57,6 +58,7 @@ class Interface:
|
||||
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
||||
live (bool): whether the interface should automatically reload on change.
|
||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
||||
@ -98,6 +100,7 @@ class Interface:
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
|
||||
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
self.verbose = verbose
|
||||
@ -107,6 +110,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.interpretation = interpretation
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -175,6 +179,7 @@ class Interface:
|
||||
"thumbnail": self.thumbnail,
|
||||
"allow_screenshot": self.allow_screenshot,
|
||||
"allow_flagging": self.allow_flagging,
|
||||
"allow_interpretation": self.interpretation is not None
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
@ -187,8 +192,8 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
processed_examples = []
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
@ -197,19 +202,7 @@ class Interface:
|
||||
config["examples"] = processed_examples
|
||||
return config
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
"""
|
||||
:param raw_input: a list of raw inputs to process and apply the
|
||||
prediction(s) on.
|
||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||
:return:
|
||||
processed output: a list of processed outputs to return as the
|
||||
prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each
|
||||
prediction fn.
|
||||
"""
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
def run_prediction(self, processed_input, return_duration=False):
|
||||
predictions = []
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
@ -239,6 +232,27 @@ class Interface:
|
||||
prediction = [prediction]
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
|
||||
if return_duration:
|
||||
return predictions, durations
|
||||
else:
|
||||
return predictions
|
||||
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
"""
|
||||
:param raw_input: a list of raw inputs to process and apply the
|
||||
prediction(s) on.
|
||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||
:return:
|
||||
processed output: a list of processed outputs to return as the
|
||||
prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each
|
||||
prediction fn.
|
||||
"""
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
predictions, durations = self.run_prediction(processed_input, return_duration=True)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output, durations
|
||||
@ -396,7 +410,6 @@ class Interface:
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
io.close()
|
||||
|
103
build/lib/gradio/interpretation.py
Normal file
103
build/lib/gradio/interpretation.py
Normal file
@ -0,0 +1,103 @@
|
||||
from gradio.inputs import Image, Textbox
|
||||
from gradio.outputs import Label
|
||||
from gradio import processing_utils
|
||||
from skimage.segmentation import slic
|
||||
import numpy as np
|
||||
|
||||
expected_types = {
|
||||
Image: "numpy",
|
||||
Textbox: "str"
|
||||
}
|
||||
|
||||
def default(separator=" ", n_segments=20):
|
||||
"""
|
||||
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
|
||||
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
|
||||
inputs and outputs, uses the first component.
|
||||
"""
|
||||
def tokenize_text(text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split(separator)
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def tokenize_image(image):
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
replace_color = np.mean(image, axis=(0, 1))
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = replace_color
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(interface, leave_one_out_tokens, text):
|
||||
tokens = text.split(separator)
|
||||
original_output = interface.run_prediction([text])
|
||||
|
||||
scores_by_words = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
perturbed_text = separator.join(input_text)
|
||||
perturbed_output = interface.run_prediction([perturbed_text])
|
||||
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||
scores_by_words.append(score)
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores_by_words[idx]))
|
||||
|
||||
return scores_by_char
|
||||
|
||||
def score_image(interface, leave_one_out_tokens, image):
|
||||
output_scores = np.zeros((image.shape[0], image.shape[1]))
|
||||
original_output = interface.run_prediction([image])
|
||||
|
||||
for mask, perturbed_image in leave_one_out_tokens:
|
||||
perturbed_output = interface.run_prediction([perturbed_image])
|
||||
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||
output_scores += score * mask
|
||||
|
||||
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
||||
if max_val > 0:
|
||||
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||
return output_scores.tolist()
|
||||
|
||||
def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
|
||||
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
|
||||
original_label = post_original_output["label"]
|
||||
perturbed_label = post_perturbed_output["label"]
|
||||
|
||||
# Handle different return types of Label interface
|
||||
if "confidences" in post_original_output:
|
||||
original_confidence = original_output[0][original_label]
|
||||
perturbed_confidence = perturbed_output[0][original_label]
|
||||
score = original_confidence - perturbed_confidence
|
||||
else:
|
||||
try: # try computing numerical difference
|
||||
score = float(original_label) - float(perturbed_label)
|
||||
except ValueError: # otherwise, look at strict difference in label
|
||||
score = int(not(perturbed_label == original_label))
|
||||
return score
|
||||
|
||||
def default_interpretation(interface, x):
|
||||
if isinstance(interface.input_interfaces[0], Textbox) \
|
||||
and isinstance(interface.output_interfaces[0], Label):
|
||||
leave_one_out_tokens = tokenize_text(x[0])
|
||||
return [score_text(interface, leave_one_out_tokens, x[0])]
|
||||
if isinstance(interface.input_interfaces[0], Image) \
|
||||
and isinstance(interface.output_interfaces[0], Label):
|
||||
leave_one_out_tokens = tokenize_image(x[0])
|
||||
return [score_image(interface, leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input or output types for 'default' interpretation")
|
||||
|
||||
return default_interpretation
|
||||
|
@ -9,7 +9,7 @@ from flask import Flask, request, jsonify, abort, send_file, render_template
|
||||
from multiprocessing import Process
|
||||
import pkg_resources
|
||||
from distutils import dir_util
|
||||
from gradio import inputs, outputs
|
||||
import gradio as gr
|
||||
import time
|
||||
import json
|
||||
from gradio.tunneling import create_tunnel
|
||||
@ -18,7 +18,7 @@ from shutil import copyfile
|
||||
import requests
|
||||
import sys
|
||||
import csv
|
||||
|
||||
import copy
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
@ -72,17 +72,19 @@ def get_first_available_port(initial, final):
|
||||
|
||||
|
||||
@app.route("/", methods=["GET"])
|
||||
def gradio():
|
||||
def main():
|
||||
return render_template("index.html",
|
||||
title=app.app_globals["title"],
|
||||
description=app.app_globals["description"],
|
||||
thumbnail=app.app_globals["thumbnail"],
|
||||
)
|
||||
|
||||
|
||||
@app.route("/config/", methods=["GET"])
|
||||
def config():
|
||||
return jsonify(app.app_globals["config"])
|
||||
|
||||
|
||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||
def enable_sharing(path):
|
||||
if path == "None":
|
||||
@ -90,6 +92,7 @@ def enable_sharing(path):
|
||||
app.app_globals["config"]["share_url"] = path
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/predict/", methods=["POST"])
|
||||
def predict():
|
||||
raw_input = request.json["data"]
|
||||
@ -97,6 +100,7 @@ def predict():
|
||||
output = {"data": prediction, "durations": durations}
|
||||
return jsonify(output)
|
||||
|
||||
|
||||
@app.route("/api/flag/", methods=["POST"])
|
||||
def flag():
|
||||
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
||||
@ -130,6 +134,25 @@ def flag():
|
||||
)
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/interpret/", methods=["POST"])
|
||||
def interpret():
|
||||
raw_input = request.json["data"]
|
||||
if app.interface.interpretation == "default":
|
||||
interpreter = gr.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
interpreter = app.interface.interpretation
|
||||
interpretation = interpreter(app.interface, processed_input)
|
||||
return jsonify(interpretation)
|
||||
|
||||
|
||||
@app.route("/file/<path:path>", methods=["GET"])
|
||||
def file(path):
|
||||
return send_file(os.path.join(os.getcwd(), path))
|
||||
|
@ -84,9 +84,6 @@ input.submit {
|
||||
input.submit:hover {
|
||||
background-color: #f39c12;
|
||||
}
|
||||
.flag {
|
||||
visibility: hidden;
|
||||
}
|
||||
.flagged {
|
||||
background-color: pink !important;
|
||||
}
|
||||
@ -111,9 +108,6 @@ input.submit:hover {
|
||||
.invisible {
|
||||
display: none !important;
|
||||
}
|
||||
.screenshot {
|
||||
visibility: hidden;
|
||||
}
|
||||
.screenshot_logo {
|
||||
display: none;
|
||||
flex-grow: 1;
|
||||
|
@ -25,14 +25,10 @@
|
||||
flex-direction: column;
|
||||
border: none;
|
||||
opacity: 1;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
.saliency > div {
|
||||
display: flex;
|
||||
flex-grow: 1;
|
||||
}
|
||||
.saliency > div > div {
|
||||
flex-grow: 1;
|
||||
background-color: #e67e22;
|
||||
.saliency:hover {
|
||||
opacity: 0.4;
|
||||
}
|
||||
.image_preview {
|
||||
width: 100%;
|
||||
|
@ -82,9 +82,26 @@ var io_master_template = {
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
success: function(output){
|
||||
console.log("Flagging successful")
|
||||
},
|
||||
});
|
||||
},
|
||||
interpret: function() {
|
||||
var io = this;
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
var post_data = {
|
||||
'data': this.last_input
|
||||
}
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/interpret/",
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
success: function(data) {
|
||||
for (let [idx, interpretation] of data.entries()) {
|
||||
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||
}
|
||||
io.target.find(".loading_in_progress").hide();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
|
||||
io_master.last_output = null;
|
||||
});
|
||||
|
||||
if (config["allow_screenshot"]) {
|
||||
target.find(".screenshot").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_flagging"]) {
|
||||
target.find(".flag").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_interpretation"]) {
|
||||
target.find(".interpret").css("visibility", "visible");
|
||||
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
||||
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
|
||||
} else {
|
||||
if (!config["allow_screenshot"]) {
|
||||
target.find(".screenshot").hide();
|
||||
}
|
||||
if (!config["allow_flagging"]) {
|
||||
target.find(".flag").hide();
|
||||
}
|
||||
if (!config["allow_interpretation"]) {
|
||||
target.find(".interpret").hide();
|
||||
}
|
||||
}
|
||||
if (config["examples"]) {
|
||||
target.find(".examples").removeClass("invisible");
|
||||
@ -231,12 +235,15 @@ function gradio(config, fn, target, example_file_path) {
|
||||
}
|
||||
|
||||
target.find(".flag").click(function() {
|
||||
if (io_master.last_output) {
|
||||
target.find(".flag").addClass("flagged");
|
||||
target.find(".flag").val("FLAGGED");
|
||||
io_master.flag();
|
||||
|
||||
// io_master.flag($(".flag_message").val());
|
||||
if (io_master.last_output) {
|
||||
target.find(".flag").addClass("flagged");
|
||||
target.find(".flag").val("FLAGGED");
|
||||
io_master.flag();
|
||||
}
|
||||
})
|
||||
target.find(".interpret").click(function() {
|
||||
if (io_master.last_output) {
|
||||
io_master.interpret();
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -29,6 +29,9 @@ const image_input = {
|
||||
<div class="image_preview_holder">
|
||||
<img class="image_preview" />
|
||||
</div>
|
||||
<div class="saliency_holder hide">
|
||||
<canvas class="saliency"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
|
||||
@ -180,6 +183,19 @@ const image_input = {
|
||||
this.cropper.destroy();
|
||||
}
|
||||
}
|
||||
this.target.find(".saliency_holder").addClass("hide");
|
||||
},
|
||||
show_interpretation: function(data) {
|
||||
if (this.target.find(".image_preview").attr("src")) {
|
||||
var img = this.target.find(".image_preview")[0];
|
||||
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
|
||||
var width = size.width;
|
||||
var height = size.height;
|
||||
this.target.find(".saliency_holder").removeClass("hide").html(`
|
||||
<canvas class="saliency" width=${width} height=${height}></canvas>`);
|
||||
var ctx = this.target.find(".saliency")[0].getContext('2d');
|
||||
paintSaliency(data, ctx, width, height);
|
||||
}
|
||||
},
|
||||
state: "NO_IMAGE",
|
||||
image_data: null,
|
||||
|
@ -1,5 +1,8 @@
|
||||
const textbox_input = {
|
||||
html: `<textarea class="input_text"></textarea>`,
|
||||
html: `
|
||||
<textarea class="input_text"></textarea>
|
||||
<div class="output_text"></div>
|
||||
`,
|
||||
one_line_html: `<input type="text" class="input_text">`,
|
||||
init: function(opts) {
|
||||
if (opts.lines > 1) {
|
||||
@ -13,6 +16,7 @@ const textbox_input = {
|
||||
if (opts.default) {
|
||||
this.target.find(".input_text").val(opts.default)
|
||||
}
|
||||
this.target.find(".output_text").hide();
|
||||
},
|
||||
submit: function() {
|
||||
text = this.target.find(".input_text").val();
|
||||
@ -20,6 +24,24 @@ const textbox_input = {
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find(".input_text").val("");
|
||||
this.target.find(".output_text").empty();
|
||||
this.target.find(".input_text").show();
|
||||
this.target.find(".output_text").hide();
|
||||
},
|
||||
show_interpretation: function(data) {
|
||||
this.target.find(".input_text").hide();
|
||||
this.target.find(".output_text").show();
|
||||
let html = "";
|
||||
for (let char_set of data) {
|
||||
[char, value] = char_set;
|
||||
if (value < 0) {
|
||||
color = "8,241,255," + (-value * 2);
|
||||
} else {
|
||||
color = "230,126,34," + value * 2;
|
||||
}
|
||||
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
|
||||
}
|
||||
this.target.find(".output_text").html(html);
|
||||
},
|
||||
load_example: function(data) {
|
||||
this.target.find(".input_text").val(data);
|
||||
|
@ -54,22 +54,19 @@ function toStringIfObject(input) {
|
||||
return input;
|
||||
}
|
||||
|
||||
function paintSaliency(data, width, height, ctx) {
|
||||
function paintSaliency(data, ctx, width, height) {
|
||||
var cell_width = width / data[0].length
|
||||
var cell_height = height / data.length
|
||||
var r = 0
|
||||
data.forEach(function(row) {
|
||||
var c = 0
|
||||
row.forEach(function(cell) {
|
||||
if (cell < 0.25) {
|
||||
ctx.fillStyle = "white";
|
||||
} else if (cell < 0.5) {
|
||||
ctx.fillStyle = "#add8ed";
|
||||
} else if (cell < 0.75) {
|
||||
ctx.fillStyle = "#5aa7d3";
|
||||
if (cell < 0) {
|
||||
var color = [7,47,95];
|
||||
} else {
|
||||
ctx.fillStyle = "#072F5F";
|
||||
var color = [112,62,8];
|
||||
}
|
||||
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
|
||||
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
|
||||
c++;
|
||||
})
|
||||
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
|
||||
})
|
||||
}
|
||||
|
||||
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
|
||||
var doRatio = width / height;
|
||||
var cRatio = containerWidth / containerHeight;
|
||||
var targetWidth = 0;
|
||||
var targetHeight = 0;
|
||||
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
|
||||
|
||||
if (test) {
|
||||
targetWidth = containerWidth;
|
||||
targetHeight = targetWidth / doRatio;
|
||||
} else {
|
||||
targetHeight = containerHeight;
|
||||
targetWidth = targetHeight * doRatio;
|
||||
}
|
||||
|
||||
return {
|
||||
width: targetWidth,
|
||||
height: targetHeight,
|
||||
x: (containerWidth - targetWidth) / 2,
|
||||
y: (containerHeight - targetHeight) / 2
|
||||
};
|
||||
}
|
||||
|
||||
// val should be in the range [0.0, 1.0]
|
||||
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
|
||||
function interpolate(val, rgb1, rgb2) {
|
||||
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
|
||||
return rgb;
|
||||
}
|
||||
|
||||
// quick helper function to convert the array into something we can use for css
|
||||
function colorToString(rgb) {
|
||||
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
|
||||
}
|
||||
|
48
demo/image_classifier.py
Normal file
48
demo/image_classifier.py
Normal file
@ -0,0 +1,48 @@
|
||||
import gradio as gr
|
||||
import tensorflow as tf
|
||||
# from vis.utils import utils
|
||||
# from vis.visualization import visualize_cam
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import requests
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
# # Download human-readable labels for ImageNet.
|
||||
# response = requests.get("https://git.io/JJkYN")
|
||||
# labels = response.text.split("\n")
|
||||
labels = range(1000) # comment this later
|
||||
|
||||
mobile_net = tf.keras.applications.MobileNetV2()
|
||||
|
||||
|
||||
def image_classifier(im):
|
||||
arr = np.expand_dims(im, axis=0)
|
||||
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
|
||||
prediction = mobile_net.predict(arr).flatten()
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
|
||||
def image_explain(im):
|
||||
model.layers[-1].activation = tf.keras.activations.linear
|
||||
model = utils.apply_modifications(model)
|
||||
penultimate_layer_idx = 2
|
||||
class_idx = class_idxs_sorted[0]
|
||||
seed_input = img
|
||||
grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input,
|
||||
penultimate_layer_idx = penultimate_layer_idx,#None,
|
||||
backprop_modifier = None,
|
||||
grad_modifier = None)
|
||||
print(grad_top_1)
|
||||
return grad_top1
|
||||
|
||||
|
||||
image = gr.inputs.Image(shape=(224, 224))
|
||||
label = gr.outputs.Label(num_top_classes=3)
|
||||
|
||||
gr.Interface(image_classifier, image, label,
|
||||
capture_session=True,
|
||||
interpretation="default",
|
||||
examples=[
|
||||
["images/cheetah1.jpg"],
|
||||
["images/lion.jpg"]
|
||||
]).launch();
|
13
demo/longest_word.py
Normal file
13
demo/longest_word.py
Normal file
@ -0,0 +1,13 @@
|
||||
import gradio as gr
|
||||
|
||||
def longest_word(text):
|
||||
words = text.split(" ")
|
||||
lengths = [len(word) for word in words]
|
||||
return max(lengths)
|
||||
|
||||
ex = "The quick brown fox jumped over the lazy dog."
|
||||
|
||||
io = gr.Interface(longest_word, "textbox", "label", interpretation="default", examples=[[ex]])
|
||||
|
||||
io.test_launch()
|
||||
io.launch()
|
15
demo/sentiment_analysis.py
Normal file
15
demo/sentiment_analysis.py
Normal file
@ -0,0 +1,15 @@
|
||||
import gradio as gr
|
||||
import nltk
|
||||
from nltk.sentiment.vader import SentimentIntensityAnalyzer
|
||||
nltk.download('vader_lexicon')
|
||||
sid = SentimentIntensityAnalyzer()
|
||||
|
||||
def sentiment_analysis(text):
|
||||
scores = sid.polarity_scores(text)
|
||||
del scores["compound"]
|
||||
return scores
|
||||
|
||||
io = gr.Interface(sentiment_analysis, "textbox", "label", interpretation="default")
|
||||
|
||||
io.test_launch()
|
||||
io.launch()
|
@ -5,6 +5,7 @@ gradio/__init__.py
|
||||
gradio/component.py
|
||||
gradio/inputs.py
|
||||
gradio/interface.py
|
||||
gradio/interpretation.py
|
||||
gradio/networking.py
|
||||
gradio/notebook.py
|
||||
gradio/outputs.py
|
||||
|
@ -5,10 +5,10 @@ interface using the input and output types.
|
||||
|
||||
import tempfile
|
||||
import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
import gradio.interpretation
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
@ -43,8 +43,9 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, title=None, description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
capture_session=False, interpretation=None, title=None,
|
||||
description=None, thumbnail=None, server_port=None,
|
||||
server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
|
||||
@ -57,6 +58,7 @@ class Interface:
|
||||
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
||||
live (bool): whether the interface should automatically reload on change.
|
||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
||||
@ -98,6 +100,7 @@ class Interface:
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
|
||||
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
self.verbose = verbose
|
||||
@ -107,6 +110,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.interpretation = interpretation
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -175,6 +179,7 @@ class Interface:
|
||||
"thumbnail": self.thumbnail,
|
||||
"allow_screenshot": self.allow_screenshot,
|
||||
"allow_flagging": self.allow_flagging,
|
||||
"allow_interpretation": self.interpretation is not None
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
@ -187,8 +192,8 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
processed_examples = []
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
@ -197,19 +202,7 @@ class Interface:
|
||||
config["examples"] = processed_examples
|
||||
return config
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
"""
|
||||
:param raw_input: a list of raw inputs to process and apply the
|
||||
prediction(s) on.
|
||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||
:return:
|
||||
processed output: a list of processed outputs to return as the
|
||||
prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each
|
||||
prediction fn.
|
||||
"""
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
def run_prediction(self, processed_input, return_duration=False):
|
||||
predictions = []
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
@ -239,6 +232,27 @@ class Interface:
|
||||
prediction = [prediction]
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
|
||||
if return_duration:
|
||||
return predictions, durations
|
||||
else:
|
||||
return predictions
|
||||
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
"""
|
||||
:param raw_input: a list of raw inputs to process and apply the
|
||||
prediction(s) on.
|
||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||
:return:
|
||||
processed output: a list of processed outputs to return as the
|
||||
prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each
|
||||
prediction fn.
|
||||
"""
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
predictions, durations = self.run_prediction(processed_input, return_duration=True)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output, durations
|
||||
@ -396,7 +410,6 @@ class Interface:
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
io.close()
|
||||
|
103
gradio/interpretation.py
Normal file
103
gradio/interpretation.py
Normal file
@ -0,0 +1,103 @@
|
||||
from gradio.inputs import Image, Textbox
|
||||
from gradio.outputs import Label
|
||||
from gradio import processing_utils
|
||||
from skimage.segmentation import slic
|
||||
import numpy as np
|
||||
|
||||
expected_types = {
|
||||
Image: "numpy",
|
||||
Textbox: "str"
|
||||
}
|
||||
|
||||
def default(separator=" ", n_segments=20):
|
||||
"""
|
||||
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
|
||||
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
|
||||
inputs and outputs, uses the first component.
|
||||
"""
|
||||
def tokenize_text(text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split(separator)
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def tokenize_image(image):
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
replace_color = np.mean(image, axis=(0, 1))
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = replace_color
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(interface, leave_one_out_tokens, text):
|
||||
tokens = text.split(separator)
|
||||
original_output = interface.run_prediction([text])
|
||||
|
||||
scores_by_words = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
perturbed_text = separator.join(input_text)
|
||||
perturbed_output = interface.run_prediction([perturbed_text])
|
||||
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||
scores_by_words.append(score)
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores_by_words[idx]))
|
||||
|
||||
return scores_by_char
|
||||
|
||||
def score_image(interface, leave_one_out_tokens, image):
|
||||
output_scores = np.zeros((image.shape[0], image.shape[1]))
|
||||
original_output = interface.run_prediction([image])
|
||||
|
||||
for mask, perturbed_image in leave_one_out_tokens:
|
||||
perturbed_output = interface.run_prediction([perturbed_image])
|
||||
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||
output_scores += score * mask
|
||||
|
||||
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
||||
if max_val > 0:
|
||||
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||
return output_scores.tolist()
|
||||
|
||||
def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
|
||||
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
|
||||
original_label = post_original_output["label"]
|
||||
perturbed_label = post_perturbed_output["label"]
|
||||
|
||||
# Handle different return types of Label interface
|
||||
if "confidences" in post_original_output:
|
||||
original_confidence = original_output[0][original_label]
|
||||
perturbed_confidence = perturbed_output[0][original_label]
|
||||
score = original_confidence - perturbed_confidence
|
||||
else:
|
||||
try: # try computing numerical difference
|
||||
score = float(original_label) - float(perturbed_label)
|
||||
except ValueError: # otherwise, look at strict difference in label
|
||||
score = int(not(perturbed_label == original_label))
|
||||
return score
|
||||
|
||||
def default_interpretation(interface, x):
|
||||
if isinstance(interface.input_interfaces[0], Textbox) \
|
||||
and isinstance(interface.output_interfaces[0], Label):
|
||||
leave_one_out_tokens = tokenize_text(x[0])
|
||||
return [score_text(interface, leave_one_out_tokens, x[0])]
|
||||
if isinstance(interface.input_interfaces[0], Image) \
|
||||
and isinstance(interface.output_interfaces[0], Label):
|
||||
leave_one_out_tokens = tokenize_image(x[0])
|
||||
return [score_image(interface, leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input or output types for 'default' interpretation")
|
||||
|
||||
return default_interpretation
|
||||
|
@ -9,7 +9,7 @@ from flask import Flask, request, jsonify, abort, send_file, render_template
|
||||
from multiprocessing import Process
|
||||
import pkg_resources
|
||||
from distutils import dir_util
|
||||
from gradio import inputs, outputs
|
||||
import gradio as gr
|
||||
import time
|
||||
import json
|
||||
from gradio.tunneling import create_tunnel
|
||||
@ -18,7 +18,7 @@ from shutil import copyfile
|
||||
import requests
|
||||
import sys
|
||||
import csv
|
||||
|
||||
import copy
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
@ -72,17 +72,19 @@ def get_first_available_port(initial, final):
|
||||
|
||||
|
||||
@app.route("/", methods=["GET"])
|
||||
def gradio():
|
||||
def main():
|
||||
return render_template("index.html",
|
||||
title=app.app_globals["title"],
|
||||
description=app.app_globals["description"],
|
||||
thumbnail=app.app_globals["thumbnail"],
|
||||
)
|
||||
|
||||
|
||||
@app.route("/config/", methods=["GET"])
|
||||
def config():
|
||||
return jsonify(app.app_globals["config"])
|
||||
|
||||
|
||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||
def enable_sharing(path):
|
||||
if path == "None":
|
||||
@ -90,6 +92,7 @@ def enable_sharing(path):
|
||||
app.app_globals["config"]["share_url"] = path
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/predict/", methods=["POST"])
|
||||
def predict():
|
||||
raw_input = request.json["data"]
|
||||
@ -97,6 +100,7 @@ def predict():
|
||||
output = {"data": prediction, "durations": durations}
|
||||
return jsonify(output)
|
||||
|
||||
|
||||
@app.route("/api/flag/", methods=["POST"])
|
||||
def flag():
|
||||
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
||||
@ -130,6 +134,25 @@ def flag():
|
||||
)
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/interpret/", methods=["POST"])
|
||||
def interpret():
|
||||
raw_input = request.json["data"]
|
||||
if app.interface.interpretation == "default":
|
||||
interpreter = gr.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
interpreter = app.interface.interpretation
|
||||
interpretation = interpreter(app.interface, processed_input)
|
||||
return jsonify(interpretation)
|
||||
|
||||
|
||||
@app.route("/file/<path:path>", methods=["GET"])
|
||||
def file(path):
|
||||
return send_file(os.path.join(os.getcwd(), path))
|
||||
|
@ -84,9 +84,6 @@ input.submit {
|
||||
input.submit:hover {
|
||||
background-color: #f39c12;
|
||||
}
|
||||
.flag {
|
||||
visibility: hidden;
|
||||
}
|
||||
.flagged {
|
||||
background-color: pink !important;
|
||||
}
|
||||
@ -111,9 +108,6 @@ input.submit:hover {
|
||||
.invisible {
|
||||
display: none !important;
|
||||
}
|
||||
.screenshot {
|
||||
visibility: hidden;
|
||||
}
|
||||
.screenshot_logo {
|
||||
display: none;
|
||||
flex-grow: 1;
|
||||
|
@ -25,14 +25,10 @@
|
||||
flex-direction: column;
|
||||
border: none;
|
||||
opacity: 1;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
.saliency > div {
|
||||
display: flex;
|
||||
flex-grow: 1;
|
||||
}
|
||||
.saliency > div > div {
|
||||
flex-grow: 1;
|
||||
background-color: #e67e22;
|
||||
.saliency:hover {
|
||||
opacity: 0.4;
|
||||
}
|
||||
.image_preview {
|
||||
width: 100%;
|
||||
|
@ -82,9 +82,26 @@ var io_master_template = {
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
success: function(output){
|
||||
console.log("Flagging successful")
|
||||
},
|
||||
});
|
||||
},
|
||||
interpret: function() {
|
||||
var io = this;
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
var post_data = {
|
||||
'data': this.last_input
|
||||
}
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/interpret/",
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
success: function(data) {
|
||||
for (let [idx, interpretation] of data.entries()) {
|
||||
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||
}
|
||||
io.target.find(".loading_in_progress").hide();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
|
||||
io_master.last_output = null;
|
||||
});
|
||||
|
||||
if (config["allow_screenshot"]) {
|
||||
target.find(".screenshot").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_flagging"]) {
|
||||
target.find(".flag").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_interpretation"]) {
|
||||
target.find(".interpret").css("visibility", "visible");
|
||||
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
||||
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
|
||||
} else {
|
||||
if (!config["allow_screenshot"]) {
|
||||
target.find(".screenshot").hide();
|
||||
}
|
||||
if (!config["allow_flagging"]) {
|
||||
target.find(".flag").hide();
|
||||
}
|
||||
if (!config["allow_interpretation"]) {
|
||||
target.find(".interpret").hide();
|
||||
}
|
||||
}
|
||||
if (config["examples"]) {
|
||||
target.find(".examples").removeClass("invisible");
|
||||
@ -231,12 +235,15 @@ function gradio(config, fn, target, example_file_path) {
|
||||
}
|
||||
|
||||
target.find(".flag").click(function() {
|
||||
if (io_master.last_output) {
|
||||
target.find(".flag").addClass("flagged");
|
||||
target.find(".flag").val("FLAGGED");
|
||||
io_master.flag();
|
||||
|
||||
// io_master.flag($(".flag_message").val());
|
||||
if (io_master.last_output) {
|
||||
target.find(".flag").addClass("flagged");
|
||||
target.find(".flag").val("FLAGGED");
|
||||
io_master.flag();
|
||||
}
|
||||
})
|
||||
target.find(".interpret").click(function() {
|
||||
if (io_master.last_output) {
|
||||
io_master.interpret();
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -29,6 +29,9 @@ const image_input = {
|
||||
<div class="image_preview_holder">
|
||||
<img class="image_preview" />
|
||||
</div>
|
||||
<div class="saliency_holder hide">
|
||||
<canvas class="saliency"></canvas>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
|
||||
@ -180,6 +183,19 @@ const image_input = {
|
||||
this.cropper.destroy();
|
||||
}
|
||||
}
|
||||
this.target.find(".saliency_holder").addClass("hide");
|
||||
},
|
||||
show_interpretation: function(data) {
|
||||
if (this.target.find(".image_preview").attr("src")) {
|
||||
var img = this.target.find(".image_preview")[0];
|
||||
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
|
||||
var width = size.width;
|
||||
var height = size.height;
|
||||
this.target.find(".saliency_holder").removeClass("hide").html(`
|
||||
<canvas class="saliency" width=${width} height=${height}></canvas>`);
|
||||
var ctx = this.target.find(".saliency")[0].getContext('2d');
|
||||
paintSaliency(data, ctx, width, height);
|
||||
}
|
||||
},
|
||||
state: "NO_IMAGE",
|
||||
image_data: null,
|
||||
|
@ -1,5 +1,8 @@
|
||||
const textbox_input = {
|
||||
html: `<textarea class="input_text"></textarea>`,
|
||||
html: `
|
||||
<textarea class="input_text"></textarea>
|
||||
<div class="output_text"></div>
|
||||
`,
|
||||
one_line_html: `<input type="text" class="input_text">`,
|
||||
init: function(opts) {
|
||||
if (opts.lines > 1) {
|
||||
@ -13,6 +16,7 @@ const textbox_input = {
|
||||
if (opts.default) {
|
||||
this.target.find(".input_text").val(opts.default)
|
||||
}
|
||||
this.target.find(".output_text").hide();
|
||||
},
|
||||
submit: function() {
|
||||
text = this.target.find(".input_text").val();
|
||||
@ -20,6 +24,24 @@ const textbox_input = {
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find(".input_text").val("");
|
||||
this.target.find(".output_text").empty();
|
||||
this.target.find(".input_text").show();
|
||||
this.target.find(".output_text").hide();
|
||||
},
|
||||
show_interpretation: function(data) {
|
||||
this.target.find(".input_text").hide();
|
||||
this.target.find(".output_text").show();
|
||||
let html = "";
|
||||
for (let char_set of data) {
|
||||
[char, value] = char_set;
|
||||
if (value < 0) {
|
||||
color = "8,241,255," + (-value * 2);
|
||||
} else {
|
||||
color = "230,126,34," + value * 2;
|
||||
}
|
||||
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
|
||||
}
|
||||
this.target.find(".output_text").html(html);
|
||||
},
|
||||
load_example: function(data) {
|
||||
this.target.find(".input_text").val(data);
|
||||
|
@ -54,22 +54,19 @@ function toStringIfObject(input) {
|
||||
return input;
|
||||
}
|
||||
|
||||
function paintSaliency(data, width, height, ctx) {
|
||||
function paintSaliency(data, ctx, width, height) {
|
||||
var cell_width = width / data[0].length
|
||||
var cell_height = height / data.length
|
||||
var r = 0
|
||||
data.forEach(function(row) {
|
||||
var c = 0
|
||||
row.forEach(function(cell) {
|
||||
if (cell < 0.25) {
|
||||
ctx.fillStyle = "white";
|
||||
} else if (cell < 0.5) {
|
||||
ctx.fillStyle = "#add8ed";
|
||||
} else if (cell < 0.75) {
|
||||
ctx.fillStyle = "#5aa7d3";
|
||||
if (cell < 0) {
|
||||
var color = [7,47,95];
|
||||
} else {
|
||||
ctx.fillStyle = "#072F5F";
|
||||
var color = [112,62,8];
|
||||
}
|
||||
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
|
||||
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
|
||||
c++;
|
||||
})
|
||||
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
|
||||
})
|
||||
}
|
||||
|
||||
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
|
||||
var doRatio = width / height;
|
||||
var cRatio = containerWidth / containerHeight;
|
||||
var targetWidth = 0;
|
||||
var targetHeight = 0;
|
||||
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
|
||||
|
||||
if (test) {
|
||||
targetWidth = containerWidth;
|
||||
targetHeight = targetWidth / doRatio;
|
||||
} else {
|
||||
targetHeight = containerHeight;
|
||||
targetWidth = targetHeight * doRatio;
|
||||
}
|
||||
|
||||
return {
|
||||
width: targetWidth,
|
||||
height: targetHeight,
|
||||
x: (containerWidth - targetWidth) / 2,
|
||||
y: (containerHeight - targetHeight) / 2
|
||||
};
|
||||
}
|
||||
|
||||
// val should be in the range [0.0, 1.0]
|
||||
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
|
||||
function interpolate(val, rgb1, rgb2) {
|
||||
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
|
||||
return rgb;
|
||||
}
|
||||
|
||||
// quick helper function to convert the array into something we can use for css
|
||||
function colorToString(rgb) {
|
||||
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user