This commit is contained in:
aliabd 2020-06-22 15:24:36 -07:00
commit 2497a2e2a8
8 changed files with 128 additions and 89 deletions

View File

@ -66,7 +66,8 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput):
def __init__(self, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0,
def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
flatten=False, scale=1/255, shift=0,
dtype='float64', sample_inputs=None, label=None):
self.image_width = shape[0]
self.image_height = shape[1]
@ -272,8 +273,9 @@ class Checkbox(AbstractInput):
class ImageIn(AbstractInput):
def __init__(self, shape=(224, 224, 3), image_mode='RGB',
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
self.cast_to = cast_to
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = shape[2]
@ -298,10 +300,29 @@ class ImageIn(AbstractInput):
**super().get_template_context()
}
def cast_to_base64(self, inp):
return inp
def cast_to_im(self, inp):
return preprocessing_utils.decode_base64_to_image(inp)
def cast_to_numpy(self, inp):
im = self.cast_to_im(inp)
arr = np.array(im).flatten()
return arr
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
cast_to_type = {
"base64": self.cast_to_base64,
"numpy": self.cast_to_numpy,
"pillow": self.cast_to_im
}
if self.cast_to:
return cast_to_type[self.cast_to](inp)
im = preprocessing_utils.decode_base64_to_image(inp)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

View File

@ -82,8 +82,8 @@ class Interface:
self.session = None
self.server_name = server_name
def update_config_file(self, output_directory):
config = {
def get_config_file(self):
return {
"input_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.input_interfaces],
@ -95,7 +95,38 @@ class Interface:
"show_input": self.show_input,
"show_output": self.show_output,
}
networking.set_config(config, output_directory)
def process(self, raw_input):
processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
predictions = []
for predict_fn in self.predict:
if self.context:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
self.context)
else:
prediction = predict_fn(*processed_input,
self.context)
else:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(self.output_interfaces) / \
len(self.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output
def validate(self):
if self.validate_flag:
@ -181,8 +212,7 @@ class Interface:
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
self.update_config_file(output_directory)
networking.set_config(self.get_config_file(), output_directory)
self.status = "RUNNING"
self.simple_server = httpd

View File

@ -43,7 +43,8 @@ def build_template(temp_dir):
:param temp_dir: string with path to temp directory in which the html file should be built
"""
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
@ -81,6 +82,7 @@ def render_string_or_list_with_tags(old_lines, context):
new_lines.append(line)
return new_lines
def set_config(config, temp_dir):
config_file = os.path.join(temp_dir, CONFIG_FILE)
with open(config_file, "w") as output:
@ -133,36 +135,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
if self.path == "/api/predict/":
# Make the prediction.
self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"]))
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
predictions = []
for predict_fn in interface.predict:
if interface.context:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
interface.context)
else:
prediction = predict_fn(*processed_input,
interface.context)
else:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(interface.output_interfaces) / \
len(interface.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
output = {"action": "output", "data": processed_output}
output = {"action": "output", "data": interface.process(raw_input)}
if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist()
@ -182,34 +159,37 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
elif self.path == "/api/flag/":
self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"]))
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
flag_dir = os.path.join(FLAGGING_DIRECTORY,
str(interface.flag_hash))
os.makedirs(flag_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['input_data']) for i
flag_dir, msg['data']['input_data']) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['output_data']) for i
'outputs': [interface.output_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))],
'message': msg['data']['message']}
'message': msg['data']['message']}
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
f.write(json.dumps(output))
f.write("\n")
#TODO(abidlabs): clean this up
# TODO(abidlabs): clean this up
elif self.path == "/api/auto/rotation":
from gradio import validation_data, preprocessing_utils
import numpy as np
self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"]))
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
img_orig = preprocessing_utils.decode_base64_to_image(
msg["data"])
img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224))
@ -219,8 +199,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
for deg in range(-180, 180+45, 45):
img = img_orig.rotate(deg)
img_array = np.array(img) / 127.5 - 1
prediction = interface.predict(np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(prediction)
prediction = interface.predict(
np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(
prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}),
@ -240,9 +222,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
from PIL import ImageEnhance
self._set_headers()
data_string = self.rfile.read(int(self.headers["Content-Length"]))
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
img_orig = preprocessing_utils.decode_base64_to_image(
msg["data"])
img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224))
enhancer = ImageEnhance.Brightness(img_orig)
@ -253,8 +237,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
for i in range(9):
img = enhancer.enhance(i/4)
img_array = np.array(img) / 127.5 - 1
prediction = interface.predict(np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(prediction)
prediction = interface.predict(
np.expand_dims(img_array, axis=0))
processed_output = interface.output_interface.postprocess(
prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}),
@ -299,7 +285,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None):
port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
)
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
httpd = serve_files_in_background(
interface, port, directory_to_serve, server_name)
return port, httpd

View File

@ -3,7 +3,7 @@
align-items: center;
font-size: 18px;
}
.checkbox_group input {
.checkbox_group input, .checkbox {
margin: 0px 4px 0px 0px;
}
.checkbox_group label {

View File

@ -1,6 +1,5 @@
body {
font-family: 'Open Sans', sans-serif;
font-size: 12px;
margin: 0;
}
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {

View File

@ -1,5 +1,5 @@
const checkbox = {
html: `<input type="checkbox">`,
html: `<input class="checkbox" type="checkbox">`,
init: function(opts) {
this.target.css("height", "auto");
},

View File

@ -82,8 +82,8 @@ class Interface:
self.session = None
self.server_name = server_name
def update_config_file(self, output_directory):
config = {
def get_config_file(self):
return {
"input_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.input_interfaces],
@ -95,7 +95,38 @@ class Interface:
"show_input": self.show_input,
"show_output": self.show_output,
}
networking.set_config(config, output_directory)
def process(self, raw_input):
processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
predictions = []
for predict_fn in self.predict:
if self.context:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
self.context)
else:
prediction = predict_fn(*processed_input,
self.context)
else:
if self.capture_session:
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(self.output_interfaces) / \
len(self.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output
def validate(self):
if self.validate_flag:
@ -181,8 +212,7 @@ class Interface:
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
self.update_config_file(output_directory)
networking.set_config(self.get_config_file(), output_directory)
self.status = "RUNNING"
self.simple_server = httpd

View File

@ -139,35 +139,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
predictions = []
for predict_fn in interface.predict:
if interface.context:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input,
interface.context)
else:
prediction = predict_fn(*processed_input,
interface.context)
else:
if interface.capture_session:
graph, sess = interface.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
if len(interface.output_interfaces) / \
len(interface.predict) == 1:
prediction = [prediction]
predictions.extend(prediction)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
output = {"action": "output", "data": processed_output}
output = {"action": "output", "data": interface.process(raw_input)}
if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist()