diff --git a/gradio/inputs.py b/gradio/inputs.py index 585c6dc979..2f4afb59ef 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -5,9 +5,7 @@ automatically added to a registry, which allows them to be easily referenced in """ from abc import ABC, abstractmethod -import base64 from gradio import preprocessing_utils, validation_data -from io import BytesIO import numpy as np from PIL import Image, ImageOps @@ -67,10 +65,9 @@ class AbstractInput(ABC): class Sketchpad(AbstractInput): - def __init__(self, preprocessing_fn=None, image_width=28, image_height=28, - invert_colors=True): - self.image_width = image_width - self.image_height = image_height + def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True): + self.image_width = shape[0] + self.image_height = shape[1] self.invert_colors = invert_colors super().__init__(preprocessing_fn=preprocessing_fn) @@ -81,9 +78,8 @@ class Sketchpad(AbstractInput): """ Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28 """ - content = inp.split(';')[1] - image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') + im = preprocessing_utils.encoding_to_image(inp) + im = im.convert('L') if self.invert_colors: im = ImageOps.invert(im) im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) @@ -108,9 +104,8 @@ class Webcam(AbstractInput): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ - content = inp.split(';')[1] - image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB') + im = preprocessing_utils.encoding_to_image(inp) + im = im.convert('RGB') im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels) return array @@ -131,15 +126,15 @@ class Textbox(AbstractInput): class ImageUpload(AbstractInput): - def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB', - scale=1/127.5, shift=-1, aspect_ratio="false"): - self.image_width = image_width - self.image_height = image_height - self.num_channels = num_channels + def __init__(self, preprocessing_fn=None, shape=(224, 224, 3), image_mode='RGB', + scale=1/127.5, shift=-1, cropper_aspect_ratio=None): + self.image_width = shape[0] + self.image_height = shape[1] + self.num_channels = shape[2] self.image_mode = image_mode self.scale = scale self.shift = shift - self.aspect_ratio = aspect_ratio + self.cropper_aspect_ratio = "false" if cropper_aspect_ratio is None else cropper_aspect_ratio super().__init__(preprocessing_fn=preprocessing_fn) def get_validation_inputs(self): @@ -149,15 +144,14 @@ class ImageUpload(AbstractInput): return 'image_upload' def get_js_context(self): - return {'aspect_ratio': self.aspect_ratio} + return {'aspect_ratio': self.cropper_aspect_ratio} def preprocess(self, inp): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ - content = inp.split(';')[1] - image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode) + im = preprocessing_utils.encoding_to_image(inp) + im = im.convert(self.image_mode) im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) im = np.array(im).flatten() im = im * self.scale + self.shift diff --git a/gradio/interface.py b/gradio/interface.py index 39d124884d..b3f9585c0e 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -28,7 +28,7 @@ class Interface: """ # Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description. - VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'function': 'python function', + VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'pyfunc': 'python function', 'pytorch': 'PyTorch model'} STATUS_TYPES = {'OFF': 'off', 'RUNNING': 'running'} @@ -94,7 +94,7 @@ class Interface: pass if callable(model): - return 'function' + return 'pyfunc' raise ValueError("model_type could not be inferred, please specify parameter `model_type`") @@ -127,7 +127,7 @@ class Interface: return self.model_obj.predict(preprocessed_input) elif self.model_type=='keras': return self.model_obj.predict(preprocessed_input) - elif self.model_type=='function': + elif self.model_type=='pyfunc': return self.model_obj(preprocessed_input) elif self.model_type=='pytorch': import torch diff --git a/gradio/preprocessing_utils.py b/gradio/preprocessing_utils.py index c933892c8b..3fc2f4e1b6 100644 --- a/gradio/preprocessing_utils.py +++ b/gradio/preprocessing_utils.py @@ -1,6 +1,13 @@ from PIL import Image +from io import BytesIO +import base64 +def encoding_to_image(encoding): + content = encoding.split(';')[1] + image_encoded = content.split(',')[1] + return Image.open(BytesIO(base64.b64decode(image_encoded))) + def resize_and_crop(img, size, crop_type='top'): """ Resize and crop an image to fit the specified size. diff --git a/web/getting_started.html b/web/getting_started.html index be463474ef..f33159335c 100644 --- a/web/getting_started.html +++ b/web/getting_started.html @@ -7,6 +7,10 @@ + + +