added tests for outputs

This commit is contained in:
Abubakar Abid 2019-02-27 16:54:08 -08:00
parent 29f8e2703f
commit 28b4d3cfc3
4 changed files with 66 additions and 16 deletions

View File

@ -16,7 +16,7 @@ class AbstractInput(ABC):
if preprocessing_fn is not None:
if not callable(preprocessing_fn):
raise ValueError('`preprocessing_fn` must be a callable function')
self._pre_process = preprocessing_fn
self._preprocess = preprocessing_fn
super().__init__()
@abstractmethod
@ -27,9 +27,9 @@ class AbstractInput(ABC):
pass
@abstractmethod
def _pre_process(self, inp):
def _preprocess(self, inp):
"""
All interfaces should define a method that returns the path to its template.
All interfaces should define a default preprocessing method
"""
pass
@ -39,7 +39,7 @@ class Sketchpad(AbstractInput):
def _get_template_path(self):
return 'templates/sketchpad_input.html'
def _pre_process(self, inp):
def _preprocess(self, inp):
"""
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
"""
@ -56,7 +56,7 @@ class Webcam(AbstractInput):
def _get_template_path(self):
return 'templates/webcam_input.html'
def _pre_process(self, inp):
def _preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
@ -73,7 +73,7 @@ class Textbox(AbstractInput):
def _get_template_path(self):
return 'templates/textbox_input.html'
def _pre_process(self, inp):
def _preprocess(self, inp):
"""
By default, no pre-processing is applied to text.
"""
@ -85,7 +85,7 @@ class ImageUpload(AbstractInput):
def _get_template_path(self):
return 'templates/image_upload_input.html'
def _pre_process(self, inp):
def _preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""

View File

@ -12,7 +12,7 @@ class AbstractOutput(ABC):
"""
"""
if postprocessing_fn is not None:
self._post_process = postprocessing_fn
self._postprocess = postprocessing_fn
super().__init__()
@abstractmethod
@ -23,9 +23,9 @@ class AbstractOutput(ABC):
pass
@abstractmethod
def _post_process(self):
def _postprocess(self, prediction):
"""
All interfaces should define a method that returns the path to its template.
All interfaces should define a default postprocessing method
"""
pass
@ -35,7 +35,7 @@ class Label(AbstractOutput):
def _get_template_path(self):
return 'templates/label_output.html'
def _post_process(self, prediction):
def _postprocess(self, prediction):
"""
"""
if isinstance(prediction, np.ndarray):
@ -55,7 +55,7 @@ class Textbox(AbstractOutput):
def _get_template_path(self):
return 'templates/textbox_output.html'
def _post_process(self, prediction):
def _postprocess(self, prediction):
"""
"""
return prediction

View File

@ -15,7 +15,7 @@ class TestSketchpad(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.Sketchpad()
array = inp._pre_process(BASE64_IMG)
array = inp._preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 28, 28, 1))
@ -27,7 +27,7 @@ class TestWebcam(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.Webcam()
array = inp._pre_process(BASE64_IMG)
array = inp._preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 1))
@ -39,7 +39,7 @@ class TestTextbox(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.Textbox()
string = inp._pre_process(RAND_STRING)
string = inp._preprocess(RAND_STRING)
self.assertEqual(string, RAND_STRING)
@ -51,7 +51,7 @@ class TestImageUpload(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.ImageUpload()
array = inp._pre_process(BASE64_IMG)
array = inp._preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 1))

View File

@ -0,0 +1,50 @@
import numpy as np
import unittest
import os
from gradio import outputs
PACKAGE_NAME = 'gradio'
class TestLabel(unittest.TestCase):
def test_path_exists(self):
out = outputs.Label()
path = out._get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
label = out._postprocess(string)
self.assertEqual(label, string)
def test_postprocessing_one_hot(self):
one_hot = np.array([0, 0, 0, 1, 0])
true_label = 3
out = outputs.Label()
label = out._postprocess(one_hot)
self.assertEqual(label, true_label)
def test_postprocessing_int(self):
true_label_array = np.array([[[3]]])
true_label = 3
out = outputs.Label()
label = out._postprocess(true_label_array)
self.assertEqual(label, true_label)
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
path = out._get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing(self):
string = 'happy'
out = outputs.Textbox()
string = out._postprocess(string)
self.assertEqual(string, string)
if __name__ == '__main__':
unittest.main()