mid progress

This commit is contained in:
Ali Abid 2020-09-17 14:38:22 -07:00
parent 6d7b6e8122
commit 7dc1a83ce3
11 changed files with 187 additions and 41 deletions

View File

@ -292,6 +292,7 @@ class Image(InputComponent):
def preprocess(self, x):
im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
@ -305,7 +306,7 @@ class Image(InputComponent):
elif self.type == "numpy":
return np.array(im)
elif self.type == "file":
file_obj = tempfile.NamedTemporaryFile()
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
im.save(file_obj.name)
return file_obj
else:

View File

@ -55,9 +55,9 @@ class Textbox(OutputComponent):
def postprocess(self, y):
if self.type == "str" or self.type == "auto":
return y
elif self.type == "number":
return str(y)
elif self.type == "number":
return y
else:
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")

View File

@ -24,6 +24,7 @@ function gradio(config, fn, target, example_file_path) {
<div class="output_interfaces">
</div>
<div class="panel_buttons">
<input class="interpret panel_button" type="button" value="INTERPRET"/>
<input class="screenshot panel_button" type="button" value="SCREENSHOT"/>
<div class="screenshot_logo">
<img src="/static/img/logo_inline.png">
@ -165,17 +166,14 @@ function gradio(config, fn, target, example_file_path) {
io_master.last_output = null;
});
if (config["allow_screenshot"] && !config["allow_flagging"]) {
if (config["allow_screenshot"]) {
target.find(".screenshot").css("visibility", "visible");
target.find(".flag").css("display", "none")
}
if (!config["allow_screenshot"] && config["allow_flagging"]) {
if (config["allow_flagging"]) {
target.find(".flag").css("visibility", "visible");
target.find(".screenshot").css("display", "none")
}
if (config["allow_screenshot"] && config["allow_flagging"]) {
target.find(".screenshot").css("visibility", "visible");
target.find(".flag").css("visibility", "visible")
if (config["allow_interpretation"]) {
target.find(".interpret").css("visibility", "visible");
}
if (config["examples"]) {
target.find(".examples").removeClass("invisible");

File diff suppressed because one or more lines are too long

View File

@ -11,5 +11,5 @@ def reverse_audio(audio):
io = gr.Interface(reverse_audio, "microphone", "audio")
io.test_launch()
# io.test_launch()
io.launch()

View File

@ -292,6 +292,7 @@ class Image(InputComponent):
def preprocess(self, x):
im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
@ -305,7 +306,7 @@ class Image(InputComponent):
elif self.type == "numpy":
return np.array(im)
elif self.type == "file":
file_obj = tempfile.NamedTemporaryFile()
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
im.save(file_obj.name)
return file_obj
else:

View File

@ -7,9 +7,15 @@ import tempfile
import webbrowser
from gradio.inputs import InputComponent
from gradio.inputs import Image
from gradio.inputs import Textbox
from gradio.outputs import OutputComponent
from gradio import networking, strings, utils
from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion
from skimage.segmentation import slic
from skimage.util import img_as_float
from gradio import processing_utils
import PIL
import pkg_resources
import requests
import random
@ -20,6 +26,7 @@ import sys
import weakref
import analytics
import os
import numpy as np
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
@ -46,7 +53,7 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
capture_session=False, explain_by=None, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True):
@ -110,6 +117,7 @@ class Interface:
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.explain_by = explain_by
self.session = None
self.server_name = server_name
self.title = title
@ -177,7 +185,8 @@ class Interface:
"description": self.description,
"thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging
"allow_flagging": self.allow_flagging,
"allow_interpretation": self.explain_by is not None
}
try:
param_names = inspect.getfullargspec(self.predict[0])[0]
@ -190,7 +199,6 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
return config
def process(self, raw_input, predict_fn=None):
@ -210,7 +218,7 @@ class Interface:
durations = []
for predict_fn in self.predict:
start = time.time()
if self.capture_session and not (self.session is None):
if self.capture_session and self.session is not None:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
@ -284,7 +292,7 @@ class Interface:
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd, thread = networking.start_simple_server(
self, output_directory, self.server_name, server_port=self.server_port)
self, self.server_name, server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
@ -414,6 +422,94 @@ class Interface:
return httpd, path_to_local_server, share_url
def tokenize_text(self, text):
leave_one_out_tokens = []
tokens = text.split()
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return tokens, leave_one_out_tokens
def tokenize_image(self, image):
image = np.array(processing_utils.decode_base64_to_image(image))
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = 255
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text):
original_label = ""
original_confidence = 0
tokens = text.split()
input_text = " ".join(tokens)
original_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
scores = []
for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text)
raw_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
scores.append(original_confidence - output[original_label])
scores_by_char = []
for idx, token in enumerate(tokens):
if idx != 0:
scores_by_char.append((" ", 0))
for char in token:
scores_by_char.append((char, scores[idx]))
return scores_by_char
def score_image(self, leave_one_out_tokens, image):
original_output = self.process([image])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
image_interface = self.input_interfaces[0]
shape = processing_utils.decode_base64_to_image(image).size
output_scores = np.full((shape[1], shape[0]), 0.0)
for mask, input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64(
input_image)
raw_output = self.process([input_image_base64])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
score = original_confidence - output[original_label]
output_scores += score * mask
max_val = np.max(np.abs(output_scores))
if max_val > 0:
output_scores = output_scores / max_val
return output_scores.tolist()
def simple_explanation(self, x):
if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(x[0])
return [self.score_image(leave_one_out_tokens, x[0])]
else:
print("Not valid input type")
def explain(self, x):
if self.explain_by == "default":
return self.simple_explanation(x)
else:
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
return self.explain_by(*preprocessed_x)
def reset_all():
for io in Interface.get_instances():

