mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
added ability to accept preprocessing arguments in some input intraces + tests
This commit is contained in:
parent
2a9814af4e
commit
215bd05c0e
@ -44,6 +44,10 @@ class AbstractInput(ABC):
|
||||
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/sketchpad_input.html'
|
||||
@ -55,12 +59,17 @@ class Sketchpad(AbstractInput):
|
||||
content = inp.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
|
||||
im = preprocessing_utils.resize_and_crop(im, (28, 28))
|
||||
array = np.array(im).flatten().reshape(1, 28, 28, 1)
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1)
|
||||
return array
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/webcam_input.html'
|
||||
@ -71,9 +80,9 @@ class Webcam(AbstractInput):
|
||||
"""
|
||||
content = inp.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
|
||||
im = preprocessing_utils.resize_and_crop(im, (48, 48))
|
||||
array = np.array(im).flatten().reshape(1, 48, 48, 1)
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
|
||||
@ -90,6 +99,12 @@ class Textbox(AbstractInput):
|
||||
|
||||
|
||||
class ImageUpload(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/image_upload_input.html'
|
||||
|
||||
@ -99,9 +114,9 @@ class ImageUpload(AbstractInput):
|
||||
"""
|
||||
content = inp.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
|
||||
im = preprocessing_utils.resize_and_crop(im, (48, 48))
|
||||
array = np.array(im).flatten().reshape(1, 48, 48, 1)
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ class TestWebcam(unittest.TestCase):
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Webcam()
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
self.assertEqual(array.shape, (1, 224, 224, 3))
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
@ -52,7 +52,12 @@ class TestImageUpload(unittest.TestCase):
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.ImageUpload()
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
self.assertEqual(array.shape, (1, 224, 224, 3))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.ImageUpload(image_height=48, image_width=48)
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -18,7 +18,7 @@ class TestInterface(unittest.TestCase):
|
||||
def test_output_interface_is_instance(self):
|
||||
out = gradio.outputs.Label(show_confidences=False)
|
||||
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
|
||||
self.assertEqual(io.input_interface, inp)
|
||||
self.assertEqual(io.output_interface, out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user