diff --git a/gradio/inputs.py b/gradio/inputs.py index 6a9a6c766e..5df840e527 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -16,7 +16,7 @@ class AbstractInput(ABC): if preprocessing_fn is not None: if not callable(preprocessing_fn): raise ValueError('`preprocessing_fn` must be a callable function') - self._pre_process = preprocessing_fn + self._preprocess = preprocessing_fn super().__init__() @abstractmethod @@ -27,9 +27,9 @@ class AbstractInput(ABC): pass @abstractmethod - def _pre_process(self, inp): + def _preprocess(self, inp): """ - All interfaces should define a method that returns the path to its template. + All interfaces should define a default preprocessing method """ pass @@ -39,7 +39,7 @@ class Sketchpad(AbstractInput): def _get_template_path(self): return 'templates/sketchpad_input.html' - def _pre_process(self, inp): + def _preprocess(self, inp): """ Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28 """ @@ -56,7 +56,7 @@ class Webcam(AbstractInput): def _get_template_path(self): return 'templates/webcam_input.html' - def _pre_process(self, inp): + def _preprocess(self, inp): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ @@ -73,7 +73,7 @@ class Textbox(AbstractInput): def _get_template_path(self): return 'templates/textbox_input.html' - def _pre_process(self, inp): + def _preprocess(self, inp): """ By default, no pre-processing is applied to text. """ @@ -85,7 +85,7 @@ class ImageUpload(AbstractInput): def _get_template_path(self): return 'templates/image_upload_input.html' - def _pre_process(self, inp): + def _preprocess(self, inp): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ diff --git a/gradio/outputs.py b/gradio/outputs.py index 6d239b3d56..37819a1537 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -12,7 +12,7 @@ class AbstractOutput(ABC): """ """ if postprocessing_fn is not None: - self._post_process = postprocessing_fn + self._postprocess = postprocessing_fn super().__init__() @abstractmethod @@ -23,9 +23,9 @@ class AbstractOutput(ABC): pass @abstractmethod - def _post_process(self): + def _postprocess(self, prediction): """ - All interfaces should define a method that returns the path to its template. + All interfaces should define a default postprocessing method """ pass @@ -35,7 +35,7 @@ class Label(AbstractOutput): def _get_template_path(self): return 'templates/label_output.html' - def _post_process(self, prediction): + def _postprocess(self, prediction): """ """ if isinstance(prediction, np.ndarray): @@ -55,7 +55,7 @@ class Textbox(AbstractOutput): def _get_template_path(self): return 'templates/textbox_output.html' - def _post_process(self, prediction): + def _postprocess(self, prediction): """ """ return prediction diff --git a/test/test_inputs.py b/test/test_inputs.py index e9f4cf1291..e8ef2a6fe4 100644 --- a/test/test_inputs.py +++ b/test/test_inputs.py @@ -15,7 +15,7 @@ class TestSketchpad(unittest.TestCase): def test_preprocessing(self): inp = inputs.Sketchpad() - array = inp._pre_process(BASE64_IMG) + array = inp._preprocess(BASE64_IMG) self.assertEqual(array.shape, (1, 28, 28, 1)) @@ -27,7 +27,7 @@ class TestWebcam(unittest.TestCase): def test_preprocessing(self): inp = inputs.Webcam() - array = inp._pre_process(BASE64_IMG) + array = inp._preprocess(BASE64_IMG) self.assertEqual(array.shape, (1, 48, 48, 1)) @@ -39,7 +39,7 @@ class TestTextbox(unittest.TestCase): def test_preprocessing(self): inp = inputs.Textbox() - string = inp._pre_process(RAND_STRING) + string = inp._preprocess(RAND_STRING) self.assertEqual(string, RAND_STRING) @@ -51,7 +51,7 @@ class TestImageUpload(unittest.TestCase): def test_preprocessing(self): inp = inputs.ImageUpload() - array = inp._pre_process(BASE64_IMG) + array = inp._preprocess(BASE64_IMG) self.assertEqual(array.shape, (1, 48, 48, 1)) diff --git a/test/test_outputs.py b/test/test_outputs.py index e69de29bb2..1a97cd237b 100644 --- a/test/test_outputs.py +++ b/test/test_outputs.py @@ -0,0 +1,50 @@ +import numpy as np +import unittest +import os +from gradio import outputs + +PACKAGE_NAME = 'gradio' + + +class TestLabel(unittest.TestCase): + def test_path_exists(self): + out = outputs.Label() + path = out._get_template_path() + self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) + + def test_postprocessing_string(self): + string = 'happy' + out = outputs.Label() + label = out._postprocess(string) + self.assertEqual(label, string) + + def test_postprocessing_one_hot(self): + one_hot = np.array([0, 0, 0, 1, 0]) + true_label = 3 + out = outputs.Label() + label = out._postprocess(one_hot) + self.assertEqual(label, true_label) + + def test_postprocessing_int(self): + true_label_array = np.array([[[3]]]) + true_label = 3 + out = outputs.Label() + label = out._postprocess(true_label_array) + self.assertEqual(label, true_label) + + +class TestTextbox(unittest.TestCase): + def test_path_exists(self): + out = outputs.Textbox() + path = out._get_template_path() + self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) + + def test_postprocessing(self): + string = 'happy' + out = outputs.Textbox() + string = out._postprocess(string) + self.assertEqual(string, string) + + +if __name__ == '__main__': + unittest.main()