mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
Merge pull request #52 from gradio-app/dawood/interpretation
Text & Image Interpretation
This commit is contained in:
commit
f03c01f9a8
59
build/lib/gradio/explain.py
Normal file
59
build/lib/gradio/explain.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from gradio.inputs import Textbox
|
||||||
|
from gradio.inputs import Image
|
||||||
|
|
||||||
|
from skimage.color import rgb2gray
|
||||||
|
from skimage.filters import sobel
|
||||||
|
from skimage.segmentation import slic
|
||||||
|
from skimage.util import img_as_float
|
||||||
|
from skimage import io
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(text):
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
tokens = text.split()
|
||||||
|
leave_one_out_tokens.append(tokens)
|
||||||
|
for idx, _ in enumerate(tokens):
|
||||||
|
new_token_array = tokens.copy()
|
||||||
|
del new_token_array[idx]
|
||||||
|
leave_one_out_tokens.append(new_token_array)
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def tokenize_image(image):
|
||||||
|
img = img_as_float(image[::2, ::2])
|
||||||
|
segments_slic = slic(img, n_segments=20, compactness=10, sigma=1)
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||||
|
mask = np.copy(img)
|
||||||
|
mask[segments_slic == segVal] = 255
|
||||||
|
leave_one_out_tokens.append(mask)
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def score(outputs):
|
||||||
|
print(outputs)
|
||||||
|
|
||||||
|
def simple_explanation(interface, input_interfaces,
|
||||||
|
output_interfaces, input):
|
||||||
|
if isinstance(input_interfaces[0], Textbox):
|
||||||
|
leave_one_out_tokens = tokenize_text(input[0])
|
||||||
|
outputs = []
|
||||||
|
for input_text in leave_one_out_tokens:
|
||||||
|
input_text = " ".join(input_text)
|
||||||
|
print("Input Text: ", input_text)
|
||||||
|
output = interface.process(input_text)
|
||||||
|
outputs.extend(output)
|
||||||
|
print("Output: ", output)
|
||||||
|
score(outputs)
|
||||||
|
|
||||||
|
elif isinstance(input_interfaces[0], Image):
|
||||||
|
leave_one_out_tokens = tokenize_image(input[0])
|
||||||
|
outputs = []
|
||||||
|
for input_text in leave_one_out_tokens:
|
||||||
|
input_text = " ".join(input_text)
|
||||||
|
print("Input Text: ", input_text)
|
||||||
|
output = interface.process(input_text)
|
||||||
|
outputs.extend(output)
|
||||||
|
print("Output: ", output)
|
||||||
|
score(outputs)
|
||||||
|
else:
|
||||||
|
print("Not valid input type")
|
@ -5,10 +5,10 @@ interface using the input and output types.
|
|||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
|
||||||
from gradio.inputs import InputComponent
|
from gradio.inputs import InputComponent
|
||||||
from gradio.outputs import OutputComponent
|
from gradio.outputs import OutputComponent
|
||||||
from gradio import networking, strings, utils
|
from gradio import networking, strings, utils
|
||||||
|
import gradio.interpretation
|
||||||
import requests
|
import requests
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@ -43,8 +43,9 @@ class Interface:
|
|||||||
|
|
||||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||||
live=False, show_input=True, show_output=True,
|
live=False, show_input=True, show_output=True,
|
||||||
capture_session=False, title=None, description=None,
|
capture_session=False, interpretation=None, title=None,
|
||||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
description=None, thumbnail=None, server_port=None,
|
||||||
|
server_name=networking.LOCALHOST_NAME,
|
||||||
allow_screenshot=True, allow_flagging=True,
|
allow_screenshot=True, allow_flagging=True,
|
||||||
flagging_dir="flagged", analytics_enabled=True):
|
flagging_dir="flagged", analytics_enabled=True):
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ class Interface:
|
|||||||
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
||||||
live (bool): whether the interface should automatically reload on change.
|
live (bool): whether the interface should automatically reload on change.
|
||||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||||
|
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||||
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
||||||
@ -98,6 +100,7 @@ class Interface:
|
|||||||
if not isinstance(fn, list):
|
if not isinstance(fn, list):
|
||||||
fn = [fn]
|
fn = [fn]
|
||||||
|
|
||||||
|
|
||||||
self.output_interfaces *= len(fn)
|
self.output_interfaces *= len(fn)
|
||||||
self.predict = fn
|
self.predict = fn
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -107,6 +110,7 @@ class Interface:
|
|||||||
self.show_output = show_output
|
self.show_output = show_output
|
||||||
self.flag_hash = random.getrandbits(32)
|
self.flag_hash = random.getrandbits(32)
|
||||||
self.capture_session = capture_session
|
self.capture_session = capture_session
|
||||||
|
self.interpretation = interpretation
|
||||||
self.session = None
|
self.session = None
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.title = title
|
self.title = title
|
||||||
@ -175,6 +179,7 @@ class Interface:
|
|||||||
"thumbnail": self.thumbnail,
|
"thumbnail": self.thumbnail,
|
||||||
"allow_screenshot": self.allow_screenshot,
|
"allow_screenshot": self.allow_screenshot,
|
||||||
"allow_flagging": self.allow_flagging,
|
"allow_flagging": self.allow_flagging,
|
||||||
|
"allow_interpretation": self.interpretation is not None
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||||
@ -187,8 +192,8 @@ class Interface:
|
|||||||
iface[1]["label"] = ret_name
|
iface[1]["label"] = ret_name
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
processed_examples = []
|
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
|
processed_examples = []
|
||||||
for example_set in self.examples:
|
for example_set in self.examples:
|
||||||
processed_set = []
|
processed_set = []
|
||||||
for iface, example in zip(self.input_interfaces, example_set):
|
for iface, example in zip(self.input_interfaces, example_set):
|
||||||
@ -197,19 +202,7 @@ class Interface:
|
|||||||
config["examples"] = processed_examples
|
config["examples"] = processed_examples
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def process(self, raw_input, predict_fn=None):
|
def run_prediction(self, processed_input, return_duration=False):
|
||||||
"""
|
|
||||||
:param raw_input: a list of raw inputs to process and apply the
|
|
||||||
prediction(s) on.
|
|
||||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
|
||||||
:return:
|
|
||||||
processed output: a list of processed outputs to return as the
|
|
||||||
prediction(s).
|
|
||||||
duration: a list of time deltas measuring inference time for each
|
|
||||||
prediction fn.
|
|
||||||
"""
|
|
||||||
processed_input = [input_interface.preprocess(raw_input[i])
|
|
||||||
for i, input_interface in enumerate(self.input_interfaces)]
|
|
||||||
predictions = []
|
predictions = []
|
||||||
durations = []
|
durations = []
|
||||||
for predict_fn in self.predict:
|
for predict_fn in self.predict:
|
||||||
@ -239,6 +232,27 @@ class Interface:
|
|||||||
prediction = [prediction]
|
prediction = [prediction]
|
||||||
durations.append(duration)
|
durations.append(duration)
|
||||||
predictions.extend(prediction)
|
predictions.extend(prediction)
|
||||||
|
|
||||||
|
if return_duration:
|
||||||
|
return predictions, durations
|
||||||
|
else:
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def process(self, raw_input, predict_fn=None):
|
||||||
|
"""
|
||||||
|
:param raw_input: a list of raw inputs to process and apply the
|
||||||
|
prediction(s) on.
|
||||||
|
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||||
|
:return:
|
||||||
|
processed output: a list of processed outputs to return as the
|
||||||
|
prediction(s).
|
||||||
|
duration: a list of time deltas measuring inference time for each
|
||||||
|
prediction fn.
|
||||||
|
"""
|
||||||
|
processed_input = [input_interface.preprocess(raw_input[i])
|
||||||
|
for i, input_interface in enumerate(self.input_interfaces)]
|
||||||
|
predictions, durations = self.run_prediction(processed_input, return_duration=True)
|
||||||
processed_output = [output_interface.postprocess(
|
processed_output = [output_interface.postprocess(
|
||||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||||
return processed_output, durations
|
return processed_output, durations
|
||||||
@ -396,7 +410,6 @@ class Interface:
|
|||||||
|
|
||||||
return app, path_to_local_server, share_url
|
return app, path_to_local_server, share_url
|
||||||
|
|
||||||
|
|
||||||
def reset_all():
|
def reset_all():
|
||||||
for io in Interface.get_instances():
|
for io in Interface.get_instances():
|
||||||
io.close()
|
io.close()
|
||||||
|
103
build/lib/gradio/interpretation.py
Normal file
103
build/lib/gradio/interpretation.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from gradio.inputs import Image, Textbox
|
||||||
|
from gradio.outputs import Label
|
||||||
|
from gradio import processing_utils
|
||||||
|
from skimage.segmentation import slic
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
expected_types = {
|
||||||
|
Image: "numpy",
|
||||||
|
Textbox: "str"
|
||||||
|
}
|
||||||
|
|
||||||
|
def default(separator=" ", n_segments=20):
|
||||||
|
"""
|
||||||
|
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
|
||||||
|
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
|
||||||
|
inputs and outputs, uses the first component.
|
||||||
|
"""
|
||||||
|
def tokenize_text(text):
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
tokens = text.split(separator)
|
||||||
|
for idx, _ in enumerate(tokens):
|
||||||
|
new_token_array = tokens.copy()
|
||||||
|
del new_token_array[idx]
|
||||||
|
leave_one_out_tokens.append(new_token_array)
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def tokenize_image(image):
|
||||||
|
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
replace_color = np.mean(image, axis=(0, 1))
|
||||||
|
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||||
|
mask = segments_slic == segVal
|
||||||
|
white_screen = np.copy(image)
|
||||||
|
white_screen[segments_slic == segVal] = replace_color
|
||||||
|
leave_one_out_tokens.append((mask, white_screen))
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def score_text(interface, leave_one_out_tokens, text):
|
||||||
|
tokens = text.split(separator)
|
||||||
|
original_output = interface.run_prediction([text])
|
||||||
|
|
||||||
|
scores_by_words = []
|
||||||
|
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||||
|
perturbed_text = separator.join(input_text)
|
||||||
|
perturbed_output = interface.run_prediction([perturbed_text])
|
||||||
|
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||||
|
scores_by_words.append(score)
|
||||||
|
|
||||||
|
scores_by_char = []
|
||||||
|
for idx, token in enumerate(tokens):
|
||||||
|
if idx != 0:
|
||||||
|
scores_by_char.append((" ", 0))
|
||||||
|
for char in token:
|
||||||
|
scores_by_char.append((char, scores_by_words[idx]))
|
||||||
|
|
||||||
|
return scores_by_char
|
||||||
|
|
||||||
|
def score_image(interface, leave_one_out_tokens, image):
|
||||||
|
output_scores = np.zeros((image.shape[0], image.shape[1]))
|
||||||
|
original_output = interface.run_prediction([image])
|
||||||
|
|
||||||
|
for mask, perturbed_image in leave_one_out_tokens:
|
||||||
|
perturbed_output = interface.run_prediction([perturbed_image])
|
||||||
|
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||||
|
output_scores += score * mask
|
||||||
|
|
||||||
|
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
||||||
|
if max_val > 0:
|
||||||
|
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||||
|
return output_scores.tolist()
|
||||||
|
|
||||||
|
def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||||
|
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
|
||||||
|
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
|
||||||
|
original_label = post_original_output["label"]
|
||||||
|
perturbed_label = post_perturbed_output["label"]
|
||||||
|
|
||||||
|
# Handle different return types of Label interface
|
||||||
|
if "confidences" in post_original_output:
|
||||||
|
original_confidence = original_output[0][original_label]
|
||||||
|
perturbed_confidence = perturbed_output[0][original_label]
|
||||||
|
score = original_confidence - perturbed_confidence
|
||||||
|
else:
|
||||||
|
try: # try computing numerical difference
|
||||||
|
score = float(original_label) - float(perturbed_label)
|
||||||
|
except ValueError: # otherwise, look at strict difference in label
|
||||||
|
score = int(not(perturbed_label == original_label))
|
||||||
|
return score
|
||||||
|
|
||||||
|
def default_interpretation(interface, x):
|
||||||
|
if isinstance(interface.input_interfaces[0], Textbox) \
|
||||||
|
and isinstance(interface.output_interfaces[0], Label):
|
||||||
|
leave_one_out_tokens = tokenize_text(x[0])
|
||||||
|
return [score_text(interface, leave_one_out_tokens, x[0])]
|
||||||
|
if isinstance(interface.input_interfaces[0], Image) \
|
||||||
|
and isinstance(interface.output_interfaces[0], Label):
|
||||||
|
leave_one_out_tokens = tokenize_image(x[0])
|
||||||
|
return [score_image(interface, leave_one_out_tokens, x[0])]
|
||||||
|
else:
|
||||||
|
print("Not valid input or output types for 'default' interpretation")
|
||||||
|
|
||||||
|
return default_interpretation
|
||||||
|
|
@ -9,7 +9,7 @@ from flask import Flask, request, jsonify, abort, send_file, render_template
|
|||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from distutils import dir_util
|
from distutils import dir_util
|
||||||
from gradio import inputs, outputs
|
import gradio as gr
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
from gradio.tunneling import create_tunnel
|
from gradio.tunneling import create_tunnel
|
||||||
@ -18,7 +18,7 @@ from shutil import copyfile
|
|||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
|
import copy
|
||||||
|
|
||||||
INITIAL_PORT_VALUE = int(os.getenv(
|
INITIAL_PORT_VALUE = int(os.getenv(
|
||||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||||
@ -72,17 +72,19 @@ def get_first_available_port(initial, final):
|
|||||||
|
|
||||||
|
|
||||||
@app.route("/", methods=["GET"])
|
@app.route("/", methods=["GET"])
|
||||||
def gradio():
|
def main():
|
||||||
return render_template("index.html",
|
return render_template("index.html",
|
||||||
title=app.app_globals["title"],
|
title=app.app_globals["title"],
|
||||||
description=app.app_globals["description"],
|
description=app.app_globals["description"],
|
||||||
thumbnail=app.app_globals["thumbnail"],
|
thumbnail=app.app_globals["thumbnail"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/config/", methods=["GET"])
|
@app.route("/config/", methods=["GET"])
|
||||||
def config():
|
def config():
|
||||||
return jsonify(app.app_globals["config"])
|
return jsonify(app.app_globals["config"])
|
||||||
|
|
||||||
|
|
||||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||||
def enable_sharing(path):
|
def enable_sharing(path):
|
||||||
if path == "None":
|
if path == "None":
|
||||||
@ -90,6 +92,7 @@ def enable_sharing(path):
|
|||||||
app.app_globals["config"]["share_url"] = path
|
app.app_globals["config"]["share_url"] = path
|
||||||
return jsonify(success=True)
|
return jsonify(success=True)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/predict/", methods=["POST"])
|
@app.route("/api/predict/", methods=["POST"])
|
||||||
def predict():
|
def predict():
|
||||||
raw_input = request.json["data"]
|
raw_input = request.json["data"]
|
||||||
@ -97,6 +100,7 @@ def predict():
|
|||||||
output = {"data": prediction, "durations": durations}
|
output = {"data": prediction, "durations": durations}
|
||||||
return jsonify(output)
|
return jsonify(output)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/flag/", methods=["POST"])
|
@app.route("/api/flag/", methods=["POST"])
|
||||||
def flag():
|
def flag():
|
||||||
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
||||||
@ -130,6 +134,25 @@ def flag():
|
|||||||
)
|
)
|
||||||
return jsonify(success=True)
|
return jsonify(success=True)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/interpret/", methods=["POST"])
|
||||||
|
def interpret():
|
||||||
|
raw_input = request.json["data"]
|
||||||
|
if app.interface.interpretation == "default":
|
||||||
|
interpreter = gr.interpretation.default()
|
||||||
|
processed_input = []
|
||||||
|
for i, x in enumerate(raw_input):
|
||||||
|
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||||
|
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
|
||||||
|
processed_input.append(input_interface.preprocess(x))
|
||||||
|
else:
|
||||||
|
processed_input = [input_interface.preprocess(raw_input[i])
|
||||||
|
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||||
|
interpreter = app.interface.interpretation
|
||||||
|
interpretation = interpreter(app.interface, processed_input)
|
||||||
|
return jsonify(interpretation)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/file/<path:path>", methods=["GET"])
|
@app.route("/file/<path:path>", methods=["GET"])
|
||||||
def file(path):
|
def file(path):
|
||||||
return send_file(os.path.join(os.getcwd(), path))
|
return send_file(os.path.join(os.getcwd(), path))
|
||||||
|
@ -84,9 +84,6 @@ input.submit {
|
|||||||
input.submit:hover {
|
input.submit:hover {
|
||||||
background-color: #f39c12;
|
background-color: #f39c12;
|
||||||
}
|
}
|
||||||
.flag {
|
|
||||||
visibility: hidden;
|
|
||||||
}
|
|
||||||
.flagged {
|
.flagged {
|
||||||
background-color: pink !important;
|
background-color: pink !important;
|
||||||
}
|
}
|
||||||
@ -111,9 +108,6 @@ input.submit:hover {
|
|||||||
.invisible {
|
.invisible {
|
||||||
display: none !important;
|
display: none !important;
|
||||||
}
|
}
|
||||||
.screenshot {
|
|
||||||
visibility: hidden;
|
|
||||||
}
|
|
||||||
.screenshot_logo {
|
.screenshot_logo {
|
||||||
display: none;
|
display: none;
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
|
@ -25,14 +25,10 @@
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
border: none;
|
border: none;
|
||||||
opacity: 1;
|
opacity: 1;
|
||||||
|
transition: opacity 0.2s ease;
|
||||||
}
|
}
|
||||||
.saliency > div {
|
.saliency:hover {
|
||||||
display: flex;
|
opacity: 0.4;
|
||||||
flex-grow: 1;
|
|
||||||
}
|
|
||||||
.saliency > div > div {
|
|
||||||
flex-grow: 1;
|
|
||||||
background-color: #e67e22;
|
|
||||||
}
|
}
|
||||||
.image_preview {
|
.image_preview {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
|
@ -82,9 +82,26 @@ var io_master_template = {
|
|||||||
data: JSON.stringify(post_data),
|
data: JSON.stringify(post_data),
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
contentType: 'application/json; charset=utf-8',
|
contentType: 'application/json; charset=utf-8',
|
||||||
success: function(output){
|
});
|
||||||
console.log("Flagging successful")
|
},
|
||||||
},
|
interpret: function() {
|
||||||
|
var io = this;
|
||||||
|
this.target.find(".loading").removeClass("invisible");
|
||||||
|
this.target.find(".loading_in_progress").show();
|
||||||
|
var post_data = {
|
||||||
|
'data': this.last_input
|
||||||
|
}
|
||||||
|
$.ajax({type: "POST",
|
||||||
|
url: "/api/interpret/",
|
||||||
|
data: JSON.stringify(post_data),
|
||||||
|
dataType: 'json',
|
||||||
|
contentType: 'application/json; charset=utf-8',
|
||||||
|
success: function(data) {
|
||||||
|
for (let [idx, interpretation] of data.entries()) {
|
||||||
|
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||||
|
}
|
||||||
|
io.target.find(".loading_in_progress").hide();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
io_master.last_output = null;
|
io_master.last_output = null;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (config["allow_screenshot"]) {
|
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
||||||
target.find(".screenshot").css("visibility", "visible");
|
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
|
||||||
}
|
} else {
|
||||||
if (config["allow_flagging"]) {
|
if (!config["allow_screenshot"]) {
|
||||||
target.find(".flag").css("visibility", "visible");
|
target.find(".screenshot").hide();
|
||||||
}
|
}
|
||||||
if (config["allow_interpretation"]) {
|
if (!config["allow_flagging"]) {
|
||||||
target.find(".interpret").css("visibility", "visible");
|
target.find(".flag").hide();
|
||||||
|
}
|
||||||
|
if (!config["allow_interpretation"]) {
|
||||||
|
target.find(".interpret").hide();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (config["examples"]) {
|
if (config["examples"]) {
|
||||||
target.find(".examples").removeClass("invisible");
|
target.find(".examples").removeClass("invisible");
|
||||||
@ -231,12 +235,15 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
target.find(".flag").click(function() {
|
target.find(".flag").click(function() {
|
||||||
if (io_master.last_output) {
|
if (io_master.last_output) {
|
||||||
target.find(".flag").addClass("flagged");
|
target.find(".flag").addClass("flagged");
|
||||||
target.find(".flag").val("FLAGGED");
|
target.find(".flag").val("FLAGGED");
|
||||||
io_master.flag();
|
io_master.flag();
|
||||||
|
}
|
||||||
// io_master.flag($(".flag_message").val());
|
})
|
||||||
|
target.find(".interpret").click(function() {
|
||||||
|
if (io_master.last_output) {
|
||||||
|
io_master.interpret();
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -29,6 +29,9 @@ const image_input = {
|
|||||||
<div class="image_preview_holder">
|
<div class="image_preview_holder">
|
||||||
<img class="image_preview" />
|
<img class="image_preview" />
|
||||||
</div>
|
</div>
|
||||||
|
<div class="saliency_holder hide">
|
||||||
|
<canvas class="saliency"></canvas>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
|
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
|
||||||
@ -180,6 +183,19 @@ const image_input = {
|
|||||||
this.cropper.destroy();
|
this.cropper.destroy();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
this.target.find(".saliency_holder").addClass("hide");
|
||||||
|
},
|
||||||
|
show_interpretation: function(data) {
|
||||||
|
if (this.target.find(".image_preview").attr("src")) {
|
||||||
|
var img = this.target.find(".image_preview")[0];
|
||||||
|
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
|
||||||
|
var width = size.width;
|
||||||
|
var height = size.height;
|
||||||
|
this.target.find(".saliency_holder").removeClass("hide").html(`
|
||||||
|
<canvas class="saliency" width=${width} height=${height}></canvas>`);
|
||||||
|
var ctx = this.target.find(".saliency")[0].getContext('2d');
|
||||||
|
paintSaliency(data, ctx, width, height);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
state: "NO_IMAGE",
|
state: "NO_IMAGE",
|
||||||
image_data: null,
|
image_data: null,
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
const textbox_input = {
|
const textbox_input = {
|
||||||
html: `<textarea class="input_text"></textarea>`,
|
html: `
|
||||||
|
<textarea class="input_text"></textarea>
|
||||||
|
<div class="output_text"></div>
|
||||||
|
`,
|
||||||
one_line_html: `<input type="text" class="input_text">`,
|
one_line_html: `<input type="text" class="input_text">`,
|
||||||
init: function(opts) {
|
init: function(opts) {
|
||||||
if (opts.lines > 1) {
|
if (opts.lines > 1) {
|
||||||
@ -13,6 +16,7 @@ const textbox_input = {
|
|||||||
if (opts.default) {
|
if (opts.default) {
|
||||||
this.target.find(".input_text").val(opts.default)
|
this.target.find(".input_text").val(opts.default)
|
||||||
}
|
}
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
},
|
},
|
||||||
submit: function() {
|
submit: function() {
|
||||||
text = this.target.find(".input_text").val();
|
text = this.target.find(".input_text").val();
|
||||||
@ -20,6 +24,24 @@ const textbox_input = {
|
|||||||
},
|
},
|
||||||
clear: function() {
|
clear: function() {
|
||||||
this.target.find(".input_text").val("");
|
this.target.find(".input_text").val("");
|
||||||
|
this.target.find(".output_text").empty();
|
||||||
|
this.target.find(".input_text").show();
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
|
},
|
||||||
|
show_interpretation: function(data) {
|
||||||
|
this.target.find(".input_text").hide();
|
||||||
|
this.target.find(".output_text").show();
|
||||||
|
let html = "";
|
||||||
|
for (let char_set of data) {
|
||||||
|
[char, value] = char_set;
|
||||||
|
if (value < 0) {
|
||||||
|
color = "8,241,255," + (-value * 2);
|
||||||
|
} else {
|
||||||
|
color = "230,126,34," + value * 2;
|
||||||
|
}
|
||||||
|
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
|
||||||
|
}
|
||||||
|
this.target.find(".output_text").html(html);
|
||||||
},
|
},
|
||||||
load_example: function(data) {
|
load_example: function(data) {
|
||||||
this.target.find(".input_text").val(data);
|
this.target.find(".input_text").val(data);
|
||||||
|
@ -54,22 +54,19 @@ function toStringIfObject(input) {
|
|||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
function paintSaliency(data, width, height, ctx) {
|
function paintSaliency(data, ctx, width, height) {
|
||||||
var cell_width = width / data[0].length
|
var cell_width = width / data[0].length
|
||||||
var cell_height = height / data.length
|
var cell_height = height / data.length
|
||||||
var r = 0
|
var r = 0
|
||||||
data.forEach(function(row) {
|
data.forEach(function(row) {
|
||||||
var c = 0
|
var c = 0
|
||||||
row.forEach(function(cell) {
|
row.forEach(function(cell) {
|
||||||
if (cell < 0.25) {
|
if (cell < 0) {
|
||||||
ctx.fillStyle = "white";
|
var color = [7,47,95];
|
||||||
} else if (cell < 0.5) {
|
|
||||||
ctx.fillStyle = "#add8ed";
|
|
||||||
} else if (cell < 0.75) {
|
|
||||||
ctx.fillStyle = "#5aa7d3";
|
|
||||||
} else {
|
} else {
|
||||||
ctx.fillStyle = "#072F5F";
|
var color = [112,62,8];
|
||||||
}
|
}
|
||||||
|
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
|
||||||
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
|
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
|
||||||
c++;
|
c++;
|
||||||
})
|
})
|
||||||
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
|
||||||
|
var doRatio = width / height;
|
||||||
|
var cRatio = containerWidth / containerHeight;
|
||||||
|
var targetWidth = 0;
|
||||||
|
var targetHeight = 0;
|
||||||
|
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
|
||||||
|
|
||||||
|
if (test) {
|
||||||
|
targetWidth = containerWidth;
|
||||||
|
targetHeight = targetWidth / doRatio;
|
||||||
|
} else {
|
||||||
|
targetHeight = containerHeight;
|
||||||
|
targetWidth = targetHeight * doRatio;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
width: targetWidth,
|
||||||
|
height: targetHeight,
|
||||||
|
x: (containerWidth - targetWidth) / 2,
|
||||||
|
y: (containerHeight - targetHeight) / 2
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// val should be in the range [0.0, 1.0]
|
// val should be in the range [0.0, 1.0]
|
||||||
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
|
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
|
||||||
function interpolate(val, rgb1, rgb2) {
|
function interpolate(val, rgb1, rgb2) {
|
||||||
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
|
|||||||
return rgb;
|
return rgb;
|
||||||
}
|
}
|
||||||
|
|
||||||
// quick helper function to convert the array into something we can use for css
|
|
||||||
function colorToString(rgb) {
|
function colorToString(rgb) {
|
||||||
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
|
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
|
||||||
}
|
}
|
||||||
|
48
demo/image_classifier.py
Normal file
48
demo/image_classifier.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import tensorflow as tf
|
||||||
|
# from vis.utils import utils
|
||||||
|
# from vis.visualization import visualize_cam
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
|
# # Download human-readable labels for ImageNet.
|
||||||
|
# response = requests.get("https://git.io/JJkYN")
|
||||||
|
# labels = response.text.split("\n")
|
||||||
|
labels = range(1000) # comment this later
|
||||||
|
|
||||||
|
mobile_net = tf.keras.applications.MobileNetV2()
|
||||||
|
|
||||||
|
|
||||||
|
def image_classifier(im):
|
||||||
|
arr = np.expand_dims(im, axis=0)
|
||||||
|
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
|
||||||
|
prediction = mobile_net.predict(arr).flatten()
|
||||||
|
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||||
|
|
||||||
|
def image_explain(im):
|
||||||
|
model.layers[-1].activation = tf.keras.activations.linear
|
||||||
|
model = utils.apply_modifications(model)
|
||||||
|
penultimate_layer_idx = 2
|
||||||
|
class_idx = class_idxs_sorted[0]
|
||||||
|
seed_input = img
|
||||||
|
grad_top1 = visualize_cam(model, layer_idx, class_idx, seed_input,
|
||||||
|
penultimate_layer_idx = penultimate_layer_idx,#None,
|
||||||
|
backprop_modifier = None,
|
||||||
|
grad_modifier = None)
|
||||||
|
print(grad_top_1)
|
||||||
|
return grad_top1
|
||||||
|
|
||||||
|
|
||||||
|
image = gr.inputs.Image(shape=(224, 224))
|
||||||
|
label = gr.outputs.Label(num_top_classes=3)
|
||||||
|
|
||||||
|
gr.Interface(image_classifier, image, label,
|
||||||
|
capture_session=True,
|
||||||
|
interpretation="default",
|
||||||
|
examples=[
|
||||||
|
["images/cheetah1.jpg"],
|
||||||
|
["images/lion.jpg"]
|
||||||
|
]).launch();
|
13
demo/longest_word.py
Normal file
13
demo/longest_word.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
def longest_word(text):
|
||||||
|
words = text.split(" ")
|
||||||
|
lengths = [len(word) for word in words]
|
||||||
|
return max(lengths)
|
||||||
|
|
||||||
|
ex = "The quick brown fox jumped over the lazy dog."
|
||||||
|
|
||||||
|
io = gr.Interface(longest_word, "textbox", "label", interpretation="default", examples=[[ex]])
|
||||||
|
|
||||||
|
io.test_launch()
|
||||||
|
io.launch()
|
15
demo/sentiment_analysis.py
Normal file
15
demo/sentiment_analysis.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import nltk
|
||||||
|
from nltk.sentiment.vader import SentimentIntensityAnalyzer
|
||||||
|
nltk.download('vader_lexicon')
|
||||||
|
sid = SentimentIntensityAnalyzer()
|
||||||
|
|
||||||
|
def sentiment_analysis(text):
|
||||||
|
scores = sid.polarity_scores(text)
|
||||||
|
del scores["compound"]
|
||||||
|
return scores
|
||||||
|
|
||||||
|
io = gr.Interface(sentiment_analysis, "textbox", "label", interpretation="default")
|
||||||
|
|
||||||
|
io.test_launch()
|
||||||
|
io.launch()
|
@ -5,6 +5,7 @@ gradio/__init__.py
|
|||||||
gradio/component.py
|
gradio/component.py
|
||||||
gradio/inputs.py
|
gradio/inputs.py
|
||||||
gradio/interface.py
|
gradio/interface.py
|
||||||
|
gradio/interpretation.py
|
||||||
gradio/networking.py
|
gradio/networking.py
|
||||||
gradio/notebook.py
|
gradio/notebook.py
|
||||||
gradio/outputs.py
|
gradio/outputs.py
|
||||||
|
@ -5,10 +5,10 @@ interface using the input and output types.
|
|||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
|
||||||
from gradio.inputs import InputComponent
|
from gradio.inputs import InputComponent
|
||||||
from gradio.outputs import OutputComponent
|
from gradio.outputs import OutputComponent
|
||||||
from gradio import networking, strings, utils
|
from gradio import networking, strings, utils
|
||||||
|
import gradio.interpretation
|
||||||
import requests
|
import requests
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
@ -43,8 +43,9 @@ class Interface:
|
|||||||
|
|
||||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||||
live=False, show_input=True, show_output=True,
|
live=False, show_input=True, show_output=True,
|
||||||
capture_session=False, title=None, description=None,
|
capture_session=False, interpretation=None, title=None,
|
||||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
description=None, thumbnail=None, server_port=None,
|
||||||
|
server_name=networking.LOCALHOST_NAME,
|
||||||
allow_screenshot=True, allow_flagging=True,
|
allow_screenshot=True, allow_flagging=True,
|
||||||
flagging_dir="flagged", analytics_enabled=True):
|
flagging_dir="flagged", analytics_enabled=True):
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ class Interface:
|
|||||||
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
||||||
live (bool): whether the interface should automatically reload on change.
|
live (bool): whether the interface should automatically reload on change.
|
||||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||||
|
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||||
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
||||||
@ -98,6 +100,7 @@ class Interface:
|
|||||||
if not isinstance(fn, list):
|
if not isinstance(fn, list):
|
||||||
fn = [fn]
|
fn = [fn]
|
||||||
|
|
||||||
|
|
||||||
self.output_interfaces *= len(fn)
|
self.output_interfaces *= len(fn)
|
||||||
self.predict = fn
|
self.predict = fn
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -107,6 +110,7 @@ class Interface:
|
|||||||
self.show_output = show_output
|
self.show_output = show_output
|
||||||
self.flag_hash = random.getrandbits(32)
|
self.flag_hash = random.getrandbits(32)
|
||||||
self.capture_session = capture_session
|
self.capture_session = capture_session
|
||||||
|
self.interpretation = interpretation
|
||||||
self.session = None
|
self.session = None
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.title = title
|
self.title = title
|
||||||
@ -175,6 +179,7 @@ class Interface:
|
|||||||
"thumbnail": self.thumbnail,
|
"thumbnail": self.thumbnail,
|
||||||
"allow_screenshot": self.allow_screenshot,
|
"allow_screenshot": self.allow_screenshot,
|
||||||
"allow_flagging": self.allow_flagging,
|
"allow_flagging": self.allow_flagging,
|
||||||
|
"allow_interpretation": self.interpretation is not None
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||||
@ -187,8 +192,8 @@ class Interface:
|
|||||||
iface[1]["label"] = ret_name
|
iface[1]["label"] = ret_name
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
processed_examples = []
|
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
|
processed_examples = []
|
||||||
for example_set in self.examples:
|
for example_set in self.examples:
|
||||||
processed_set = []
|
processed_set = []
|
||||||
for iface, example in zip(self.input_interfaces, example_set):
|
for iface, example in zip(self.input_interfaces, example_set):
|
||||||
@ -197,19 +202,7 @@ class Interface:
|
|||||||
config["examples"] = processed_examples
|
config["examples"] = processed_examples
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def process(self, raw_input, predict_fn=None):
|
def run_prediction(self, processed_input, return_duration=False):
|
||||||
"""
|
|
||||||
:param raw_input: a list of raw inputs to process and apply the
|
|
||||||
prediction(s) on.
|
|
||||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
|
||||||
:return:
|
|
||||||
processed output: a list of processed outputs to return as the
|
|
||||||
prediction(s).
|
|
||||||
duration: a list of time deltas measuring inference time for each
|
|
||||||
prediction fn.
|
|
||||||
"""
|
|
||||||
processed_input = [input_interface.preprocess(raw_input[i])
|
|
||||||
for i, input_interface in enumerate(self.input_interfaces)]
|
|
||||||
predictions = []
|
predictions = []
|
||||||
durations = []
|
durations = []
|
||||||
for predict_fn in self.predict:
|
for predict_fn in self.predict:
|
||||||
@ -239,6 +232,27 @@ class Interface:
|
|||||||
prediction = [prediction]
|
prediction = [prediction]
|
||||||
durations.append(duration)
|
durations.append(duration)
|
||||||
predictions.extend(prediction)
|
predictions.extend(prediction)
|
||||||
|
|
||||||
|
if return_duration:
|
||||||
|
return predictions, durations
|
||||||
|
else:
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def process(self, raw_input, predict_fn=None):
|
||||||
|
"""
|
||||||
|
:param raw_input: a list of raw inputs to process and apply the
|
||||||
|
prediction(s) on.
|
||||||
|
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||||
|
:return:
|
||||||
|
processed output: a list of processed outputs to return as the
|
||||||
|
prediction(s).
|
||||||
|
duration: a list of time deltas measuring inference time for each
|
||||||
|
prediction fn.
|
||||||
|
"""
|
||||||
|
processed_input = [input_interface.preprocess(raw_input[i])
|
||||||
|
for i, input_interface in enumerate(self.input_interfaces)]
|
||||||
|
predictions, durations = self.run_prediction(processed_input, return_duration=True)
|
||||||
processed_output = [output_interface.postprocess(
|
processed_output = [output_interface.postprocess(
|
||||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||||
return processed_output, durations
|
return processed_output, durations
|
||||||
@ -396,7 +410,6 @@ class Interface:
|
|||||||
|
|
||||||
return app, path_to_local_server, share_url
|
return app, path_to_local_server, share_url
|
||||||
|
|
||||||
|
|
||||||
def reset_all():
|
def reset_all():
|
||||||
for io in Interface.get_instances():
|
for io in Interface.get_instances():
|
||||||
io.close()
|
io.close()
|
||||||
|
103
gradio/interpretation.py
Normal file
103
gradio/interpretation.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from gradio.inputs import Image, Textbox
|
||||||
|
from gradio.outputs import Label
|
||||||
|
from gradio import processing_utils
|
||||||
|
from skimage.segmentation import slic
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
expected_types = {
|
||||||
|
Image: "numpy",
|
||||||
|
Textbox: "str"
|
||||||
|
}
|
||||||
|
|
||||||
|
def default(separator=" ", n_segments=20):
|
||||||
|
"""
|
||||||
|
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
|
||||||
|
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
|
||||||
|
inputs and outputs, uses the first component.
|
||||||
|
"""
|
||||||
|
def tokenize_text(text):
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
tokens = text.split(separator)
|
||||||
|
for idx, _ in enumerate(tokens):
|
||||||
|
new_token_array = tokens.copy()
|
||||||
|
del new_token_array[idx]
|
||||||
|
leave_one_out_tokens.append(new_token_array)
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def tokenize_image(image):
|
||||||
|
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
replace_color = np.mean(image, axis=(0, 1))
|
||||||
|
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||||
|
mask = segments_slic == segVal
|
||||||
|
white_screen = np.copy(image)
|
||||||
|
white_screen[segments_slic == segVal] = replace_color
|
||||||
|
leave_one_out_tokens.append((mask, white_screen))
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def score_text(interface, leave_one_out_tokens, text):
|
||||||
|
tokens = text.split(separator)
|
||||||
|
original_output = interface.run_prediction([text])
|
||||||
|
|
||||||
|
scores_by_words = []
|
||||||
|
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||||
|
perturbed_text = separator.join(input_text)
|
||||||
|
perturbed_output = interface.run_prediction([perturbed_text])
|
||||||
|
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||||
|
scores_by_words.append(score)
|
||||||
|
|
||||||
|
scores_by_char = []
|
||||||
|
for idx, token in enumerate(tokens):
|
||||||
|
if idx != 0:
|
||||||
|
scores_by_char.append((" ", 0))
|
||||||
|
for char in token:
|
||||||
|
scores_by_char.append((char, scores_by_words[idx]))
|
||||||
|
|
||||||
|
return scores_by_char
|
||||||
|
|
||||||
|
def score_image(interface, leave_one_out_tokens, image):
|
||||||
|
output_scores = np.zeros((image.shape[0], image.shape[1]))
|
||||||
|
original_output = interface.run_prediction([image])
|
||||||
|
|
||||||
|
for mask, perturbed_image in leave_one_out_tokens:
|
||||||
|
perturbed_output = interface.run_prediction([perturbed_image])
|
||||||
|
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||||
|
output_scores += score * mask
|
||||||
|
|
||||||
|
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
||||||
|
if max_val > 0:
|
||||||
|
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||||
|
return output_scores.tolist()
|
||||||
|
|
||||||
|
def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||||
|
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
|
||||||
|
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
|
||||||
|
original_label = post_original_output["label"]
|
||||||
|
perturbed_label = post_perturbed_output["label"]
|
||||||
|
|
||||||
|
# Handle different return types of Label interface
|
||||||
|
if "confidences" in post_original_output:
|
||||||
|
original_confidence = original_output[0][original_label]
|
||||||
|
perturbed_confidence = perturbed_output[0][original_label]
|
||||||
|
score = original_confidence - perturbed_confidence
|
||||||
|
else:
|
||||||
|
try: # try computing numerical difference
|
||||||
|
score = float(original_label) - float(perturbed_label)
|
||||||
|
except ValueError: # otherwise, look at strict difference in label
|
||||||
|
score = int(not(perturbed_label == original_label))
|
||||||
|
return score
|
||||||
|
|
||||||
|
def default_interpretation(interface, x):
|
||||||
|
if isinstance(interface.input_interfaces[0], Textbox) \
|
||||||
|
and isinstance(interface.output_interfaces[0], Label):
|
||||||
|
leave_one_out_tokens = tokenize_text(x[0])
|
||||||
|
return [score_text(interface, leave_one_out_tokens, x[0])]
|
||||||
|
if isinstance(interface.input_interfaces[0], Image) \
|
||||||
|
and isinstance(interface.output_interfaces[0], Label):
|
||||||
|
leave_one_out_tokens = tokenize_image(x[0])
|
||||||
|
return [score_image(interface, leave_one_out_tokens, x[0])]
|
||||||
|
else:
|
||||||
|
print("Not valid input or output types for 'default' interpretation")
|
||||||
|
|
||||||
|
return default_interpretation
|
||||||
|
|
@ -9,7 +9,7 @@ from flask import Flask, request, jsonify, abort, send_file, render_template
|
|||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from distutils import dir_util
|
from distutils import dir_util
|
||||||
from gradio import inputs, outputs
|
import gradio as gr
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
from gradio.tunneling import create_tunnel
|
from gradio.tunneling import create_tunnel
|
||||||
@ -18,7 +18,7 @@ from shutil import copyfile
|
|||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
|
import copy
|
||||||
|
|
||||||
INITIAL_PORT_VALUE = int(os.getenv(
|
INITIAL_PORT_VALUE = int(os.getenv(
|
||||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||||
@ -72,17 +72,19 @@ def get_first_available_port(initial, final):
|
|||||||
|
|
||||||
|
|
||||||
@app.route("/", methods=["GET"])
|
@app.route("/", methods=["GET"])
|
||||||
def gradio():
|
def main():
|
||||||
return render_template("index.html",
|
return render_template("index.html",
|
||||||
title=app.app_globals["title"],
|
title=app.app_globals["title"],
|
||||||
description=app.app_globals["description"],
|
description=app.app_globals["description"],
|
||||||
thumbnail=app.app_globals["thumbnail"],
|
thumbnail=app.app_globals["thumbnail"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/config/", methods=["GET"])
|
@app.route("/config/", methods=["GET"])
|
||||||
def config():
|
def config():
|
||||||
return jsonify(app.app_globals["config"])
|
return jsonify(app.app_globals["config"])
|
||||||
|
|
||||||
|
|
||||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||||
def enable_sharing(path):
|
def enable_sharing(path):
|
||||||
if path == "None":
|
if path == "None":
|
||||||
@ -90,6 +92,7 @@ def enable_sharing(path):
|
|||||||
app.app_globals["config"]["share_url"] = path
|
app.app_globals["config"]["share_url"] = path
|
||||||
return jsonify(success=True)
|
return jsonify(success=True)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/predict/", methods=["POST"])
|
@app.route("/api/predict/", methods=["POST"])
|
||||||
def predict():
|
def predict():
|
||||||
raw_input = request.json["data"]
|
raw_input = request.json["data"]
|
||||||
@ -97,6 +100,7 @@ def predict():
|
|||||||
output = {"data": prediction, "durations": durations}
|
output = {"data": prediction, "durations": durations}
|
||||||
return jsonify(output)
|
return jsonify(output)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/flag/", methods=["POST"])
|
@app.route("/api/flag/", methods=["POST"])
|
||||||
def flag():
|
def flag():
|
||||||
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
||||||
@ -130,6 +134,25 @@ def flag():
|
|||||||
)
|
)
|
||||||
return jsonify(success=True)
|
return jsonify(success=True)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/interpret/", methods=["POST"])
|
||||||
|
def interpret():
|
||||||
|
raw_input = request.json["data"]
|
||||||
|
if app.interface.interpretation == "default":
|
||||||
|
interpreter = gr.interpretation.default()
|
||||||
|
processed_input = []
|
||||||
|
for i, x in enumerate(raw_input):
|
||||||
|
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||||
|
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
|
||||||
|
processed_input.append(input_interface.preprocess(x))
|
||||||
|
else:
|
||||||
|
processed_input = [input_interface.preprocess(raw_input[i])
|
||||||
|
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||||
|
interpreter = app.interface.interpretation
|
||||||
|
interpretation = interpreter(app.interface, processed_input)
|
||||||
|
return jsonify(interpretation)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/file/<path:path>", methods=["GET"])
|
@app.route("/file/<path:path>", methods=["GET"])
|
||||||
def file(path):
|
def file(path):
|
||||||
return send_file(os.path.join(os.getcwd(), path))
|
return send_file(os.path.join(os.getcwd(), path))
|
||||||
|
@ -84,9 +84,6 @@ input.submit {
|
|||||||
input.submit:hover {
|
input.submit:hover {
|
||||||
background-color: #f39c12;
|
background-color: #f39c12;
|
||||||
}
|
}
|
||||||
.flag {
|
|
||||||
visibility: hidden;
|
|
||||||
}
|
|
||||||
.flagged {
|
.flagged {
|
||||||
background-color: pink !important;
|
background-color: pink !important;
|
||||||
}
|
}
|
||||||
@ -111,9 +108,6 @@ input.submit:hover {
|
|||||||
.invisible {
|
.invisible {
|
||||||
display: none !important;
|
display: none !important;
|
||||||
}
|
}
|
||||||
.screenshot {
|
|
||||||
visibility: hidden;
|
|
||||||
}
|
|
||||||
.screenshot_logo {
|
.screenshot_logo {
|
||||||
display: none;
|
display: none;
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
|
@ -25,14 +25,10 @@
|
|||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
border: none;
|
border: none;
|
||||||
opacity: 1;
|
opacity: 1;
|
||||||
|
transition: opacity 0.2s ease;
|
||||||
}
|
}
|
||||||
.saliency > div {
|
.saliency:hover {
|
||||||
display: flex;
|
opacity: 0.4;
|
||||||
flex-grow: 1;
|
|
||||||
}
|
|
||||||
.saliency > div > div {
|
|
||||||
flex-grow: 1;
|
|
||||||
background-color: #e67e22;
|
|
||||||
}
|
}
|
||||||
.image_preview {
|
.image_preview {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
|
@ -82,9 +82,26 @@ var io_master_template = {
|
|||||||
data: JSON.stringify(post_data),
|
data: JSON.stringify(post_data),
|
||||||
dataType: 'json',
|
dataType: 'json',
|
||||||
contentType: 'application/json; charset=utf-8',
|
contentType: 'application/json; charset=utf-8',
|
||||||
success: function(output){
|
});
|
||||||
console.log("Flagging successful")
|
},
|
||||||
},
|
interpret: function() {
|
||||||
|
var io = this;
|
||||||
|
this.target.find(".loading").removeClass("invisible");
|
||||||
|
this.target.find(".loading_in_progress").show();
|
||||||
|
var post_data = {
|
||||||
|
'data': this.last_input
|
||||||
|
}
|
||||||
|
$.ajax({type: "POST",
|
||||||
|
url: "/api/interpret/",
|
||||||
|
data: JSON.stringify(post_data),
|
||||||
|
dataType: 'json',
|
||||||
|
contentType: 'application/json; charset=utf-8',
|
||||||
|
success: function(data) {
|
||||||
|
for (let [idx, interpretation] of data.entries()) {
|
||||||
|
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||||
|
}
|
||||||
|
io.target.find(".loading_in_progress").hide();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
io_master.last_output = null;
|
io_master.last_output = null;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (config["allow_screenshot"]) {
|
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
||||||
target.find(".screenshot").css("visibility", "visible");
|
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
|
||||||
}
|
} else {
|
||||||
if (config["allow_flagging"]) {
|
if (!config["allow_screenshot"]) {
|
||||||
target.find(".flag").css("visibility", "visible");
|
target.find(".screenshot").hide();
|
||||||
}
|
}
|
||||||
if (config["allow_interpretation"]) {
|
if (!config["allow_flagging"]) {
|
||||||
target.find(".interpret").css("visibility", "visible");
|
target.find(".flag").hide();
|
||||||
|
}
|
||||||
|
if (!config["allow_interpretation"]) {
|
||||||
|
target.find(".interpret").hide();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (config["examples"]) {
|
if (config["examples"]) {
|
||||||
target.find(".examples").removeClass("invisible");
|
target.find(".examples").removeClass("invisible");
|
||||||
@ -231,12 +235,15 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
target.find(".flag").click(function() {
|
target.find(".flag").click(function() {
|
||||||
if (io_master.last_output) {
|
if (io_master.last_output) {
|
||||||
target.find(".flag").addClass("flagged");
|
target.find(".flag").addClass("flagged");
|
||||||
target.find(".flag").val("FLAGGED");
|
target.find(".flag").val("FLAGGED");
|
||||||
io_master.flag();
|
io_master.flag();
|
||||||
|
}
|
||||||
// io_master.flag($(".flag_message").val());
|
})
|
||||||
|
target.find(".interpret").click(function() {
|
||||||
|
if (io_master.last_output) {
|
||||||
|
io_master.interpret();
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -29,6 +29,9 @@ const image_input = {
|
|||||||
<div class="image_preview_holder">
|
<div class="image_preview_holder">
|
||||||
<img class="image_preview" />
|
<img class="image_preview" />
|
||||||
</div>
|
</div>
|
||||||
|
<div class="saliency_holder hide">
|
||||||
|
<canvas class="saliency"></canvas>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
|
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
|
||||||
@ -180,6 +183,19 @@ const image_input = {
|
|||||||
this.cropper.destroy();
|
this.cropper.destroy();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
this.target.find(".saliency_holder").addClass("hide");
|
||||||
|
},
|
||||||
|
show_interpretation: function(data) {
|
||||||
|
if (this.target.find(".image_preview").attr("src")) {
|
||||||
|
var img = this.target.find(".image_preview")[0];
|
||||||
|
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight)
|
||||||
|
var width = size.width;
|
||||||
|
var height = size.height;
|
||||||
|
this.target.find(".saliency_holder").removeClass("hide").html(`
|
||||||
|
<canvas class="saliency" width=${width} height=${height}></canvas>`);
|
||||||
|
var ctx = this.target.find(".saliency")[0].getContext('2d');
|
||||||
|
paintSaliency(data, ctx, width, height);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
state: "NO_IMAGE",
|
state: "NO_IMAGE",
|
||||||
image_data: null,
|
image_data: null,
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
const textbox_input = {
|
const textbox_input = {
|
||||||
html: `<textarea class="input_text"></textarea>`,
|
html: `
|
||||||
|
<textarea class="input_text"></textarea>
|
||||||
|
<div class="output_text"></div>
|
||||||
|
`,
|
||||||
one_line_html: `<input type="text" class="input_text">`,
|
one_line_html: `<input type="text" class="input_text">`,
|
||||||
init: function(opts) {
|
init: function(opts) {
|
||||||
if (opts.lines > 1) {
|
if (opts.lines > 1) {
|
||||||
@ -13,6 +16,7 @@ const textbox_input = {
|
|||||||
if (opts.default) {
|
if (opts.default) {
|
||||||
this.target.find(".input_text").val(opts.default)
|
this.target.find(".input_text").val(opts.default)
|
||||||
}
|
}
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
},
|
},
|
||||||
submit: function() {
|
submit: function() {
|
||||||
text = this.target.find(".input_text").val();
|
text = this.target.find(".input_text").val();
|
||||||
@ -20,6 +24,24 @@ const textbox_input = {
|
|||||||
},
|
},
|
||||||
clear: function() {
|
clear: function() {
|
||||||
this.target.find(".input_text").val("");
|
this.target.find(".input_text").val("");
|
||||||
|
this.target.find(".output_text").empty();
|
||||||
|
this.target.find(".input_text").show();
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
|
},
|
||||||
|
show_interpretation: function(data) {
|
||||||
|
this.target.find(".input_text").hide();
|
||||||
|
this.target.find(".output_text").show();
|
||||||
|
let html = "";
|
||||||
|
for (let char_set of data) {
|
||||||
|
[char, value] = char_set;
|
||||||
|
if (value < 0) {
|
||||||
|
color = "8,241,255," + (-value * 2);
|
||||||
|
} else {
|
||||||
|
color = "230,126,34," + value * 2;
|
||||||
|
}
|
||||||
|
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
|
||||||
|
}
|
||||||
|
this.target.find(".output_text").html(html);
|
||||||
},
|
},
|
||||||
load_example: function(data) {
|
load_example: function(data) {
|
||||||
this.target.find(".input_text").val(data);
|
this.target.find(".input_text").val(data);
|
||||||
|
@ -54,22 +54,19 @@ function toStringIfObject(input) {
|
|||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
function paintSaliency(data, width, height, ctx) {
|
function paintSaliency(data, ctx, width, height) {
|
||||||
var cell_width = width / data[0].length
|
var cell_width = width / data[0].length
|
||||||
var cell_height = height / data.length
|
var cell_height = height / data.length
|
||||||
var r = 0
|
var r = 0
|
||||||
data.forEach(function(row) {
|
data.forEach(function(row) {
|
||||||
var c = 0
|
var c = 0
|
||||||
row.forEach(function(cell) {
|
row.forEach(function(cell) {
|
||||||
if (cell < 0.25) {
|
if (cell < 0) {
|
||||||
ctx.fillStyle = "white";
|
var color = [7,47,95];
|
||||||
} else if (cell < 0.5) {
|
|
||||||
ctx.fillStyle = "#add8ed";
|
|
||||||
} else if (cell < 0.75) {
|
|
||||||
ctx.fillStyle = "#5aa7d3";
|
|
||||||
} else {
|
} else {
|
||||||
ctx.fillStyle = "#072F5F";
|
var color = [112,62,8];
|
||||||
}
|
}
|
||||||
|
ctx.fillStyle = colorToString(interpolate(cell, [255,255,255], color));
|
||||||
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
|
ctx.fillRect(c * cell_width, r * cell_height, cell_width, cell_height);
|
||||||
c++;
|
c++;
|
||||||
})
|
})
|
||||||
@ -77,6 +74,29 @@ function paintSaliency(data, width, height, ctx) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getObjectFitSize(contains /* true = contain, false = cover */, containerWidth, containerHeight, width, height){
|
||||||
|
var doRatio = width / height;
|
||||||
|
var cRatio = containerWidth / containerHeight;
|
||||||
|
var targetWidth = 0;
|
||||||
|
var targetHeight = 0;
|
||||||
|
var test = contains ? (doRatio > cRatio) : (doRatio < cRatio);
|
||||||
|
|
||||||
|
if (test) {
|
||||||
|
targetWidth = containerWidth;
|
||||||
|
targetHeight = targetWidth / doRatio;
|
||||||
|
} else {
|
||||||
|
targetHeight = containerHeight;
|
||||||
|
targetWidth = targetHeight * doRatio;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
width: targetWidth,
|
||||||
|
height: targetHeight,
|
||||||
|
x: (containerWidth - targetWidth) / 2,
|
||||||
|
y: (containerHeight - targetHeight) / 2
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// val should be in the range [0.0, 1.0]
|
// val should be in the range [0.0, 1.0]
|
||||||
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
|
// rgb1 and rgb2 should be an array of 3 values each in the range [0, 255]
|
||||||
function interpolate(val, rgb1, rgb2) {
|
function interpolate(val, rgb1, rgb2) {
|
||||||
@ -88,7 +108,6 @@ function interpolate(val, rgb1, rgb2) {
|
|||||||
return rgb;
|
return rgb;
|
||||||
}
|
}
|
||||||
|
|
||||||
// quick helper function to convert the array into something we can use for css
|
|
||||||
function colorToString(rgb) {
|
function colorToString(rgb) {
|
||||||
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
|
return "rgb(" + rgb[0] + ", " + rgb[1] + ", " + rgb[2] + ")";
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user