sentiment analysis demo

This commit is contained in:
Ali Abid 2020-09-15 10:37:02 -07:00
parent af7d0c0655
commit 3b73a5522a
13 changed files with 318 additions and 59 deletions

View File

@ -0,0 +1,59 @@
from gradio.inputs import Textbox
from gradio.inputs import Image
from skimage.color import rgb2gray
from skimage.filters import sobel
from skimage.segmentation import slic
from skimage.util import img_as_float
from skimage import io
import numpy as np
def tokenize_text(text):
leave_one_out_tokens = []
tokens = text.split()
leave_one_out_tokens.append(tokens)
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens
def tokenize_image(image):
img = img_as_float(image[::2, ::2])
segments_slic = slic(img, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(img)
mask[segments_slic == segVal] = 255
leave_one_out_tokens.append(mask)
return leave_one_out_tokens
def score(outputs):
print(outputs)
def simple_explanation(interface, input_interfaces,
output_interfaces, input):
if isinstance(input_interfaces[0], Textbox):
leave_one_out_tokens = tokenize_text(input[0])
outputs = []
for input_text in leave_one_out_tokens:
input_text = " ".join(input_text)
print("Input Text: ", input_text)
output = interface.process(input_text)
outputs.extend(output)
print("Output: ", output)
score(outputs)
elif isinstance(input_interfaces[0], Image):
leave_one_out_tokens = tokenize_image(input[0])
outputs = []
for input_text in leave_one_out_tokens:
input_text = " ".join(input_text)
print("Input Text: ", input_text)
output = interface.process(input_text)
outputs.extend(output)
print("Output: ", output)
score(outputs)
else:
print("Not valid input type")

View File

@ -7,9 +7,13 @@ import tempfile
import webbrowser import webbrowser
from gradio.inputs import InputComponent from gradio.inputs import InputComponent
from gradio.inputs import Image
from gradio.inputs import Textbox
from gradio.outputs import OutputComponent from gradio.outputs import OutputComponent
from gradio import networking, strings, utils from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion from distutils.version import StrictVersion
from skimage.segmentation import slic
from skimage.util import img_as_float
import pkg_resources import pkg_resources
import requests import requests
import random import random
@ -20,6 +24,7 @@ import sys
import weakref import weakref
import analytics import analytics
import os import os
import numpy as np
PKG_VERSION_URL = "https://gradio.app/api/pkg-version" PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
@ -46,7 +51,7 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None, def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True, live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None, capture_session=False, explain_by=None, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME, thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True, allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True): flagging_dir="flagged", analytics_enabled=True):
@ -110,6 +115,7 @@ class Interface:
self.show_output = show_output self.show_output = show_output
self.flag_hash = random.getrandbits(32) self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session self.capture_session = capture_session
self.explain_by = explain_by
self.session = None self.session = None
self.server_name = server_name self.server_name = server_name
self.title = title self.title = title
@ -177,7 +183,8 @@ class Interface:
"description": self.description, "description": self.description,
"thumbnail": self.thumbnail, "thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot, "allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging "allow_flagging": self.allow_flagging,
"allow_interpretation": self.explain_by is not None
} }
try: try:
param_names = inspect.getfullargspec(self.predict[0])[0] param_names = inspect.getfullargspec(self.predict[0])[0]
@ -190,7 +197,6 @@ class Interface:
iface[1]["label"] = ret_name iface[1]["label"] = ret_name
except ValueError: except ValueError:
pass pass
return config return config
def process(self, raw_input, predict_fn=None): def process(self, raw_input, predict_fn=None):
@ -414,6 +420,88 @@ class Interface:
return httpd, path_to_local_server, share_url return httpd, path_to_local_server, share_url
def tokenize_text(self, text):
leave_one_out_tokens = []
tokens = text.split()
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return tokens, leave_one_out_tokens
def tokenize_image(self, image):
image = self.input_interfaces[0].preprocess(image)
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = np.copy(image)
mask[segments_slic == segVal] = 255
leave_one_out_tokens.append(mask)
return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text):
original_label = ""
original_confidence = 0
tokens = text.split()
input_text = " ".join(tokens)
output = self.predict[0](input_text)
original_label = max(output, key=output.get)
original_confidence = output[original_label]
scores = []
for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text)
output = self.predict[0](input_text)
scores.append(original_confidence - output[original_label])
scores_by_char = []
for idx, token in enumerate(tokens):
if idx != 0:
scores_by_char.append((" ", 0))
for char in token:
scores_by_char.append((char, scores[idx]))
return scores_by_char
def score_image(self, leave_one_out_tokens, image):
original_output = self.process(image)
original_label = original_output[0][0]['confidences'][0][
'label']
original_confidence = original_output[0][0]['confidences'][0][
'confidence']
output_scores = np.full(np.shape(self.input_interfaces[0].preprocess(
image[0])), 255)
for input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64(
input_image)
input_image_arr = []
input_image_arr.append(input_image_base64)
output = self.process(input_image_arr)
np.set_printoptions(threshold=sys.maxsize)
if output[0][0]['confidences'][0]['label'] == original_label:
input_image[input_image == 255] = (original_confidence -
output[0][0][
'confidences'][0][
'confidence']) * 100
mask = (output_scores == 255)
output_scores[mask] = input_image[mask]
return output_scores
def simple_explanation(self, input):
if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(input[0])
return [self.score_text(tokens, leave_one_out_tokens, input[0])]
elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(input[0])
return self.score_image(leave_one_out_tokens, input[0])
else:
print("Not valid input type")
def explain(self, input):
if self.explain_by == "default":
return self.simple_explanation(input)
else:
return self.explain_by(input)
def reset_all(): def reset_all():
for io in Interface.get_instances(): for io in Interface.get_instances():