View File

@ -128,7 +128,7 @@ def file(path):
return send_file(os.path.join(os.getcwd(), path))
def start_server(interface, directory_to_serve=None, server_name=None, server_port=None):
def start_server(interface, server_name=None, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
port = get_first_available_port(

View File

@ -54,10 +54,10 @@ class Textbox(OutputComponent):
}
def postprocess(self, y):
if self.type == "str":
return y
elif self.type == "number" or self.type == "auto":
if self.type == "str" or self.type == "auto":
return str(y)
elif self.type == "number":
return y
else:
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")

File diff suppressed because one or more lines are too long

View File

@ -1,47 +1,97 @@
import unittest
import gradio as gr
import PIL
import numpy as np
class TestTextbox(unittest.TestCase):
def test_component(self):
box = gr.inputs.Textbox()
assert box.preprocess("Hello") == "Hello"
box = gr.inputs.Textbox(type="str")
assert box.preprocess(125) == 125
def test_interface(self):
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
assert interface.process("Hello") == "olleH"
assert iface.process(["Hello"])[0] == ["olleH"]
iface = gr.Interface(lambda x: x*x, "number", "number")
assert interface.process(5) == 25
assert iface.process(["5"])[0] == [25]
class TestSlider(unittest.TestCase):
pass
def test_interface(self):
iface = gr.Interface(lambda x: str(x) + " cats", "slider", "textbox")
assert iface.process([4])[0] == ["4 cats"]
class TestCheckbox(unittest.TestCase):
pass
def test_interface(self):
iface = gr.Interface(lambda x: "yes" if x else "no", "checkbox", "textbox")
assert iface.process([False])[0] == ["no"]
class TestCheckboxGroup(unittest.TestCase):
pass
def test_interface(self):
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"])
iface = gr.Interface(lambda x: "|".join(x), checkboxes, "textbox")
assert iface.process([["a", "c"]])[0] == ["a|c"]
assert iface.process([[]])[0] == [""]
checkboxes = gr.inputs.CheckboxGroup(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: "|".join(map(str, x)), checkboxes, "textbox")
assert iface.process([["a", "c"]])[0] == ["0|2"]
class TestRadio(unittest.TestCase):
pass
def test_interface(self):
radio = gr.inputs.Radio(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, radio, "textbox")
assert iface.process(["c"])[0] == ["cc"]
radio = gr.inputs.Radio(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, radio, "number")
assert iface.process(["c"])[0] == [4]
class TestDropdown(unittest.TestCase):
pass
def test_interface(self):
dropdown = gr.inputs.Dropdown(["a", "b", "c"])
iface = gr.Interface(lambda x: 2 * x, dropdown, "textbox")
assert iface.process(["c"])[0] == ["cc"]
dropdown = gr.inputs.Dropdown(["a", "b", "c"], type="index")
iface = gr.Interface(lambda x: 2 * x, dropdown, "number")
assert iface.process(["c"])[0] == [4]
class TestImage(unittest.TestCase):
pass
def test_component(self):
x_img = gr.test_data.BASE64_IMAGE
image_input = gr.inputs.Image()
assert image_input.preprocess(x_img).shape == (68, 61 ,3)
image_input = gr.inputs.Image(image_mode="L", shape=(25, 25))
assert image_input.preprocess(x_img).shape == (25, 25)
image_input = gr.inputs.Image(shape=(30, 10), type="pil")
assert image_input.preprocess(x_img).size == (30, 10)
def test_interface(self):
x_img = gr.test_data.BASE64_IMAGE
def open_and_rotate(img_file):
img = PIL.Image.open(img_file)
return img.rotate(90, expand=True)
iface = gr.Interface(
open_and_rotate,
gr.inputs.Image(shape=(30, 10), type="file"),
"image")
output = iface.process([x_img])[0][0]
assert gr.processing_utils.decode_base64_to_image(output).size == (10, 30)
class TestAudio(unittest.TestCase):
pass
def test_component(self):
x_wav = gr.test_data.BASE64_AUDIO
audio_input = gr.inputs.Audio()
output = audio_input.preprocess(x_wav)
print(output[0])
print(output[1].shape)
assert output[0] == 44000
assert output[1].shape == (100, 2)
def test_interface(self):
pass
class TestFile(unittest.TestCase):