mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
sentiment analysis demo
This commit is contained in:
parent
af7d0c0655
commit
3b73a5522a
59
build/lib/gradio/explain.py
Normal file
59
build/lib/gradio/explain.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from gradio.inputs import Textbox
|
||||||
|
from gradio.inputs import Image
|
||||||
|
|
||||||
|
from skimage.color import rgb2gray
|
||||||
|
from skimage.filters import sobel
|
||||||
|
from skimage.segmentation import slic
|
||||||
|
from skimage.util import img_as_float
|
||||||
|
from skimage import io
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_text(text):
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
tokens = text.split()
|
||||||
|
leave_one_out_tokens.append(tokens)
|
||||||
|
for idx, _ in enumerate(tokens):
|
||||||
|
new_token_array = tokens.copy()
|
||||||
|
del new_token_array[idx]
|
||||||
|
leave_one_out_tokens.append(new_token_array)
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def tokenize_image(image):
|
||||||
|
img = img_as_float(image[::2, ::2])
|
||||||
|
segments_slic = slic(img, n_segments=20, compactness=10, sigma=1)
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||||
|
mask = np.copy(img)
|
||||||
|
mask[segments_slic == segVal] = 255
|
||||||
|
leave_one_out_tokens.append(mask)
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def score(outputs):
|
||||||
|
print(outputs)
|
||||||
|
|
||||||
|
def simple_explanation(interface, input_interfaces,
|
||||||
|
output_interfaces, input):
|
||||||
|
if isinstance(input_interfaces[0], Textbox):
|
||||||
|
leave_one_out_tokens = tokenize_text(input[0])
|
||||||
|
outputs = []
|
||||||
|
for input_text in leave_one_out_tokens:
|
||||||
|
input_text = " ".join(input_text)
|
||||||
|
print("Input Text: ", input_text)
|
||||||
|
output = interface.process(input_text)
|
||||||
|
outputs.extend(output)
|
||||||
|
print("Output: ", output)
|
||||||
|
score(outputs)
|
||||||
|
|
||||||
|
elif isinstance(input_interfaces[0], Image):
|
||||||
|
leave_one_out_tokens = tokenize_image(input[0])
|
||||||
|
outputs = []
|
||||||
|
for input_text in leave_one_out_tokens:
|
||||||
|
input_text = " ".join(input_text)
|
||||||
|
print("Input Text: ", input_text)
|
||||||
|
output = interface.process(input_text)
|
||||||
|
outputs.extend(output)
|
||||||
|
print("Output: ", output)
|
||||||
|
score(outputs)
|
||||||
|
else:
|
||||||
|
print("Not valid input type")
|
@ -7,9 +7,13 @@ import tempfile
|
|||||||
import webbrowser
|
import webbrowser
|
||||||
|
|
||||||
from gradio.inputs import InputComponent
|
from gradio.inputs import InputComponent
|
||||||
|
from gradio.inputs import Image
|
||||||
|
from gradio.inputs import Textbox
|
||||||
from gradio.outputs import OutputComponent
|
from gradio.outputs import OutputComponent
|
||||||
from gradio import networking, strings, utils
|
from gradio import networking, strings, utils, processing_utils
|
||||||
from distutils.version import StrictVersion
|
from distutils.version import StrictVersion
|
||||||
|
from skimage.segmentation import slic
|
||||||
|
from skimage.util import img_as_float
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import requests
|
import requests
|
||||||
import random
|
import random
|
||||||
@ -20,6 +24,7 @@ import sys
|
|||||||
import weakref
|
import weakref
|
||||||
import analytics
|
import analytics
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||||
@ -46,7 +51,7 @@ class Interface:
|
|||||||
|
|
||||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||||
live=False, show_input=True, show_output=True,
|
live=False, show_input=True, show_output=True,
|
||||||
capture_session=False, title=None, description=None,
|
capture_session=False, explain_by=None, title=None, description=None,
|
||||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||||
allow_screenshot=True, allow_flagging=True,
|
allow_screenshot=True, allow_flagging=True,
|
||||||
flagging_dir="flagged", analytics_enabled=True):
|
flagging_dir="flagged", analytics_enabled=True):
|
||||||
@ -110,6 +115,7 @@ class Interface:
|
|||||||
self.show_output = show_output
|
self.show_output = show_output
|
||||||
self.flag_hash = random.getrandbits(32)
|
self.flag_hash = random.getrandbits(32)
|
||||||
self.capture_session = capture_session
|
self.capture_session = capture_session
|
||||||
|
self.explain_by = explain_by
|
||||||
self.session = None
|
self.session = None
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.title = title
|
self.title = title
|
||||||
@ -177,7 +183,8 @@ class Interface:
|
|||||||
"description": self.description,
|
"description": self.description,
|
||||||
"thumbnail": self.thumbnail,
|
"thumbnail": self.thumbnail,
|
||||||
"allow_screenshot": self.allow_screenshot,
|
"allow_screenshot": self.allow_screenshot,
|
||||||
"allow_flagging": self.allow_flagging
|
"allow_flagging": self.allow_flagging,
|
||||||
|
"allow_interpretation": self.explain_by is not None
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||||
@ -190,7 +197,6 @@ class Interface:
|
|||||||
iface[1]["label"] = ret_name
|
iface[1]["label"] = ret_name
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def process(self, raw_input, predict_fn=None):
|
def process(self, raw_input, predict_fn=None):
|
||||||
@ -414,6 +420,88 @@ class Interface:
|
|||||||
|
|
||||||
return httpd, path_to_local_server, share_url
|
return httpd, path_to_local_server, share_url
|
||||||
|
|
||||||
|
def tokenize_text(self, text):
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
tokens = text.split()
|
||||||
|
for idx, _ in enumerate(tokens):
|
||||||
|
new_token_array = tokens.copy()
|
||||||
|
del new_token_array[idx]
|
||||||
|
leave_one_out_tokens.append(new_token_array)
|
||||||
|
return tokens, leave_one_out_tokens
|
||||||
|
|
||||||
|
def tokenize_image(self, image):
|
||||||
|
image = self.input_interfaces[0].preprocess(image)
|
||||||
|
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||||
|
leave_one_out_tokens = []
|
||||||
|
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||||
|
mask = np.copy(image)
|
||||||
|
mask[segments_slic == segVal] = 255
|
||||||
|
leave_one_out_tokens.append(mask)
|
||||||
|
return leave_one_out_tokens
|
||||||
|
|
||||||
|
def score_text(self, tokens, leave_one_out_tokens, text):
|
||||||
|
original_label = ""
|
||||||
|
original_confidence = 0
|
||||||
|
tokens = text.split()
|
||||||
|
|
||||||
|
input_text = " ".join(tokens)
|
||||||
|
output = self.predict[0](input_text)
|
||||||
|
original_label = max(output, key=output.get)
|
||||||
|
original_confidence = output[original_label]
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||||
|
input_text = " ".join(input_text)
|
||||||
|
output = self.predict[0](input_text)
|
||||||
|
scores.append(original_confidence - output[original_label])
|
||||||
|
|
||||||
|
scores_by_char = []
|
||||||
|
for idx, token in enumerate(tokens):
|
||||||
|
if idx != 0:
|
||||||
|
scores_by_char.append((" ", 0))
|
||||||
|
for char in token:
|
||||||
|
scores_by_char.append((char, scores[idx]))
|
||||||
|
return scores_by_char
|
||||||
|
|
||||||
|
def score_image(self, leave_one_out_tokens, image):
|
||||||
|
original_output = self.process(image)
|
||||||
|
original_label = original_output[0][0]['confidences'][0][
|
||||||
|
'label']
|
||||||
|
original_confidence = original_output[0][0]['confidences'][0][
|
||||||
|
'confidence']
|
||||||
|
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
|
||||||
|
image[0])), 255)
|
||||||
|
for input_image in leave_one_out_tokens:
|
||||||
|
input_image_base64 = processing_utils.encode_array_to_base64(
|
||||||
|
input_image)
|
||||||
|
input_image_arr = []
|
||||||
|
input_image_arr.append(input_image_base64)
|
||||||
|
output = self.process(input_image_arr)
|
||||||
|
np.set_printoptions(threshold=sys.maxsize)
|
||||||
|
if output[0][0]['confidences'][0]['label'] == original_label:
|
||||||
|
input_image[input_image == 255] = (original_confidence -
|
||||||
|
output[0][0][
|
||||||
|
'confidences'][0][
|
||||||
|
'confidence']) * 100
|
||||||
|
mask = (output_scores == 255)
|
||||||
|
output_scores[mask] = input_image[mask]
|
||||||
|
return output_scores
|
||||||
|
|
||||||
|
def simple_explanation(self, input):
|
||||||
|
if isinstance(self.input_interfaces[0], Textbox):
|
||||||
|
tokens, leave_one_out_tokens = self.tokenize_text(input[0])
|
||||||
|
return [self.score_text(tokens, leave_one_out_tokens, input[0])]
|
||||||
|
elif isinstance(self.input_interfaces[0], Image):
|
||||||
|
leave_one_out_tokens = self.tokenize_image(input[0])
|
||||||
|
return self.score_image(leave_one_out_tokens, input[0])
|
||||||
|
else:
|
||||||
|
print("Not valid input type")
|
||||||
|
|
||||||
|
def explain(self, input):
|
||||||
|
if self.explain_by == "default":
|
||||||
|
return self.simple_explanation(input)
|
||||||
|
else:
|
||||||
|
return self.explain_by(input)
|
||||||
|
|
||||||
def reset_all():
|
def reset_all():
|
||||||
for io in Interface.get_instances():
|
for io in Interface.get_instances():
|
||||||
|
@ -194,6 +194,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
|||||||
dict(zip(headers, output["inputs"] +
|
dict(zip(headers, output["inputs"] +
|
||||||
output["outputs"]))
|
output["outputs"]))
|
||||||
)
|
)
|
||||||
|
elif self.path == "/api/interpret/":
|
||||||
|
self._set_headers()
|
||||||
|
data_string = self.rfile.read(
|
||||||
|
int(self.headers["Content-Length"]))
|
||||||
|
msg = json.loads(data_string)
|
||||||
|
interpretation = interface.explain(msg["data"])
|
||||||
|
self.wfile.write(json.dumps(interpretation).encode())
|
||||||
else:
|
else:
|
||||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||||
|
|
||||||
|
@ -80,9 +80,24 @@ var io_master_template = {
|
|||||||
$.ajax({type: "POST",
|
$.ajax({type: "POST",
|
||||||
url: "/api/flag/",
|
url: "/api/flag/",
|
||||||
data: JSON.stringify(post_data),
|
data: JSON.stringify(post_data),
|
||||||
success: function(output){
|
});
|
||||||
console.log("Flagging successful")
|
},
|
||||||
},
|
interpret: function() {
|
||||||
|
var io = this;
|
||||||
|
this.target.find(".loading").removeClass("invisible");
|
||||||
|
this.target.find(".loading_in_progress").show();
|
||||||
|
var post_data = {
|
||||||
|
'data': this.last_input
|
||||||
|
}
|
||||||
|
$.ajax({type: "POST",
|
||||||
|
url: "/api/interpret/",
|
||||||
|
data: JSON.stringify(post_data),
|
||||||
|
success: function(data) {
|
||||||
|
for (let [idx, interpretation] of data.entries()) {
|
||||||
|
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||||
|
}
|
||||||
|
io.target.find(".loading_in_progress").hide();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -24,6 +24,7 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
<div class="output_interfaces">
|
<div class="output_interfaces">
|
||||||
</div>
|
</div>
|
||||||
<div class="panel_buttons">
|
<div class="panel_buttons">
|
||||||
|
<input class="interpret panel_button" type="button" value="INTERPRET"/>
|
||||||
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
|
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
|
||||||
<div class="screenshot_logo">
|
<div class="screenshot_logo">
|
||||||
<img src="/static/img/logo_inline.png">
|
<img src="/static/img/logo_inline.png">
|
||||||
@ -165,17 +166,14 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
io_master.last_output = null;
|
io_master.last_output = null;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (config["allow_screenshot"] && !config["allow_flagging"]) {
|
if (config["allow_screenshot"]) {
|
||||||
target.find(".screenshot").css("visibility", "visible");
|
target.find(".screenshot").css("visibility", "visible");
|
||||||
target.find(".flag").css("display", "none")
|
|
||||||
}
|
}
|
||||||
if (!config["allow_screenshot"] && config["allow_flagging"]) {
|
if (config["allow_flagging"]) {
|
||||||
target.find(".flag").css("visibility", "visible");
|
target.find(".flag").css("visibility", "visible");
|
||||||
target.find(".screenshot").css("display", "none")
|
|
||||||
}
|
}
|
||||||
if (config["allow_screenshot"] && config["allow_flagging"]) {
|
if (config["allow_interpretation"]) {
|
||||||
target.find(".screenshot").css("visibility", "visible");
|
target.find(".interpret").css("visibility", "visible");
|
||||||
target.find(".flag").css("visibility", "visible")
|
|
||||||
}
|
}
|
||||||
if (config["examples"]) {
|
if (config["examples"]) {
|
||||||
target.find(".examples").removeClass("invisible");
|
target.find(".examples").removeClass("invisible");
|
||||||
@ -233,12 +231,15 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
target.find(".flag").click(function() {
|
target.find(".flag").click(function() {
|
||||||
if (io_master.last_output) {
|
if (io_master.last_output) {
|
||||||
target.find(".flag").addClass("flagged");
|
target.find(".flag").addClass("flagged");
|
||||||
target.find(".flag").val("FLAGGED");
|
target.find(".flag").val("FLAGGED");
|
||||||
io_master.flag();
|
io_master.flag();
|
||||||
|
}
|
||||||
// io_master.flag($(".flag_message").val());
|
})
|
||||||
|
target.find(".interpret").click(function() {
|
||||||
|
if (io_master.last_output) {
|
||||||
|
io_master.interpret();
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -250,6 +251,8 @@ function gradio_url(config, url, target, example_file_path) {
|
|||||||
$.ajax({type: "POST",
|
$.ajax({type: "POST",
|
||||||
url: url,
|
url: url,
|
||||||
data: JSON.stringify({"data": data}),
|
data: JSON.stringify({"data": data}),
|
||||||
|
dataType: 'json',
|
||||||
|
contentType: 'application/json; charset=utf-8',
|
||||||
success: resolve,
|
success: resolve,
|
||||||
error: reject,
|
error: reject,
|
||||||
});
|
});
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
const textbox_input = {
|
const textbox_input = {
|
||||||
html: `<textarea class="input_text"></textarea>`,
|
html: `
|
||||||
|
<textarea class="input_text"></textarea>
|
||||||
|
<div class="output_text"></div>
|
||||||
|
`,
|
||||||
one_line_html: `<input type="text" class="input_text">`,
|
one_line_html: `<input type="text" class="input_text">`,
|
||||||
init: function(opts) {
|
init: function(opts) {
|
||||||
if (opts.lines > 1) {
|
if (opts.lines > 1) {
|
||||||
@ -13,6 +16,7 @@ const textbox_input = {
|
|||||||
if (opts.default) {
|
if (opts.default) {
|
||||||
this.target.find(".input_text").val(opts.default)
|
this.target.find(".input_text").val(opts.default)
|
||||||
}
|
}
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
},
|
},
|
||||||
submit: function() {
|
submit: function() {
|
||||||
text = this.target.find(".input_text").val();
|
text = this.target.find(".input_text").val();
|
||||||
@ -20,6 +24,24 @@ const textbox_input = {
|
|||||||
},
|
},
|
||||||
clear: function() {
|
clear: function() {
|
||||||
this.target.find(".input_text").val("");
|
this.target.find(".input_text").val("");
|
||||||
|
this.target.find(".output_text").empty();
|
||||||
|
this.target.find(".input_text").show();
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
|
},
|
||||||
|
show_interpretation: function(data) {
|
||||||
|
this.target.find(".input_text").hide();
|
||||||
|
this.target.find(".output_text").show();
|
||||||
|
let html = "";
|
||||||
|
for (let char_set of data) {
|
||||||
|
[char, value] = char_set;
|
||||||
|
if (value < 0) {
|
||||||
|
color = "8,241,255," + (-value * 2);
|
||||||
|
} else {
|
||||||
|
color = "230,126,34," + value * 2;
|
||||||
|
}
|
||||||
|
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
|
||||||
|
}
|
||||||
|
this.target.find(".output_text").html(html);
|
||||||
},
|
},
|
||||||
load_example: function(data) {
|
load_example: function(data) {
|
||||||
this.target.find(".input_text").val(data);
|
this.target.find(".input_text").val(data);
|
||||||
|
13
demo/sentiment_analysis.py
Normal file
13
demo/sentiment_analysis.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import gradio as gr
|
||||||
|
import nltk
|
||||||
|
from nltk.sentiment.vader import SentimentIntensityAnalyzer
|
||||||
|
nltk.download('vader_lexicon')
|
||||||
|
sid = SentimentIntensityAnalyzer()
|
||||||
|
|
||||||
|
def sentiment_analysis(text):
|
||||||
|
return sid.polarity_scores(text)
|
||||||
|
|
||||||
|
io = gr.Interface(sentiment_analysis, "textbox", "label", explain_by="default")
|
||||||
|
|
||||||
|
io.test_launch()
|
||||||
|
io.launch()
|
@ -3,6 +3,7 @@ README.md
|
|||||||
setup.py
|
setup.py
|
||||||
gradio/__init__.py
|
gradio/__init__.py
|
||||||
gradio/component.py
|
gradio/component.py
|
||||||
|
gradio/explain.py
|
||||||
gradio/inputs.py
|
gradio/inputs.py
|
||||||
gradio/interface.py
|
gradio/interface.py
|
||||||
gradio/networking.py
|
gradio/networking.py
|
||||||
|
@ -51,8 +51,7 @@ class Interface:
|
|||||||
|
|
||||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||||
live=False, show_input=True, show_output=True,
|
live=False, show_input=True, show_output=True,
|
||||||
capture_session=False, explain_by="default", title=None,
|
capture_session=False, explain_by=None, title=None, description=None,
|
||||||
description=None,
|
|
||||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||||
allow_screenshot=True, allow_flagging=True,
|
allow_screenshot=True, allow_flagging=True,
|
||||||
flagging_dir="flagged", analytics_enabled=True):
|
flagging_dir="flagged", analytics_enabled=True):
|
||||||
@ -184,7 +183,8 @@ class Interface:
|
|||||||
"description": self.description,
|
"description": self.description,
|
||||||
"thumbnail": self.thumbnail,
|
"thumbnail": self.thumbnail,
|
||||||
"allow_screenshot": self.allow_screenshot,
|
"allow_screenshot": self.allow_screenshot,
|
||||||
"allow_flagging": self.allow_flagging
|
"allow_flagging": self.allow_flagging,
|
||||||
|
"allow_interpretation": self.explain_by is not None
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||||
@ -423,12 +423,11 @@ class Interface:
|
|||||||
def tokenize_text(self, text):
|
def tokenize_text(self, text):
|
||||||
leave_one_out_tokens = []
|
leave_one_out_tokens = []
|
||||||
tokens = text.split()
|
tokens = text.split()
|
||||||
leave_one_out_tokens.append(tokens)
|
|
||||||
for idx, _ in enumerate(tokens):
|
for idx, _ in enumerate(tokens):
|
||||||
new_token_array = tokens.copy()
|
new_token_array = tokens.copy()
|
||||||
del new_token_array[idx]
|
del new_token_array[idx]
|
||||||
leave_one_out_tokens.append(new_token_array)
|
leave_one_out_tokens.append(new_token_array)
|
||||||
return leave_one_out_tokens
|
return tokens, leave_one_out_tokens
|
||||||
|
|
||||||
def tokenize_image(self, image):
|
def tokenize_image(self, image):
|
||||||
image = self.input_interfaces[0].preprocess(image)
|
image = self.input_interfaces[0].preprocess(image)
|
||||||
@ -440,24 +439,29 @@ class Interface:
|
|||||||
leave_one_out_tokens.append(mask)
|
leave_one_out_tokens.append(mask)
|
||||||
return leave_one_out_tokens
|
return leave_one_out_tokens
|
||||||
|
|
||||||
def score_text(self, leave_one_out_tokens, text):
|
def score_text(self, tokens, leave_one_out_tokens, text):
|
||||||
original_label = ""
|
original_label = ""
|
||||||
original_confidence = 0
|
original_confidence = 0
|
||||||
tokens = text.split()
|
tokens = text.split()
|
||||||
outputs = {}
|
|
||||||
|
input_text = " ".join(tokens)
|
||||||
|
output = self.predict[0](input_text)
|
||||||
|
original_label = max(output, key=output.get)
|
||||||
|
original_confidence = output[original_label]
|
||||||
|
|
||||||
|
scores = []
|
||||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||||
input_text = " ".join(input_text)
|
input_text = " ".join(input_text)
|
||||||
output = self.process(input_text)
|
output = self.predict[0](input_text)
|
||||||
if idx == 0:
|
scores.append(original_confidence - output[original_label])
|
||||||
original_label = output[0][0]['confidences'][0][
|
|
||||||
'label']
|
scores_by_char = []
|
||||||
original_confidence = output[0][0]['confidences'][0][
|
for idx, token in enumerate(tokens):
|
||||||
'confidence']
|
if idx != 0:
|
||||||
else:
|
scores_by_char.append((" ", 0))
|
||||||
if output[0][0]['confidences'][0]['label'] == original_label:
|
for char in token:
|
||||||
outputs[tokens[idx-1]] = original_confidence - output[0][0]
|
scores_by_char.append((char, scores[idx]))
|
||||||
['confidences'][0]['confidence']
|
return scores_by_char
|
||||||
return outputs
|
|
||||||
|
|
||||||
def score_image(self, leave_one_out_tokens, image):
|
def score_image(self, leave_one_out_tokens, image):
|
||||||
original_output = self.process(image)
|
original_output = self.process(image)
|
||||||
@ -485,11 +489,11 @@ class Interface:
|
|||||||
|
|
||||||
def simple_explanation(self, input):
|
def simple_explanation(self, input):
|
||||||
if isinstance(self.input_interfaces[0], Textbox):
|
if isinstance(self.input_interfaces[0], Textbox):
|
||||||
leave_one_out_tokens = self.tokenize_text(input[0])
|
tokens, leave_one_out_tokens = self.tokenize_text(input[0])
|
||||||
return self.score_text(leave_one_out_tokens, input[0])
|
return [self.score_text(tokens, leave_one_out_tokens, input[0])]
|
||||||
elif isinstance(self.input_interfaces[0], Image):
|
elif isinstance(self.input_interfaces[0], Image):
|
||||||
leave_one_out_tokens = self.tokenize_image(input[0])
|
leave_one_out_tokens = self.tokenize_image(input[0])
|
||||||
return self.score_image(leave_one_out_tokens, input)
|
return self.score_image(leave_one_out_tokens, input[0])
|
||||||
else:
|
else:
|
||||||
print("Not valid input type")
|
print("Not valid input type")
|
||||||
|
|
||||||
|
@ -194,6 +194,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
|||||||
dict(zip(headers, output["inputs"] +
|
dict(zip(headers, output["inputs"] +
|
||||||
output["outputs"]))
|
output["outputs"]))
|
||||||
)
|
)
|
||||||
|
elif self.path == "/api/interpret/":
|
||||||
|
self._set_headers()
|
||||||
|
data_string = self.rfile.read(
|
||||||
|
int(self.headers["Content-Length"]))
|
||||||
|
msg = json.loads(data_string)
|
||||||
|
interpretation = interface.explain(msg["data"])
|
||||||
|
self.wfile.write(json.dumps(interpretation).encode())
|
||||||
else:
|
else:
|
||||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||||
|
|
||||||
|
@ -80,9 +80,24 @@ var io_master_template = {
|
|||||||
$.ajax({type: "POST",
|
$.ajax({type: "POST",
|
||||||
url: "/api/flag/",
|
url: "/api/flag/",
|
||||||
data: JSON.stringify(post_data),
|
data: JSON.stringify(post_data),
|
||||||
success: function(output){
|
});
|
||||||
console.log("Flagging successful")
|
},
|
||||||
},
|
interpret: function() {
|
||||||
|
var io = this;
|
||||||
|
this.target.find(".loading").removeClass("invisible");
|
||||||
|
this.target.find(".loading_in_progress").show();
|
||||||
|
var post_data = {
|
||||||
|
'data': this.last_input
|
||||||
|
}
|
||||||
|
$.ajax({type: "POST",
|
||||||
|
url: "/api/interpret/",
|
||||||
|
data: JSON.stringify(post_data),
|
||||||
|
success: function(data) {
|
||||||
|
for (let [idx, interpretation] of data.entries()) {
|
||||||
|
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||||
|
}
|
||||||
|
io.target.find(".loading_in_progress").hide();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -24,6 +24,7 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
<div class="output_interfaces">
|
<div class="output_interfaces">
|
||||||
</div>
|
</div>
|
||||||
<div class="panel_buttons">
|
<div class="panel_buttons">
|
||||||
|
<input class="interpret panel_button" type="button" value="INTERPRET"/>
|
||||||
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
|
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
|
||||||
<div class="screenshot_logo">
|
<div class="screenshot_logo">
|
||||||
<img src="/static/img/logo_inline.png">
|
<img src="/static/img/logo_inline.png">
|
||||||
@ -165,17 +166,14 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
io_master.last_output = null;
|
io_master.last_output = null;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (config["allow_screenshot"] && !config["allow_flagging"]) {
|
if (config["allow_screenshot"]) {
|
||||||
target.find(".screenshot").css("visibility", "visible");
|
target.find(".screenshot").css("visibility", "visible");
|
||||||
target.find(".flag").css("display", "none")
|
|
||||||
}
|
}
|
||||||
if (!config["allow_screenshot"] && config["allow_flagging"]) {
|
if (config["allow_flagging"]) {
|
||||||
target.find(".flag").css("visibility", "visible");
|
target.find(".flag").css("visibility", "visible");
|
||||||
target.find(".screenshot").css("display", "none")
|
|
||||||
}
|
}
|
||||||
if (config["allow_screenshot"] && config["allow_flagging"]) {
|
if (config["allow_interpretation"]) {
|
||||||
target.find(".screenshot").css("visibility", "visible");
|
target.find(".interpret").css("visibility", "visible");
|
||||||
target.find(".flag").css("visibility", "visible")
|
|
||||||
}
|
}
|
||||||
if (config["examples"]) {
|
if (config["examples"]) {
|
||||||
target.find(".examples").removeClass("invisible");
|
target.find(".examples").removeClass("invisible");
|
||||||
@ -233,12 +231,15 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
target.find(".flag").click(function() {
|
target.find(".flag").click(function() {
|
||||||
if (io_master.last_output) {
|
if (io_master.last_output) {
|
||||||
target.find(".flag").addClass("flagged");
|
target.find(".flag").addClass("flagged");
|
||||||
target.find(".flag").val("FLAGGED");
|
target.find(".flag").val("FLAGGED");
|
||||||
io_master.flag();
|
io_master.flag();
|
||||||
|
}
|
||||||
// io_master.flag($(".flag_message").val());
|
})
|
||||||
|
target.find(".interpret").click(function() {
|
||||||
|
if (io_master.last_output) {
|
||||||
|
io_master.interpret();
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -250,6 +251,8 @@ function gradio_url(config, url, target, example_file_path) {
|
|||||||
$.ajax({type: "POST",
|
$.ajax({type: "POST",
|
||||||
url: url,
|
url: url,
|
||||||
data: JSON.stringify({"data": data}),
|
data: JSON.stringify({"data": data}),
|
||||||
|
dataType: 'json',
|
||||||
|
contentType: 'application/json; charset=utf-8',
|
||||||
success: resolve,
|
success: resolve,
|
||||||
error: reject,
|
error: reject,
|
||||||
});
|
});
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
const textbox_input = {
|
const textbox_input = {
|
||||||
html: `<textarea class="input_text"></textarea>`,
|
html: `
|
||||||
|
<textarea class="input_text"></textarea>
|
||||||
|
<div class="output_text"></div>
|
||||||
|
`,
|
||||||
one_line_html: `<input type="text" class="input_text">`,
|
one_line_html: `<input type="text" class="input_text">`,
|
||||||
init: function(opts) {
|
init: function(opts) {
|
||||||
if (opts.lines > 1) {
|
if (opts.lines > 1) {
|
||||||
@ -13,6 +16,7 @@ const textbox_input = {
|
|||||||
if (opts.default) {
|
if (opts.default) {
|
||||||
this.target.find(".input_text").val(opts.default)
|
this.target.find(".input_text").val(opts.default)
|
||||||
}
|
}
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
},
|
},
|
||||||
submit: function() {
|
submit: function() {
|
||||||
text = this.target.find(".input_text").val();
|
text = this.target.find(".input_text").val();
|
||||||
@ -20,6 +24,24 @@ const textbox_input = {
|
|||||||
},
|
},
|
||||||
clear: function() {
|
clear: function() {
|
||||||
this.target.find(".input_text").val("");
|
this.target.find(".input_text").val("");
|
||||||
|
this.target.find(".output_text").empty();
|
||||||
|
this.target.find(".input_text").show();
|
||||||
|
this.target.find(".output_text").hide();
|
||||||
|
},
|
||||||
|
show_interpretation: function(data) {
|
||||||
|
this.target.find(".input_text").hide();
|
||||||
|
this.target.find(".output_text").show();
|
||||||
|
let html = "";
|
||||||
|
for (let char_set of data) {
|
||||||
|
[char, value] = char_set;
|
||||||
|
if (value < 0) {
|
||||||
|
color = "8,241,255," + (-value * 2);
|
||||||
|
} else {
|
||||||
|
color = "230,126,34," + value * 2;
|
||||||
|
}
|
||||||
|
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
|
||||||
|
}
|
||||||
|
this.target.find(".output_text").html(html);
|
||||||
},
|
},
|
||||||
load_example: function(data) {
|
load_example: function(data) {
|
||||||
this.target.find(".input_text").val(data);
|
this.target.find(".input_text").val(data);
|
||||||
|
Loading…
Reference in New Issue
Block a user