This commit is contained in:
Ali Abid 2020-07-07 11:02:14 -07:00
parent a03f9a4706
commit 4debb0f829

View File

@ -67,7 +67,7 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput):
def __init__(self, shape=(28, 28), invert_colors=True,
def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
flatten=False, scale=1/255, shift=0,
dtype='float64', sample_inputs=None, label=None):
self.image_width = shape[0]
@ -110,10 +110,10 @@ class Sketchpad(AbstractInput):
class Webcam(AbstractInput):
def __init__(self, image_width=224, image_height=224, num_channels=3, label=None):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
def __init__(self, shape=(224, 224), label=None):
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = 3
super().__init__(label)
def get_validation_inputs(self):
@ -132,8 +132,7 @@ class Webcam(AbstractInput):
im = preprocessing_utils.decode_base64_to_image(inp)
im = im.convert('RGB')
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels)
return array
return np.array(im)
class Textbox(AbstractInput):
@ -244,9 +243,11 @@ class Checkbox(AbstractInput):
class Image(AbstractInput):
def __init__(self, cast_to=None, shape=(224, 224), label=None):
def __init__(self, cast_to=None, shape=(224, 224), image_mode='RGB', label=None):
self.cast_to = cast_to
self.image_width = shape[0]
self.image_height = shape[1]
self.image_mode = image_mode
super().__init__(label)
def get_validation_inputs(self):
@ -264,24 +265,36 @@ class Image(AbstractInput):
**super().get_template_context()
}
def cast_to_base64(self, inp):
return inp
def cast_to_im(self, inp):
return preprocessing_utils.decode_base64_to_image(inp)
def cast_to_numpy(self, inp):
im = self.cast_to_im(inp)
arr = np.array(im).flatten()
return arr
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
cast_to_type = {
"base64": self.cast_to_base64,
"numpy": self.cast_to_numpy,
"pillow": self.cast_to_im
}
if self.cast_to:
return cast_to_type[self.cast_to](inp)
im = preprocessing_utils.decode_base64_to_image(inp)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
im = np.array(im).flatten()
im = im * self.scale + self.shift
if self.num_channels is None:
array = im.reshape(self.image_width, self.image_height)
else:
array = im.reshape(self.image_width, self.image_height, \
self.num_channels)
return array
return np.array(im)
def process_example(self, example):
if os.path.exists(example):
@ -305,4 +318,4 @@ class Microphone(AbstractInput):
shortcuts = {}
for cls in AbstractInput.__subclasses__():
for shortcut, parameters in cls.get_shortcut_implementations().items():
shortcuts[shortcut] = cls(**parameters)
shortcuts[shortcut] = cls(**parameters)