From ee26d3acc8b1877036e1ac6fc7d8c24b5e9ab4df Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 5 Mar 2019 22:55:40 -0800 Subject: [PATCH] added image mode to image upload input interface --- gradio/inputs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gradio/inputs.py b/gradio/inputs.py index 634d86408c..fb2fd8de7e 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -99,10 +99,11 @@ class Textbox(AbstractInput): class ImageUpload(AbstractInput): - def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3): + def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB'): self.image_width = image_width self.image_height = image_height self.num_channels = num_channels + self.image_mode = image_mode super().__init__(preprocessing_fn=preprocessing_fn) def get_template_path(self): @@ -114,7 +115,7 @@ class ImageUpload(AbstractInput): """ content = inp.split(';')[1] image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB') + im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode) im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels) return array