added ability to accept preprocessing arguments in some input intraces + tests

This commit is contained in:
Abubakar Abid 2019-03-05 22:51:36 -08:00
parent 2a9814af4e
commit 215bd05c0e
3 changed files with 31 additions and 11 deletions

View File

@ -44,6 +44,10 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28):
self.image_width = image_width
self.image_height = image_height
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/sketchpad_input.html'
@ -55,12 +59,17 @@ class Sketchpad(AbstractInput):
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (28, 28))
array = np.array(im).flatten().reshape(1, 28, 28, 1)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1)
return array
class Webcam(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/webcam_input.html'
@ -71,9 +80,9 @@ class Webcam(AbstractInput):
"""
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (48, 48))
array = np.array(im).flatten().reshape(1, 48, 48, 1)
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array
@ -90,6 +99,12 @@ class Textbox(AbstractInput):
class ImageUpload(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/image_upload_input.html'
@ -99,9 +114,9 @@ class ImageUpload(AbstractInput):
"""
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (48, 48))
array = np.array(im).flatten().reshape(1, 48, 48, 1)
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array

View File

@ -28,7 +28,7 @@ class TestWebcam(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.Webcam()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 1))
self.assertEqual(array.shape, (1, 224, 224, 3))
class TestTextbox(unittest.TestCase):
@ -52,7 +52,12 @@ class TestImageUpload(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.ImageUpload()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 1))
self.assertEqual(array.shape, (1, 224, 224, 3))
def test_preprocessing(self):
inp = inputs.ImageUpload(image_height=48, image_width=48)
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 3))
if __name__ == '__main__':

View File

@ -18,7 +18,7 @@ class TestInterface(unittest.TestCase):
def test_output_interface_is_instance(self):
out = gradio.outputs.Label(show_confidences=False)
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
self.assertEqual(io.input_interface, inp)
self.assertEqual(io.output_interface, out)
if __name__ == '__main__':