From 215bd05c0ea4af2a8b8b94f04b391dc23aad85b5 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 5 Mar 2019 22:51:36 -0800 Subject: [PATCH] added ability to accept preprocessing arguments in some input intraces + tests --- gradio/inputs.py | 31 +++++++++++++++++++++++-------- test/test_inputs.py | 9 +++++++-- test/test_interface.py | 2 +- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/gradio/inputs.py b/gradio/inputs.py index 67d159761d..634d86408c 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -44,6 +44,10 @@ class AbstractInput(ABC): class Sketchpad(AbstractInput): + def __init__(self, preprocessing_fn=None, image_width=28, image_height=28): + self.image_width = image_width + self.image_height = image_height + super().__init__(preprocessing_fn=preprocessing_fn) def get_template_path(self): return 'templates/sketchpad_input.html' @@ -55,12 +59,17 @@ class Sketchpad(AbstractInput): content = inp.split(';')[1] image_encoded = content.split(',')[1] im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') - im = preprocessing_utils.resize_and_crop(im, (28, 28)) - array = np.array(im).flatten().reshape(1, 28, 28, 1) + im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) + array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1) return array class Webcam(AbstractInput): + def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3): + self.image_width = image_width + self.image_height = image_height + self.num_channels = num_channels + super().__init__(preprocessing_fn=preprocessing_fn) def get_template_path(self): return 'templates/webcam_input.html' @@ -71,9 +80,9 @@ class Webcam(AbstractInput): """ content = inp.split(';')[1] image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') - im = preprocessing_utils.resize_and_crop(im, (48, 48)) - array = np.array(im).flatten().reshape(1, 48, 48, 1) + im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB') + im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) + array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels) return array @@ -90,6 +99,12 @@ class Textbox(AbstractInput): class ImageUpload(AbstractInput): + def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3): + self.image_width = image_width + self.image_height = image_height + self.num_channels = num_channels + super().__init__(preprocessing_fn=preprocessing_fn) + def get_template_path(self): return 'templates/image_upload_input.html' @@ -99,9 +114,9 @@ class ImageUpload(AbstractInput): """ content = inp.split(';')[1] image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') - im = preprocessing_utils.resize_and_crop(im, (48, 48)) - array = np.array(im).flatten().reshape(1, 48, 48, 1) + im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB') + im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) + array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels) return array diff --git a/test/test_inputs.py b/test/test_inputs.py index 39924f3b1d..11a4ec1d98 100644 --- a/test/test_inputs.py +++ b/test/test_inputs.py @@ -28,7 +28,7 @@ class TestWebcam(unittest.TestCase): def test_preprocessing(self): inp = inputs.Webcam() array = inp.preprocess(BASE64_IMG) - self.assertEqual(array.shape, (1, 48, 48, 1)) + self.assertEqual(array.shape, (1, 224, 224, 3)) class TestTextbox(unittest.TestCase): @@ -52,7 +52,12 @@ class TestImageUpload(unittest.TestCase): def test_preprocessing(self): inp = inputs.ImageUpload() array = inp.preprocess(BASE64_IMG) - self.assertEqual(array.shape, (1, 48, 48, 1)) + self.assertEqual(array.shape, (1, 224, 224, 3)) + + def test_preprocessing(self): + inp = inputs.ImageUpload(image_height=48, image_width=48) + array = inp.preprocess(BASE64_IMG) + self.assertEqual(array.shape, (1, 48, 48, 3)) if __name__ == '__main__': diff --git a/test/test_interface.py b/test/test_interface.py index 995ee237a8..1a00987e61 100644 --- a/test/test_interface.py +++ b/test/test_interface.py @@ -18,7 +18,7 @@ class TestInterface(unittest.TestCase): def test_output_interface_is_instance(self): out = gradio.outputs.Label(show_confidences=False) io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function') - self.assertEqual(io.input_interface, inp) + self.assertEqual(io.output_interface, out) if __name__ == '__main__':