mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
added tests for outputs
This commit is contained in:
parent
29f8e2703f
commit
28b4d3cfc3
@ -16,7 +16,7 @@ class AbstractInput(ABC):
|
||||
if preprocessing_fn is not None:
|
||||
if not callable(preprocessing_fn):
|
||||
raise ValueError('`preprocessing_fn` must be a callable function')
|
||||
self._pre_process = preprocessing_fn
|
||||
self._preprocess = preprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
@ -27,9 +27,9 @@ class AbstractInput(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _pre_process(self, inp):
|
||||
def _preprocess(self, inp):
|
||||
"""
|
||||
All interfaces should define a method that returns the path to its template.
|
||||
All interfaces should define a default preprocessing method
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -39,7 +39,7 @@ class Sketchpad(AbstractInput):
|
||||
def _get_template_path(self):
|
||||
return 'templates/sketchpad_input.html'
|
||||
|
||||
def _pre_process(self, inp):
|
||||
def _preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
|
||||
"""
|
||||
@ -56,7 +56,7 @@ class Webcam(AbstractInput):
|
||||
def _get_template_path(self):
|
||||
return 'templates/webcam_input.html'
|
||||
|
||||
def _pre_process(self, inp):
|
||||
def _preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
@ -73,7 +73,7 @@ class Textbox(AbstractInput):
|
||||
def _get_template_path(self):
|
||||
return 'templates/textbox_input.html'
|
||||
|
||||
def _pre_process(self, inp):
|
||||
def _preprocess(self, inp):
|
||||
"""
|
||||
By default, no pre-processing is applied to text.
|
||||
"""
|
||||
@ -85,7 +85,7 @@ class ImageUpload(AbstractInput):
|
||||
def _get_template_path(self):
|
||||
return 'templates/image_upload_input.html'
|
||||
|
||||
def _pre_process(self, inp):
|
||||
def _preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
|
@ -12,7 +12,7 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
"""
|
||||
if postprocessing_fn is not None:
|
||||
self._post_process = postprocessing_fn
|
||||
self._postprocess = postprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
@ -23,9 +23,9 @@ class AbstractOutput(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _post_process(self):
|
||||
def _postprocess(self, prediction):
|
||||
"""
|
||||
All interfaces should define a method that returns the path to its template.
|
||||
All interfaces should define a default postprocessing method
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -35,7 +35,7 @@ class Label(AbstractOutput):
|
||||
def _get_template_path(self):
|
||||
return 'templates/label_output.html'
|
||||
|
||||
def _post_process(self, prediction):
|
||||
def _postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
if isinstance(prediction, np.ndarray):
|
||||
@ -55,7 +55,7 @@ class Textbox(AbstractOutput):
|
||||
def _get_template_path(self):
|
||||
return 'templates/textbox_output.html'
|
||||
|
||||
def _post_process(self, prediction):
|
||||
def _postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
return prediction
|
||||
|
@ -15,7 +15,7 @@ class TestSketchpad(unittest.TestCase):
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Sketchpad()
|
||||
array = inp._pre_process(BASE64_IMG)
|
||||
array = inp._preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 28, 28, 1))
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ class TestWebcam(unittest.TestCase):
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Webcam()
|
||||
array = inp._pre_process(BASE64_IMG)
|
||||
array = inp._preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class TestTextbox(unittest.TestCase):
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Textbox()
|
||||
string = inp._pre_process(RAND_STRING)
|
||||
string = inp._preprocess(RAND_STRING)
|
||||
self.assertEqual(string, RAND_STRING)
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ class TestImageUpload(unittest.TestCase):
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.ImageUpload()
|
||||
array = inp._pre_process(BASE64_IMG)
|
||||
array = inp._preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
|
||||
|
||||
|
@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
import os
|
||||
from gradio import outputs
|
||||
|
||||
PACKAGE_NAME = 'gradio'
|
||||
|
||||
|
||||
class TestLabel(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Label()
|
||||
path = out._get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing_string(self):
|
||||
string = 'happy'
|
||||
out = outputs.Label()
|
||||
label = out._postprocess(string)
|
||||
self.assertEqual(label, string)
|
||||
|
||||
def test_postprocessing_one_hot(self):
|
||||
one_hot = np.array([0, 0, 0, 1, 0])
|
||||
true_label = 3
|
||||
out = outputs.Label()
|
||||
label = out._postprocess(one_hot)
|
||||
self.assertEqual(label, true_label)
|
||||
|
||||
def test_postprocessing_int(self):
|
||||
true_label_array = np.array([[[3]]])
|
||||
true_label = 3
|
||||
out = outputs.Label()
|
||||
label = out._postprocess(true_label_array)
|
||||
self.assertEqual(label, true_label)
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Textbox()
|
||||
path = out._get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing(self):
|
||||
string = 'happy'
|
||||
out = outputs.Textbox()
|
||||
string = out._postprocess(string)
|
||||
self.assertEqual(string, string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user