diff --git a/build/lib/gradio/inputs.py b/build/lib/gradio/inputs.py index a4137dfbc4..67d927229f 100644 --- a/build/lib/gradio/inputs.py +++ b/build/lib/gradio/inputs.py @@ -66,7 +66,8 @@ class AbstractInput(ABC): class Sketchpad(AbstractInput): - def __init__(self, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0, + def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True, + flatten=False, scale=1/255, shift=0, dtype='float64', sample_inputs=None, label=None): self.image_width = shape[0] self.image_height = shape[1] @@ -272,8 +273,9 @@ class Checkbox(AbstractInput): class ImageIn(AbstractInput): - def __init__(self, shape=(224, 224, 3), image_mode='RGB', + def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB', scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None): + self.cast_to = cast_to self.image_width = shape[0] self.image_height = shape[1] self.num_channels = shape[2] @@ -298,10 +300,29 @@ class ImageIn(AbstractInput): **super().get_template_context() } + def cast_to_base64(self, inp): + return inp + + def cast_to_im(self, inp): + return preprocessing_utils.decode_base64_to_image(inp) + + def cast_to_numpy(self, inp): + im = self.cast_to_im(inp) + arr = np.array(im).flatten() + return arr + def preprocess(self, inp): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ + cast_to_type = { + "base64": self.cast_to_base64, + "numpy": self.cast_to_numpy, + "pillow": self.cast_to_im + } + if self.cast_to: + return cast_to_type[self.cast_to](inp) + im = preprocessing_utils.decode_base64_to_image(inp) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/build/lib/gradio/interface.py b/build/lib/gradio/interface.py index c52343e497..40117cd509 100644 --- a/build/lib/gradio/interface.py +++ b/build/lib/gradio/interface.py @@ -82,8 +82,8 @@ class Interface: self.session = None self.server_name = server_name - def update_config_file(self, output_directory): - config = { + def get_config_file(self): + return { "input_interfaces": [ (iface.__class__.__name__.lower(), iface.get_template_context()) for iface in self.input_interfaces], @@ -95,7 +95,38 @@ class Interface: "show_input": self.show_input, "show_output": self.show_output, } - networking.set_config(config, output_directory) + + def process(self, raw_input): + processed_input = [input_interface.preprocess( + raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)] + predictions = [] + for predict_fn in self.predict: + if self.context: + if self.capture_session: + graph, sess = self.session + with graph.as_default(): + with sess.as_default(): + prediction = predict_fn(*processed_input, + self.context) + else: + prediction = predict_fn(*processed_input, + self.context) + else: + if self.capture_session: + graph, sess = self.session + with graph.as_default(): + with sess.as_default(): + prediction = predict_fn(*processed_input) + else: + prediction = predict_fn(*processed_input) + if len(self.output_interfaces) / \ + len(self.predict) == 1: + prediction = [prediction] + predictions.extend(prediction) + processed_output = [output_interface.postprocess( + predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] + return processed_output + def validate(self): if self.validate_flag: @@ -181,8 +212,7 @@ class Interface: server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name) path_to_local_server = "http://{}:{}/".format(self.server_name, server_port) networking.build_template(output_directory) - - self.update_config_file(output_directory) + networking.set_config(self.get_config_file(), output_directory) self.status = "RUNNING" self.simple_server = httpd diff --git a/build/lib/gradio/networking.py b/build/lib/gradio/networking.py index 72912e9a9e..f39bf80653 100644 --- a/build/lib/gradio/networking.py +++ b/build/lib/gradio/networking.py @@ -43,7 +43,8 @@ def build_template(temp_dir): :param temp_dir: string with path to temp directory in which the html file should be built """ dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir) - dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP)) + dir_util.copy_tree(STATIC_PATH_LIB, os.path.join( + temp_dir, STATIC_PATH_TEMP)) # Move association file to root of temporary directory. copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC), @@ -81,6 +82,7 @@ def render_string_or_list_with_tags(old_lines, context): new_lines.append(line) return new_lines + def set_config(config, temp_dir): config_file = os.path.join(temp_dir, CONFIG_FILE) with open(config_file, "w") as output: @@ -133,36 +135,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n if self.path == "/api/predict/": # Make the prediction. self._set_headers() - data_string = self.rfile.read(int(self.headers["Content-Length"])) + data_string = self.rfile.read( + int(self.headers["Content-Length"])) msg = json.loads(data_string) raw_input = msg["data"] - processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)] - predictions = [] - for predict_fn in interface.predict: - if interface.context: - if interface.capture_session: - graph, sess = interface.session - with graph.as_default(): - with sess.as_default(): - prediction = predict_fn(*processed_input, - interface.context) - else: - prediction = predict_fn(*processed_input, - interface.context) - else: - if interface.capture_session: - graph, sess = interface.session - with graph.as_default(): - with sess.as_default(): - prediction = predict_fn(*processed_input) - else: - prediction = predict_fn(*processed_input) - if len(interface.output_interfaces) / \ - len(interface.predict) == 1: - prediction = [prediction] - predictions.extend(prediction) - processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)] - output = {"action": "output", "data": processed_output} + output = {"action": "output", "data": interface.process(raw_input)} if interface.saliency is not None: saliency = interface.saliency(raw_input, prediction) output['saliency'] = saliency.tolist() @@ -182,34 +159,37 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n elif self.path == "/api/flag/": self._set_headers() - data_string = self.rfile.read(int(self.headers["Content-Length"])) + data_string = self.rfile.read( + int(self.headers["Content-Length"])) msg = json.loads(data_string) flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.flag_hash)) os.makedirs(flag_dir, exist_ok=True) output = {'inputs': [interface.input_interfaces[ i].rebuild_flagged( - flag_dir, msg['data']['input_data']) for i + flag_dir, msg['data']['input_data']) for i in range(len(interface.input_interfaces))], - 'outputs': [interface.output_interfaces[ - i].rebuild_flagged( - flag_dir, msg['data']['output_data']) for i + 'outputs': [interface.output_interfaces[ + i].rebuild_flagged( + flag_dir, msg['data']['output_data']) for i in range(len(interface.output_interfaces))], - 'message': msg['data']['message']} + 'message': msg['data']['message']} with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f: f.write(json.dumps(output)) f.write("\n") - #TODO(abidlabs): clean this up + # TODO(abidlabs): clean this up elif self.path == "/api/auto/rotation": from gradio import validation_data, preprocessing_utils import numpy as np self._set_headers() - data_string = self.rfile.read(int(self.headers["Content-Length"])) + data_string = self.rfile.read( + int(self.headers["Content-Length"])) msg = json.loads(data_string) - img_orig = preprocessing_utils.decode_base64_to_image(msg["data"]) + img_orig = preprocessing_utils.decode_base64_to_image( + msg["data"]) img_orig = img_orig.convert('RGB') img_orig = img_orig.resize((224, 224)) @@ -219,8 +199,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n for deg in range(-180, 180+45, 45): img = img_orig.rotate(deg) img_array = np.array(img) / 127.5 - 1 - prediction = interface.predict(np.expand_dims(img_array, axis=0)) - processed_output = interface.output_interface.postprocess(prediction) + prediction = interface.predict( + np.expand_dims(img_array, axis=0)) + processed_output = interface.output_interface.postprocess( + prediction) output = {'input': interface.input_interface.save_to_file(flag_dir, img), 'output': interface.output_interface.rebuild_flagged( flag_dir, {'data': {'output': processed_output}}), @@ -240,9 +222,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n from PIL import ImageEnhance self._set_headers() - data_string = self.rfile.read(int(self.headers["Content-Length"])) + data_string = self.rfile.read( + int(self.headers["Content-Length"])) msg = json.loads(data_string) - img_orig = preprocessing_utils.decode_base64_to_image(msg["data"]) + img_orig = preprocessing_utils.decode_base64_to_image( + msg["data"]) img_orig = img_orig.convert('RGB') img_orig = img_orig.resize((224, 224)) enhancer = ImageEnhance.Brightness(img_orig) @@ -253,8 +237,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n for i in range(9): img = enhancer.enhance(i/4) img_array = np.array(img) / 127.5 - 1 - prediction = interface.predict(np.expand_dims(img_array, axis=0)) - processed_output = interface.output_interface.postprocess(prediction) + prediction = interface.predict( + np.expand_dims(img_array, axis=0)) + processed_output = interface.output_interface.postprocess( + prediction) output = {'input': interface.input_interface.save_to_file(flag_dir, img), 'output': interface.output_interface.rebuild_flagged( flag_dir, {'data': {'output': processed_output}}), @@ -299,7 +285,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None): port = get_first_available_port( INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS ) - httpd = serve_files_in_background(interface, port, directory_to_serve, server_name) + httpd = serve_files_in_background( + interface, port, directory_to_serve, server_name) return port, httpd diff --git a/build/lib/gradio/static/css/interfaces/input/checkbox_group.css b/build/lib/gradio/static/css/interfaces/input/checkbox_group.css index e570830208..7d93fdcfd2 100644 --- a/build/lib/gradio/static/css/interfaces/input/checkbox_group.css +++ b/build/lib/gradio/static/css/interfaces/input/checkbox_group.css @@ -3,7 +3,7 @@ align-items: center; font-size: 18px; } -.checkbox_group input { +.checkbox_group input, .checkbox { margin: 0px 4px 0px 0px; } .checkbox_group label { diff --git a/build/lib/gradio/static/css/style.css b/build/lib/gradio/static/css/style.css index ce100507eb..6a15f296df 100644 --- a/build/lib/gradio/static/css/style.css +++ b/build/lib/gradio/static/css/style.css @@ -1,6 +1,5 @@ body { font-family: 'Open Sans', sans-serif; - font-size: 12px; margin: 0; } button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] { diff --git a/build/lib/gradio/static/js/interfaces/input/checkbox.js b/build/lib/gradio/static/js/interfaces/input/checkbox.js index 61f0694f86..3e533816ce 100644 --- a/build/lib/gradio/static/js/interfaces/input/checkbox.js +++ b/build/lib/gradio/static/js/interfaces/input/checkbox.js @@ -1,5 +1,5 @@ const checkbox = { - html: ``, + html: ``, init: function(opts) { this.target.css("height", "auto"); }, diff --git a/gradio/interface.py b/gradio/interface.py index c52343e497..40117cd509 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -82,8 +82,8 @@ class Interface: self.session = None self.server_name = server_name - def update_config_file(self, output_directory): - config = { + def get_config_file(self): + return { "input_interfaces": [ (iface.__class__.__name__.lower(), iface.get_template_context()) for iface in self.input_interfaces], @@ -95,7 +95,38 @@ class Interface: "show_input": self.show_input, "show_output": self.show_output, } - networking.set_config(config, output_directory) + + def process(self, raw_input): + processed_input = [input_interface.preprocess( + raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)] + predictions = [] + for predict_fn in self.predict: + if self.context: + if self.capture_session: + graph, sess = self.session + with graph.as_default(): + with sess.as_default(): + prediction = predict_fn(*processed_input, + self.context) + else: + prediction = predict_fn(*processed_input, + self.context) + else: + if self.capture_session: + graph, sess = self.session + with graph.as_default(): + with sess.as_default(): + prediction = predict_fn(*processed_input) + else: + prediction = predict_fn(*processed_input) + if len(self.output_interfaces) / \ + len(self.predict) == 1: + prediction = [prediction] + predictions.extend(prediction) + processed_output = [output_interface.postprocess( + predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] + return processed_output + def validate(self): if self.validate_flag: @@ -181,8 +212,7 @@ class Interface: server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name) path_to_local_server = "http://{}:{}/".format(self.server_name, server_port) networking.build_template(output_directory) - - self.update_config_file(output_directory) + networking.set_config(self.get_config_file(), output_directory) self.status = "RUNNING" self.simple_server = httpd diff --git a/gradio/networking.py b/gradio/networking.py index 7fbf5cecd0..f39bf80653 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -139,35 +139,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n int(self.headers["Content-Length"])) msg = json.loads(data_string) raw_input = msg["data"] - processed_input = [input_interface.preprocess( - raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)] - predictions = [] - for predict_fn in interface.predict: - if interface.context: - if interface.capture_session: - graph, sess = interface.session - with graph.as_default(): - with sess.as_default(): - prediction = predict_fn(*processed_input, - interface.context) - else: - prediction = predict_fn(*processed_input, - interface.context) - else: - if interface.capture_session: - graph, sess = interface.session - with graph.as_default(): - with sess.as_default(): - prediction = predict_fn(*processed_input) - else: - prediction = predict_fn(*processed_input) - if len(interface.output_interfaces) / \ - len(interface.predict) == 1: - prediction = [prediction] - predictions.extend(prediction) - processed_output = [output_interface.postprocess( - predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)] - output = {"action": "output", "data": processed_output} + output = {"action": "output", "data": interface.process(raw_input)} if interface.saliency is not None: saliency = interface.saliency(raw_input, prediction) output['saliency'] = saliency.tolist()