This commit is contained in:
aliabd 2020-06-22 15:24:36 -07:00
commit 2497a2e2a8
8 changed files with 128 additions and 89 deletions

View File

@ -66,7 +66,8 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput): class Sketchpad(AbstractInput):
def __init__(self, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0, def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
flatten=False, scale=1/255, shift=0,
dtype='float64', sample_inputs=None, label=None): dtype='float64', sample_inputs=None, label=None):
self.image_width = shape[0] self.image_width = shape[0]
self.image_height = shape[1] self.image_height = shape[1]
@ -272,8 +273,9 @@ class Checkbox(AbstractInput):
class ImageIn(AbstractInput): class ImageIn(AbstractInput):
def __init__(self, shape=(224, 224, 3), image_mode='RGB', def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None): scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
self.cast_to = cast_to
self.image_width = shape[0] self.image_width = shape[0]
self.image_height = shape[1] self.image_height = shape[1]
self.num_channels = shape[2] self.num_channels = shape[2]
@ -298,10 +300,29 @@ class ImageIn(AbstractInput):
**super().get_template_context() **super().get_template_context()
} }
def cast_to_base64(self, inp):
return inp
def cast_to_im(self, inp):
return preprocessing_utils.decode_base64_to_image(inp)
def cast_to_numpy(self, inp):
im = self.cast_to_im(inp)
arr = np.array(im).flatten()
return arr
def preprocess(self, inp): def preprocess(self, inp):
""" """
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
""" """
cast_to_type = {
"base64": self.cast_to_base64,
"numpy": self.cast_to_numpy,
"pillow": self.cast_to_im
}
if self.cast_to:
return cast_to_type[self.cast_to](inp)
im = preprocessing_utils.decode_base64_to_image(inp) im = preprocessing_utils.decode_base64_to_image(inp)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")

View File

@ -82,8 +82,8 @@ class Interface:
self.session = None self.session = None
self.server_name = server_name self.server_name = server_name
def update_config_file(self, output_directory): def get_config_file(self):
config = { return {
"input_interfaces": [ "input_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context()) (iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.input_interfaces], for iface in self.input_interfaces],
@ -95,7 +95,38 @@ class Interface:
"show_input": self.show_input, "show_input": self.show_input,
"show_output": self.show_output, "show_output": self.show_output,
} }
networking.set_config(config, output_directory)
def process(self, raw_input):
processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
predictions = []
for predict_fn in self.predict:
if self.context:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
self.context)
else:
prediction = predict_fn(*processed_input,
self.context)
else:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(self.output_interfaces) / \
len(self.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output
def validate(self): def validate(self):
if self.validate_flag: if self.validate_flag:
@ -181,8 +212,7 @@ class Interface:
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name) server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port) path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory) networking.build_template(output_directory)
networking.set_config(self.get_config_file(), output_directory)
self.update_config_file(output_directory)
self.status = "RUNNING" self.status = "RUNNING"
self.simple_server = httpd self.simple_server = httpd

View File

