From 6caeca54d84f4b9ac5e64294f37f12102b0c693a Mon Sep 17 00:00:00 2001 From: dawoodkhan82 Date: Mon, 22 Jul 2019 13:55:02 -0700 Subject: [PATCH] sample-inputs for sketchpad --- gradio.egg-info/requires.txt | 1 + gradio/inputs.py | 17 ++++++++++++++++- gradio/interface.py | 23 ++++++++++++++++++----- gradio/networking.py | 10 ++++++++++ gradio/preprocessing_utils.py | 15 +++++++++++++-- gradio/static/config.json | 3 ++- 6 files changed, 60 insertions(+), 9 deletions(-) diff --git a/gradio.egg-info/requires.txt b/gradio.egg-info/requires.txt index 704fad562a..9f67db07f8 100644 --- a/gradio.egg-info/requires.txt +++ b/gradio.egg-info/requires.txt @@ -6,3 +6,4 @@ psutil paramiko scipy IPython +scikit-image diff --git a/gradio/inputs.py b/gradio/inputs.py index bdaae47002..bb91dcf8b4 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -53,6 +53,12 @@ class AbstractInput(ABC): """ return {} + def sample_inputs(self): + """ + An interface can optionally implement a method that sends a list of sample inputs for inference. + """ + return [] + @abstractmethod def get_name(self): """ @@ -77,7 +83,7 @@ class AbstractInput(ABC): class Sketchpad(AbstractInput): def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0, - dtype='float64'): + dtype='float64', sample_inputs=None): self.image_width = shape[0] self.image_height = shape[1] self.invert_colors = invert_colors @@ -85,8 +91,10 @@ class Sketchpad(AbstractInput): self.scale = scale self.shift = shift self.dtype = dtype + self.sample_inputs = sample_inputs super().__init__(preprocessing_fn=preprocessing_fn) + def get_name(self): return 'sketchpad' @@ -121,6 +129,13 @@ class Sketchpad(AbstractInput): im.save(f'{dir}/{filename}', 'PNG') return filename + def get_sample_inputs(self): + encoded_images = [] + if self.sample_inputs is not None: + for input in self.sample_inputs: + encoded_images.append(preprocessing_utils.encode_array_to_base64(input)) + return encoded_images + class Webcam(AbstractInput): def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3): diff --git a/gradio/interface.py b/gradio/interface.py index b07dfdd570..be84b44572 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -130,6 +130,22 @@ class Interface: raise ValueError("model_type could not be inferred, please specify parameter `model_type`") + + @staticmethod + def update_config_file(self, output_directory): + networking.set_interface_types_in_config_file( + output_directory, + self.input_interface.__class__.__name__.lower(), + self.output_interface.__class__.__name__.lower(), + ) + + if self.input_interface.__class__.__name__.lower() == "sketchpad" or self.input_interface.__class__.__name__.lower() == "textbox": + networking.set_sample_data_in_config_file( + output_directory, + self.input_interface.get_sample_inputs() + ) + + def predict(self, preprocessed_input): """ Method that calls the relevant method of the model object to make a prediction. @@ -235,11 +251,8 @@ class Interface: output_directory, self.input_interface, self.output_interface ) - networking.set_interface_types_in_config_file( - output_directory, - self.input_interface.__class__.__name__.lower(), - self.output_interface.__class__.__name__.lower(), - ) + self.update_config_file(self, output_directory) + self.status = self.STATUS_TYPES["RUNNING"] self.simple_server = httpd diff --git a/gradio/networking.py b/gradio/networking.py index 441ea4ad04..26bbb38813 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -117,6 +117,16 @@ def set_share_url_in_config_file(temp_dir, share_url): ) +def set_sample_data_in_config_file(temp_dir, sample_inputs): + config_file = os.path.join(temp_dir, CONFIG_FILE) + render_template_with_tags( + config_file, + { + "sample_inputs": sample_inputs + }, + ) + + def get_first_available_port(initial, final): """ Gets the first open port in a specified range of port numbers diff --git a/gradio/preprocessing_utils.py b/gradio/preprocessing_utils.py index bb3025b9cc..60f4cb0815 100644 --- a/gradio/preprocessing_utils.py +++ b/gradio/preprocessing_utils.py @@ -5,6 +5,7 @@ import tempfile import scipy.io.wavfile from scipy.fftpack import dct import numpy as np +import skimage ######################### @@ -16,6 +17,16 @@ def decode_base64_to_image(encoding): return Image.open(BytesIO(base64.b64decode(image_encoded))) +def encode_array_to_base64(image_array): + with BytesIO() as output_bytes: + PIL_image = Image.fromarray(skimage.img_as_ubyte(image_array)) + PIL_image.save(output_bytes, 'PNG') + bytes_data = output_bytes.getvalue() + + base64_str = str(base64.b64encode(bytes_data), 'utf-8') + return base64_str + + def resize_and_crop(img, size, crop_type='top'): """ Resize and crop an image to fit the specified size. @@ -24,8 +35,8 @@ def resize_and_crop(img, size, crop_type='top'): modified_path: path to store the modified image. size: `(width, height)` tuple. crop_type: can be 'top', 'middle' or 'bottom', depending on this - value, the image will cropped getting the 'top/left', 'midle' or - 'bottom/rigth' of the image to fit the size. + value, the image will cropped getting the 'top/left', 'middle' or + 'bottom/right' of the image to fit the size. raises: Exception: if can not open the file in img_path of there is problems to save the image. diff --git a/gradio/static/config.json b/gradio/static/config.json index 8ee60a2a23..dff5d34881 100644 --- a/gradio/static/config.json +++ b/gradio/static/config.json @@ -1,5 +1,6 @@ { "input_interface_type": "{{input_interface_type}}", "output_interface_type": "{{output_interface_type}}", - "share_url": "{{share_url}}" + "share_url": "{{share_url}}", + "sample_inputs": "{{sample_inputs}}" }