View File

@ -194,6 +194,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
dict(zip(headers, output["inputs"] + dict(zip(headers, output["inputs"] +
output["outputs"])) output["outputs"]))
) )
elif self.path == "/api/interpret/":
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
interpretation = interface.explain(msg["data"])
self.wfile.write(json.dumps(interpretation).encode())
else: else:
self.send_error(404, 'Path not found: {}'.format(self.path)) self.send_error(404, 'Path not found: {}'.format(self.path))

View File

@ -80,9 +80,24 @@ var io_master_template = {
$.ajax({type: "POST", $.ajax({type: "POST",
url: "/api/flag/", url: "/api/flag/",
data: JSON.stringify(post_data), data: JSON.stringify(post_data),
success: function(output){ });
console.log("Flagging successful") },
}, interpret: function() {
var io = this;
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
var post_data = {
'data': this.last_input
}
$.ajax({type: "POST",
url: "/api/interpret/",
data: JSON.stringify(post_data),
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();
}
}); });
} }
}; };

View File

@ -24,6 +24,7 @@ function gradio(config, fn, target, example_file_path) {
<div class="output_interfaces"> <div class="output_interfaces">
</div> </div>
<div class="panel_buttons"> <div class="panel_buttons">
<input class="interpret panel_button" type="button" value="INTERPRET"/>
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/> <input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
<div class="screenshot_logo"> <div class="screenshot_logo">
<img src="/static/img/logo_inline.png"> <img src="/static/img/logo_inline.png">
@ -165,17 +166,14 @@ function gradio(config, fn, target, example_file_path) {
io_master.last_output = null; io_master.last_output = null;
}); });
if (config["allow_screenshot"] && !config["allow_flagging"]) { if (config["allow_screenshot"]) {
target.find(".screenshot").css("visibility", "visible"); target.find(".screenshot").css("visibility", "visible");
target.find(".flag").css("display", "none")
} }
if (!config["allow_screenshot"] && config["allow_flagging"]) { if (config["allow_flagging"]) {
target.find(".flag").css("visibility", "visible"); target.find(".flag").css("visibility", "visible");
target.find(".screenshot").css("display", "none")
} }
if (config["allow_screenshot"] && config["allow_flagging"]) { if (config["allow_interpretation"]) {
target.find(".screenshot").css("visibility", "visible"); target.find(".interpret").css("visibility", "visible");
target.find(".flag").css("visibility", "visible")
} }
if (config["examples"]) { if (config["examples"]) {
target.find(".examples").removeClass("invisible"); target.find(".examples").removeClass("invisible");
@ -233,12 +231,15 @@ function gradio(config, fn, target, example_file_path) {
} }
target.find(".flag").click(function() { target.find(".flag").click(function() {
if (io_master.last_output) { if (io_master.last_output) {
target.find(".flag").addClass("flagged"); target.find(".flag").addClass("flagged");
target.find(".flag").val("FLAGGED"); target.find(".flag").val("FLAGGED");
io_master.flag(); io_master.flag();
}
// io_master.flag($(".flag_message").val()); })
target.find(".interpret").click(function() {
if (io_master.last_output) {
io_master.interpret();
} }
}) })
@ -250,6 +251,8 @@ function gradio_url(config, url, target, example_file_path) {
$.ajax({type: "POST", $.ajax({type: "POST",
url: url, url: url,
data: JSON.stringify({"data": data}), data: JSON.stringify({"data": data}),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
success: resolve, success: resolve,
error: reject, error: reject,
}); });

View File

@ -1,5 +1,8 @@
const textbox_input = { const textbox_input = {
html: `<textarea class="input_text"></textarea>`, html: `
<textarea class="input_text"></textarea>
<div class="output_text"></div>
`,
one_line_html: `<input type="text" class="input_text">`, one_line_html: `<input type="text" class="input_text">`,
init: function(opts) { init: function(opts) {
if (opts.lines > 1) { if (opts.lines > 1) {
@ -13,6 +16,7 @@ const textbox_input = {
if (opts.default) { if (opts.default) {
this.target.find(".input_text").val(opts.default) this.target.find(".input_text").val(opts.default)
} }
this.target.find(".output_text").hide();
}, },
submit: function() { submit: function() {
text = this.target.find(".input_text").val(); text = this.target.find(".input_text").val();
@ -20,6 +24,24 @@ const textbox_input = {
}, },
clear: function() { clear: function() {
this.target.find(".input_text").val(""); this.target.find(".input_text").val("");
this.target.find(".output_text").empty();
this.target.find(".input_text").show();
this.target.find(".output_text").hide();
},
show_interpretation: function(data) {
this.target.find(".input_text").hide();
this.target.find(".output_text").show();
let html = "";
for (let char_set of data) {
[char, value] = char_set;
if (value < 0) {
color = "8,241,255," + (-value * 2);
} else {
color = "230,126,34," + value * 2;
}
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
}
this.target.find(".output_text").html(html);
}, },
load_example: function(data) { load_example: function(data) {
this.target.find(".input_text").val(data); this.target.find(".input_text").val(data);

View File

@ -0,0 +1,13 @@
import gradio as gr
import nltk
from nltk.sentiment.vader import SentimentIntensityAnalyzer
nltk.download('vader_lexicon')
sid = SentimentIntensityAnalyzer()
def sentiment_analysis(text):
return sid.polarity_scores(text)
io = gr.Interface(sentiment_analysis, "textbox", "label", explain_by="default")
io.test_launch()
io.launch()

View File

@ -3,6 +3,7 @@ README.md
setup.py setup.py
gradio/__init__.py gradio/__init__.py
gradio/component.py gradio/component.py
gradio/explain.py
gradio/inputs.py gradio/inputs.py
gradio/interface.py gradio/interface.py
gradio/networking.py gradio/networking.py

View File

@ -51,8 +51,7 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None, def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True, live=False, show_input=True, show_output=True,
capture_session=False, explain_by="default", title=None, capture_session=False, explain_by=None, title=None, description=None,
description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME, thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True, allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True): flagging_dir="flagged", analytics_enabled=True):
@ -184,7 +183,8 @@ class Interface:
"description": self.description, "description": self.description,
"thumbnail": self.thumbnail, "thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot, "allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging "allow_flagging": self.allow_flagging,
"allow_interpretation": self.explain_by is not None
} }
try: try:
param_names = inspect.getfullargspec(self.predict[0])[0] param_names = inspect.getfullargspec(self.predict[0])[0]
@ -423,12 +423,11 @@ class Interface:
def tokenize_text(self, text): def tokenize_text(self, text):
leave_one_out_tokens = [] leave_one_out_tokens = []
tokens = text.split() tokens = text.split()
leave_one_out_tokens.append(tokens)
for idx, _ in enumerate(tokens): for idx, _ in enumerate(tokens):
new_token_array = tokens.copy() new_token_array = tokens.copy()
del new_token_array[idx] del new_token_array[idx]
leave_one_out_tokens.append(new_token_array) leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens return tokens, leave_one_out_tokens
def tokenize_image(self, image): def tokenize_image(self, image):
image = self.input_interfaces[0].preprocess(image) image = self.input_interfaces[0].preprocess(image)
@ -440,24 +439,29 @@ class Interface:
leave_one_out_tokens.append(mask) leave_one_out_tokens.append(mask)
return leave_one_out_tokens return leave_one_out_tokens
def score_text(self, leave_one_out_tokens, text): def score_text(self, tokens, leave_one_out_tokens, text):
original_label = "" original_label = ""
original_confidence = 0 original_confidence = 0
tokens = text.split() tokens = text.split()
outputs = {}
input_text = " ".join(tokens)
output = self.predict[0](input_text)
original_label = max(output, key=output.get)
original_confidence = output[original_label]
scores = []
for idx, input_text in enumerate(leave_one_out_tokens): for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text) input_text = " ".join(input_text)
output = self.process(input_text) output = self.predict[0](input_text)
if idx == 0: scores.append(original_confidence - output[original_label])
original_label = output[0][0]['confidences'][0][
'label'] scores_by_char = []
original_confidence = output[0][0]['confidences'][0][ for idx, token in enumerate(tokens):
'confidence'] if idx != 0:
else: scores_by_char.append((" ", 0))
if output[0][0]['confidences'][0]['label'] == original_label: for char in token:
outputs[tokens[idx-1]] = original_confidence - output[0][0] scores_by_char.append((char, scores[idx]))
['confidences'][0]['confidence'] return scores_by_char
return outputs
def score_image(self, leave_one_out_tokens, image): def score_image(self, leave_one_out_tokens, image):
original_output = self.process(image) original_output = self.process(image)
@ -485,11 +489,11 @@ class Interface:
def simple_explanation(self, input): def simple_explanation(self, input):
if isinstance(self.input_interfaces[0], Textbox): if isinstance(self.input_interfaces[0], Textbox):
leave_one_out_tokens = self.tokenize_text(input[0]) tokens, leave_one_out_tokens = self.tokenize_text(input[0])
return self.score_text(leave_one_out_tokens, input[0]) return [self.score_text(tokens, leave_one_out_tokens, input[0])]
elif isinstance(self.input_interfaces[0], Image): elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(input[0]) leave_one_out_tokens = self.tokenize_image(input[0])
return self.score_image(leave_one_out_tokens, input) return self.score_image(leave_one_out_tokens, input[0])
else: else:
print("Not valid input type") print("Not valid input type")

View File

@ -194,6 +194,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
dict(zip(headers, output["inputs"] + dict(zip(headers, output["inputs"] +
output["outputs"])) output["outputs"]))
) )
elif self.path == "/api/interpret/":
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
interpretation = interface.explain(msg["data"])
self.wfile.write(json.dumps(interpretation).encode())
else: else:
self.send_error(404, 'Path not found: {}'.format(self.path)) self.send_error(404, 'Path not found: {}'.format(self.path))

View File

@ -80,9 +80,24 @@ var io_master_template = {
$.ajax({type: "POST", $.ajax({type: "POST",
url: "/api/flag/", url: "/api/flag/",
data: JSON.stringify(post_data), data: JSON.stringify(post_data),
success: function(output){ });
console.log("Flagging successful") },
}, interpret: function() {
var io = this;
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
var post_data = {
'data': this.last_input
}
$.ajax({type: "POST",
url: "/api/interpret/",
data: JSON.stringify(post_data),
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();
}
}); });
} }
}; };

View File

@ -24,6 +24,7 @@ function gradio(config, fn, target, example_file_path) {
<div class="output_interfaces"> <div class="output_interfaces">
</div> </div>
<div class="panel_buttons"> <div class="panel_buttons">
<input class="interpret panel_button" type="button" value="INTERPRET"/>
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/> <input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
<div class="screenshot_logo"> <div class="screenshot_logo">
<img src="/static/img/logo_inline.png"> <img src="/static/img/logo_inline.png">
@ -165,17 +166,14 @@ function gradio(config, fn, target, example_file_path) {
io_master.last_output = null; io_master.last_output = null;
}); });
if (config["allow_screenshot"] && !config["allow_flagging"]) { if (config["allow_screenshot"]) {
target.find(".screenshot").css("visibility", "visible"); target.find(".screenshot").css("visibility", "visible");
target.find(".flag").css("display", "none")
} }
if (!config["allow_screenshot"] && config["allow_flagging"]) { if (config["allow_flagging"]) {
target.find(".flag").css("visibility", "visible"); target.find(".flag").css("visibility", "visible");
target.find(".screenshot").css("display", "none")
} }
if (config["allow_screenshot"] && config["allow_flagging"]) { if (config["allow_interpretation"]) {
target.find(".screenshot").css("visibility", "visible"); target.find(".interpret").css("visibility", "visible");
target.find(".flag").css("visibility", "visible")
} }
if (config["examples"]) { if (config["examples"]) {
target.find(".examples").removeClass("invisible"); target.find(".examples").removeClass("invisible");
@ -233,12 +231,15 @@ function gradio(config, fn, target, example_file_path) {
} }
target.find(".flag").click(function() { target.find(".flag").click(function() {
if (io_master.last_output) { if (io_master.last_output) {
target.find(".flag").addClass("flagged"); target.find(".flag").addClass("flagged");
target.find(".flag").val("FLAGGED"); target.find(".flag").val("FLAGGED");
io_master.flag(); io_master.flag();
}
// io_master.flag($(".flag_message").val()); })
target.find(".interpret").click(function() {
if (io_master.last_output) {
io_master.interpret();
} }
}) })
@ -250,6 +251,8 @@ function gradio_url(config, url, target, example_file_path) {
$.ajax({type: "POST", $.ajax({type: "POST",
url: url, url: url,
data: JSON.stringify({"data": data}), data: JSON.stringify({"data": data}),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
success: resolve, success: resolve,
error: reject, error: reject,
}); });

View File

@ -1,5 +1,8 @@
const textbox_input = { const textbox_input = {
html: `<textarea class="input_text"></textarea>`, html: `
<textarea class="input_text"></textarea>
<div class="output_text"></div>
`,
one_line_html: `<input type="text" class="input_text">`, one_line_html: `<input type="text" class="input_text">`,
init: function(opts) { init: function(opts) {
if (opts.lines > 1) { if (opts.lines > 1) {
@ -13,6 +16,7 @@ const textbox_input = {
if (opts.default) { if (opts.default) {
this.target.find(".input_text").val(opts.default) this.target.find(".input_text").val(opts.default)
} }
this.target.find(".output_text").hide();
}, },
submit: function() { submit: function() {
text = this.target.find(".input_text").val(); text = this.target.find(".input_text").val();
@ -20,6 +24,24 @@ const textbox_input = {
}, },
clear: function() { clear: function() {
this.target.find(".input_text").val(""); this.target.find(".input_text").val("");
this.target.find(".output_text").empty();
this.target.find(".input_text").show();
this.target.find(".output_text").hide();
},
show_interpretation: function(data) {
this.target.find(".input_text").hide();
this.target.find(".output_text").show();
let html = "";
for (let char_set of data) {
[char, value] = char_set;
if (value < 0) {
color = "8,241,255," + (-value * 2);
} else {
color = "230,126,34," + value * 2;
}
html += `<span title="${value}" style="background-color: rgba(${color})">${char}</span>`
}
this.target.find(".output_text").html(html);
}, },
load_example: function(data) { load_example: function(data) {
this.target.find(".input_text").val(data); this.target.find(".input_text").val(data);