sample-inputs for sketchpad

This commit is contained in:
dawoodkhan82 2019-07-22 13:55:02 -07:00
parent b43e76b148
commit 8ef628d203
6 changed files with 60 additions and 9 deletions

View File

@ -6,3 +6,4 @@ psutil
paramiko paramiko
scipy scipy
IPython IPython
scikit-image

View File

@ -53,6 +53,12 @@ class AbstractInput(ABC):
""" """
return {} return {}
def sample_inputs(self):
"""
An interface can optionally implement a method that sends a list of sample inputs for inference.
"""
return []
@abstractmethod @abstractmethod
def get_name(self): def get_name(self):
""" """
@ -77,7 +83,7 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput): class Sketchpad(AbstractInput):
def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0, def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0,
dtype='float64'): dtype='float64', sample_inputs=None):
self.image_width = shape[0] self.image_width = shape[0]
self.image_height = shape[1] self.image_height = shape[1]
self.invert_colors = invert_colors self.invert_colors = invert_colors
@ -85,8 +91,10 @@ class Sketchpad(AbstractInput):
self.scale = scale self.scale = scale
self.shift = shift self.shift = shift
self.dtype = dtype self.dtype = dtype
self.sample_inputs = sample_inputs
super().__init__(preprocessing_fn=preprocessing_fn) super().__init__(preprocessing_fn=preprocessing_fn)
def get_name(self): def get_name(self):
return 'sketchpad' return 'sketchpad'
@ -121,6 +129,13 @@ class Sketchpad(AbstractInput):
im.save(f'{dir}/{filename}', 'PNG') im.save(f'{dir}/{filename}', 'PNG')
return filename return filename
def get_sample_inputs(self):
encoded_images = []
if self.sample_inputs is not None:
for input in self.sample_inputs:
encoded_images.append(preprocessing_utils.encode_array_to_base64(input))
return encoded_images
class Webcam(AbstractInput): class Webcam(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3): def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):

View File

@ -130,6 +130,22 @@ class Interface:
raise ValueError("model_type could not be inferred, please specify parameter `model_type`") raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
@staticmethod
def update_config_file(self, output_directory):
networking.set_interface_types_in_config_file(
output_directory,
self.input_interface.__class__.__name__.lower(),
self.output_interface.__class__.__name__.lower(),
)
if self.input_interface.__class__.__name__.lower() == "sketchpad" or self.input_interface.__class__.__name__.lower() == "textbox":
networking.set_sample_data_in_config_file(
output_directory,
self.input_interface.get_sample_inputs()
)
def predict(self, preprocessed_input): def predict(self, preprocessed_input):
""" """
Method that calls the relevant method of the model object to make a prediction. Method that calls the relevant method of the model object to make a prediction.
@ -235,11 +251,8 @@ class Interface:
output_directory, self.input_interface, self.output_interface output_directory, self.input_interface, self.output_interface
) )
networking.set_interface_types_in_config_file( self.update_config_file(self, output_directory)
output_directory,
self.input_interface.__class__.__name__.lower(),
self.output_interface.__class__.__name__.lower(),
)
self.status = self.STATUS_TYPES["RUNNING"] self.status = self.STATUS_TYPES["RUNNING"]
self.simple_server = httpd self.simple_server = httpd

View File

@ -117,6 +117,16 @@ def set_share_url_in_config_file(temp_dir, share_url):
) )
def set_sample_data_in_config_file(temp_dir, sample_inputs):
config_file = os.path.join(temp_dir, CONFIG_FILE)
render_template_with_tags(
config_file,
{
"sample_inputs": sample_inputs
},
)
def get_first_available_port(initial, final): def get_first_available_port(initial, final):
""" """
Gets the first open port in a specified range of port numbers Gets the first open port in a specified range of port numbers

View File

@ -5,6 +5,7 @@ import tempfile
import scipy.io.wavfile import scipy.io.wavfile
from scipy.fftpack import dct from scipy.fftpack import dct
import numpy as np import numpy as np
import skimage
######################### #########################
@ -16,6 +17,16 @@ def decode_base64_to_image(encoding):
return Image.open(BytesIO(base64.b64decode(image_encoded))) return Image.open(BytesIO(base64.b64decode(image_encoded)))
def encode_array_to_base64(image_array):
with BytesIO() as output_bytes:
PIL_image = Image.fromarray(skimage.img_as_ubyte(image_array))
PIL_image.save(output_bytes, 'PNG')
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), 'utf-8')
return base64_str
def resize_and_crop(img, size, crop_type='top'): def resize_and_crop(img, size, crop_type='top'):
""" """
Resize and crop an image to fit the specified size. Resize and crop an image to fit the specified size.
@ -24,8 +35,8 @@ def resize_and_crop(img, size, crop_type='top'):
modified_path: path to store the modified image. modified_path: path to store the modified image.
size: `(width, height)` tuple. size: `(width, height)` tuple.
crop_type: can be 'top', 'middle' or 'bottom', depending on this crop_type: can be 'top', 'middle' or 'bottom', depending on this
value, the image will cropped getting the 'top/left', 'midle' or value, the image will cropped getting the 'top/left', 'middle' or
'bottom/rigth' of the image to fit the size. 'bottom/right' of the image to fit the size.
raises: raises:
Exception: if can not open the file in img_path of there is problems Exception: if can not open the file in img_path of there is problems
to save the image. to save the image.

View File

@ -1,5 +1,6 @@
{ {
"input_interface_type": "{{input_interface_type}}", "input_interface_type": "{{input_interface_type}}",
"output_interface_type": "{{output_interface_type}}", "output_interface_type": "{{output_interface_type}}",
"share_url": "{{share_url}}" "share_url": "{{share_url}}",
"sample_inputs": "{{sample_inputs}}"
} }