@ -43,7 +43,8 @@ def build_template(temp_dir):
:param temp_dir: string with path to temp directory in which the html file should be built :param temp_dir: string with path to temp directory in which the html file should be built
""" """
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir) dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP)) dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory. # Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC), copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
@ -81,6 +82,7 @@ def render_string_or_list_with_tags(old_lines, context):
new_lines.append(line) new_lines.append(line)
return new_lines return new_lines
def set_config(config, temp_dir): def set_config(config, temp_dir):
config_file = os.path.join(temp_dir, CONFIG_FILE) config_file = os.path.join(temp_dir, CONFIG_FILE)
with open(config_file, "w") as output: with open(config_file, "w") as output:
@ -133,36 +135,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
if self.path == "/api/predict/": if self.path == "/api/predict/":
# Make the prediction. # Make the prediction.
self._set_headers() self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"])) data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string) msg = json.loads(data_string)
raw_input = msg["data"] raw_input = msg["data"]
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)] output = {"action": "output", "data": interface.process(raw_input)}
predictions = []
for predict_fn in interface.predict:
if interface.context:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
interface.context)
else:
prediction = predict_fn(*processed_input,
interface.context)
else:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(interface.output_interfaces) / \
len(interface.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
output = {"action": "output", "data": processed_output}
if interface.saliency is not None: if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction) saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist() output['saliency'] = saliency.tolist()
@ -182,7 +159,8 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
elif self.path == "/api/flag/": elif self.path == "/api/flag/":
self._set_headers() self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"])) data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string) msg = json.loads(data_string)
flag_dir = os.path.join(FLAGGING_DIRECTORY, flag_dir = os.path.join(FLAGGING_DIRECTORY,
str(interface.flag_hash)) str(interface.flag_hash))
@ -207,9 +185,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
import numpy as np import numpy as np
self._set_headers() self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"])) data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string) msg = json.loads(data_string)
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"]) img_orig = preprocessing_utils.decode_base64_to_image(
msg["data"])
img_orig = img_orig.convert('RGB') img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224)) img_orig = img_orig.resize((224, 224))
@ -219,8 +199,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
for deg in range(-180, 180+45, 45): for deg in range(-180, 180+45, 45):
img = img_orig.rotate(deg) img = img_orig.rotate(deg)
img_array = np.array(img) / 127.5 - 1 img_array = np.array(img) / 127.5 - 1
prediction = interface.predict(np.expand_dims(img_array, axis=0)) prediction = interface.predict(
processed_output = interface.output_interface.postprocess(prediction) np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(
prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img), output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged( 'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}), flag_dir, {'data': {'output': processed_output}}),
@ -240,9 +222,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
from PIL import ImageEnhance from PIL import ImageEnhance
self._set_headers() self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"])) data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string) msg = json.loads(data_string)
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"]) img_orig = preprocessing_utils.decode_base64_to_image(
msg["data"])
img_orig = img_orig.convert('RGB') img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224)) img_orig = img_orig.resize((224, 224))
enhancer = ImageEnhance.Brightness(img_orig) enhancer = ImageEnhance.Brightness(img_orig)
@ -253,8 +237,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
for i in range(9): for i in range(9):
img = enhancer.enhance(i/4) img = enhancer.enhance(i/4)
img_array = np.array(img) / 127.5 - 1 img_array = np.array(img) / 127.5 - 1
prediction = interface.predict(np.expand_dims(img_array, axis=0)) prediction = interface.predict(
processed_output = interface.output_interface.postprocess(prediction) np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(
prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img), output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged( 'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}), flag_dir, {'data': {'output': processed_output}}),
@ -299,7 +285,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None):
port = get_first_available_port( port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
) )
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name) httpd = serve_files_in_background(
interface, port, directory_to_serve, server_name)
return port, httpd return port, httpd

View File

@ -3,7 +3,7 @@
align-items: center; align-items: center;
font-size: 18px; font-size: 18px;
} }
.checkbox_group input { .checkbox_group input, .checkbox {
margin: 0px 4px 0px 0px; margin: 0px 4px 0px 0px;
} }
.checkbox_group label { .checkbox_group label {

View File

@ -1,6 +1,5 @@
body { body {
font-family: 'Open Sans', sans-serif; font-family: 'Open Sans', sans-serif;
font-size: 12px;
margin: 0; margin: 0;
} }
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] { button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {

View File

@ -1,5 +1,5 @@
const checkbox = { const checkbox = {
html: `<input type="checkbox">`, html: `<input class="checkbox" type="checkbox">`,
init: function(opts) { init: function(opts) {
this.target.css("height", "auto"); this.target.css("height", "auto");
}, },

View File

@ -82,8 +82,8 @@ class Interface:
self.session = None self.session = None
self.server_name = server_name self.server_name = server_name
def update_config_file(self, output_directory): def get_config_file(self):
config = { return {
"input_interfaces": [ "input_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context()) (iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.input_interfaces], for iface in self.input_interfaces],
@ -95,7 +95,38 @@ class Interface:
"show_input": self.show_input, "show_input": self.show_input,
"show_output": self.show_output, "show_output": self.show_output,
} }
networking.set_config(config, output_directory)
def process(self, raw_input):
processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
predictions = []
for predict_fn in self.predict:
if self.context:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
self.context)
else:
prediction = predict_fn(*processed_input,
self.context)
else:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(self.output_interfaces) / \
len(self.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output
def validate(self): def validate(self):
if self.validate_flag: if self.validate_flag:
@ -181,8 +212,7 @@ class Interface:
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name) server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port) path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory) networking.build_template(output_directory)
networking.set_config(self.get_config_file(), output_directory)
self.update_config_file(output_directory)
self.status = "RUNNING" self.status = "RUNNING"
self.simple_server = httpd self.simple_server = httpd

View File

@ -139,35 +139,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
int(self.headers["Content-Length"])) int(self.headers["Content-Length"]))
msg = json.loads(data_string) msg = json.loads(data_string)
raw_input = msg["data"] raw_input = msg["data"]
processed_input = [input_interface.preprocess( output = {"action": "output", "data": interface.process(raw_input)}
raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
predictions = []
for predict_fn in interface.predict:
if interface.context:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
interface.context)
else:
prediction = predict_fn(*processed_input,
interface.context)
else:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(interface.output_interfaces) / \
len(interface.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
output = {"action": "output", "data": processed_output}
if interface.saliency is not None: if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction) saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist() output['saliency'] = saliency.tolist()