diff --git a/build/lib/gradio/inputs.py b/build/lib/gradio/inputs.py
index a4137dfbc4..67d927229f 100644
--- a/build/lib/gradio/inputs.py
+++ b/build/lib/gradio/inputs.py
@@ -66,7 +66,8 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput):
- def __init__(self, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0,
+ def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
+ flatten=False, scale=1/255, shift=0,
dtype='float64', sample_inputs=None, label=None):
self.image_width = shape[0]
self.image_height = shape[1]
@@ -272,8 +273,9 @@ class Checkbox(AbstractInput):
class ImageIn(AbstractInput):
- def __init__(self, shape=(224, 224, 3), image_mode='RGB',
+ def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
+ self.cast_to = cast_to
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = shape[2]
@@ -298,10 +300,29 @@ class ImageIn(AbstractInput):
**super().get_template_context()
}
+ def cast_to_base64(self, inp):
+ return inp
+
+ def cast_to_im(self, inp):
+ return preprocessing_utils.decode_base64_to_image(inp)
+
+ def cast_to_numpy(self, inp):
+ im = self.cast_to_im(inp)
+ arr = np.array(im).flatten()
+ return arr
+
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
+ cast_to_type = {
+ "base64": self.cast_to_base64,
+ "numpy": self.cast_to_numpy,
+ "pillow": self.cast_to_im
+ }
+ if self.cast_to:
+ return cast_to_type[self.cast_to](inp)
+
im = preprocessing_utils.decode_base64_to_image(inp)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
diff --git a/build/lib/gradio/interface.py b/build/lib/gradio/interface.py
index c52343e497..40117cd509 100644
--- a/build/lib/gradio/interface.py
+++ b/build/lib/gradio/interface.py
@@ -82,8 +82,8 @@ class Interface:
self.session = None
self.server_name = server_name
- def update_config_file(self, output_directory):
- config = {
+ def get_config_file(self):
+ return {
"input_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.input_interfaces],
@@ -95,7 +95,38 @@ class Interface:
"show_input": self.show_input,
"show_output": self.show_output,
}
- networking.set_config(config, output_directory)
+
+ def process(self, raw_input):
+ processed_input = [input_interface.preprocess(
+ raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
+ predictions = []
+ for predict_fn in self.predict:
+ if self.context:
+ if self.capture_session:
+ graph, sess = self.session
+ with graph.as_default():
+ with sess.as_default():
+ prediction = predict_fn(*processed_input,
+ self.context)
+ else:
+ prediction = predict_fn(*processed_input,
+ self.context)
+ else:
+ if self.capture_session:
+ graph, sess = self.session
+ with graph.as_default():
+ with sess.as_default():
+ prediction = predict_fn(*processed_input)
+ else:
+ prediction = predict_fn(*processed_input)
+ if len(self.output_interfaces) / \
+ len(self.predict) == 1:
+ prediction = [prediction]
+ predictions.extend(prediction)
+ processed_output = [output_interface.postprocess(
+ predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
+ return processed_output
+
def validate(self):
if self.validate_flag:
@@ -181,8 +212,7 @@ class Interface:
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
-
- self.update_config_file(output_directory)
+ networking.set_config(self.get_config_file(), output_directory)
self.status = "RUNNING"
self.simple_server = httpd
diff --git a/build/lib/gradio/networking.py b/build/lib/gradio/networking.py
index 72912e9a9e..f39bf80653 100644
--- a/build/lib/gradio/networking.py
+++ b/build/lib/gradio/networking.py
@@ -43,7 +43,8 @@ def build_template(temp_dir):
:param temp_dir: string with path to temp directory in which the html file should be built
"""
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
- dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
+ dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
+ temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
@@ -81,6 +82,7 @@ def render_string_or_list_with_tags(old_lines, context):
new_lines.append(line)
return new_lines
+
def set_config(config, temp_dir):
config_file = os.path.join(temp_dir, CONFIG_FILE)
with open(config_file, "w") as output:
@@ -133,36 +135,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
if self.path == "/api/predict/":
# Make the prediction.
self._set_headers()
- data_string = self.rfile.read(int(self.headers["Content-Length"]))
+ data_string = self.rfile.read(
+ int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
- processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
- predictions = []
- for predict_fn in interface.predict:
- if interface.context:
- if interface.capture_session:
- graph, sess = interface.session
- with graph.as_default():
- with sess.as_default():
- prediction = predict_fn(*processed_input,
- interface.context)
- else:
- prediction = predict_fn(*processed_input,
- interface.context)
- else:
- if interface.capture_session:
- graph, sess = interface.session
- with graph.as_default():
- with sess.as_default():
- prediction = predict_fn(*processed_input)
- else:
- prediction = predict_fn(*processed_input)
- if len(interface.output_interfaces) / \
- len(interface.predict) == 1:
- prediction = [prediction]
- predictions.extend(prediction)
- processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
- output = {"action": "output", "data": processed_output}
+ output = {"action": "output", "data": interface.process(raw_input)}
if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist()
@@ -182,34 +159,37 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
elif self.path == "/api/flag/":
self._set_headers()
- data_string = self.rfile.read(int(self.headers["Content-Length"]))
+ data_string = self.rfile.read(
+ int(self.headers["Content-Length"]))
msg = json.loads(data_string)
flag_dir = os.path.join(FLAGGING_DIRECTORY,
str(interface.flag_hash))
os.makedirs(flag_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild_flagged(
- flag_dir, msg['data']['input_data']) for i
+ flag_dir, msg['data']['input_data']) for i
in range(len(interface.input_interfaces))],
- 'outputs': [interface.output_interfaces[
- i].rebuild_flagged(
- flag_dir, msg['data']['output_data']) for i
+ 'outputs': [interface.output_interfaces[
+ i].rebuild_flagged(
+ flag_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))],
- 'message': msg['data']['message']}
+ 'message': msg['data']['message']}
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
f.write(json.dumps(output))
f.write("\n")
- #TODO(abidlabs): clean this up
+ # TODO(abidlabs): clean this up
elif self.path == "/api/auto/rotation":
from gradio import validation_data, preprocessing_utils
import numpy as np
self._set_headers()
- data_string = self.rfile.read(int(self.headers["Content-Length"]))
+ data_string = self.rfile.read(
+ int(self.headers["Content-Length"]))
msg = json.loads(data_string)
- img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
+ img_orig = preprocessing_utils.decode_base64_to_image(
+ msg["data"])
img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224))
@@ -219,8 +199,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
for deg in range(-180, 180+45, 45):
img = img_orig.rotate(deg)
img_array = np.array(img) / 127.5 - 1
- prediction = interface.predict(np.expand_dims(img_array, axis=0))
- processed_output = interface.output_interface.postprocess(prediction)
+ prediction = interface.predict(
+ np.expand_dims(img_array, axis=0))
+ processed_output = interface.output_interface.postprocess(
+ prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}),
@@ -240,9 +222,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
from PIL import ImageEnhance
self._set_headers()
- data_string = self.rfile.read(int(self.headers["Content-Length"]))
+ data_string = self.rfile.read(
+ int(self.headers["Content-Length"]))
msg = json.loads(data_string)
- img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
+ img_orig = preprocessing_utils.decode_base64_to_image(
+ msg["data"])
img_orig = img_orig.convert('RGB')
img_orig = img_orig.resize((224, 224))
enhancer = ImageEnhance.Brightness(img_orig)
@@ -253,8 +237,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
for i in range(9):
img = enhancer.enhance(i/4)
img_array = np.array(img) / 127.5 - 1
- prediction = interface.predict(np.expand_dims(img_array, axis=0))
- processed_output = interface.output_interface.postprocess(prediction)
+ prediction = interface.predict(
+ np.expand_dims(img_array, axis=0))
+ processed_output = interface.output_interface.postprocess(
+ prediction)
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
'output': interface.output_interface.rebuild_flagged(
flag_dir, {'data': {'output': processed_output}}),
@@ -299,7 +285,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None):
port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
)
- httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
+ httpd = serve_files_in_background(
+ interface, port, directory_to_serve, server_name)
return port, httpd
diff --git a/build/lib/gradio/static/css/interfaces/input/checkbox_group.css b/build/lib/gradio/static/css/interfaces/input/checkbox_group.css
index e570830208..7d93fdcfd2 100644
--- a/build/lib/gradio/static/css/interfaces/input/checkbox_group.css
+++ b/build/lib/gradio/static/css/interfaces/input/checkbox_group.css
@@ -3,7 +3,7 @@
align-items: center;
font-size: 18px;
}
-.checkbox_group input {
+.checkbox_group input, .checkbox {
margin: 0px 4px 0px 0px;
}
.checkbox_group label {
diff --git a/build/lib/gradio/static/css/style.css b/build/lib/gradio/static/css/style.css
index ce100507eb..6a15f296df 100644
--- a/build/lib/gradio/static/css/style.css
+++ b/build/lib/gradio/static/css/style.css
@@ -1,6 +1,5 @@
body {
font-family: 'Open Sans', sans-serif;
- font-size: 12px;
margin: 0;
}
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {
diff --git a/build/lib/gradio/static/js/interfaces/input/checkbox.js b/build/lib/gradio/static/js/interfaces/input/checkbox.js
index 61f0694f86..3e533816ce 100644
--- a/build/lib/gradio/static/js/interfaces/input/checkbox.js
+++ b/build/lib/gradio/static/js/interfaces/input/checkbox.js
@@ -1,5 +1,5 @@
const checkbox = {
- html: ``,
+ html: ``,
init: function(opts) {
this.target.css("height", "auto");
},
diff --git a/gradio/interface.py b/gradio/interface.py
index c52343e497..40117cd509 100644
--- a/gradio/interface.py
+++ b/gradio/interface.py
@@ -82,8 +82,8 @@ class Interface:
self.session = None
self.server_name = server_name
- def update_config_file(self, output_directory):
- config = {
+ def get_config_file(self):
+ return {
"input_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.input_interfaces],
@@ -95,7 +95,38 @@ class Interface:
"show_input": self.show_input,
"show_output": self.show_output,
}
- networking.set_config(config, output_directory)
+
+ def process(self, raw_input):
+ processed_input = [input_interface.preprocess(
+ raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
+ predictions = []
+ for predict_fn in self.predict:
+ if self.context:
+ if self.capture_session:
+ graph, sess = self.session
+ with graph.as_default():
+ with sess.as_default():
+ prediction = predict_fn(*processed_input,
+ self.context)
+ else:
+ prediction = predict_fn(*processed_input,
+ self.context)
+ else:
+ if self.capture_session:
+ graph, sess = self.session
+ with graph.as_default():
+ with sess.as_default():
+ prediction = predict_fn(*processed_input)
+ else:
+ prediction = predict_fn(*processed_input)
+ if len(self.output_interfaces) / \
+ len(self.predict) == 1:
+ prediction = [prediction]
+ predictions.extend(prediction)
+ processed_output = [output_interface.postprocess(
+ predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
+ return processed_output
+
def validate(self):
if self.validate_flag:
@@ -181,8 +212,7 @@ class Interface:
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
-
- self.update_config_file(output_directory)
+ networking.set_config(self.get_config_file(), output_directory)
self.status = "RUNNING"
self.simple_server = httpd
diff --git a/gradio/networking.py b/gradio/networking.py
index 7fbf5cecd0..f39bf80653 100644
--- a/gradio/networking.py
+++ b/gradio/networking.py
@@ -139,35 +139,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
- processed_input = [input_interface.preprocess(
- raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
- predictions = []
- for predict_fn in interface.predict:
- if interface.context:
- if interface.capture_session:
- graph, sess = interface.session
- with graph.as_default():
- with sess.as_default():
- prediction = predict_fn(*processed_input,
- interface.context)
- else:
- prediction = predict_fn(*processed_input,
- interface.context)
- else:
- if interface.capture_session:
- graph, sess = interface.session
- with graph.as_default():
- with sess.as_default():
- prediction = predict_fn(*processed_input)
- else:
- prediction = predict_fn(*processed_input)
- if len(interface.output_interfaces) / \
- len(interface.predict) == 1:
- prediction = [prediction]
- predictions.extend(prediction)
- processed_output = [output_interface.postprocess(
- predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
- output = {"action": "output", "data": processed_output}
+ output = {"action": "output", "data": interface.process(raw_input)}
if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist()