gradio/test/test_inputs.py

60 lines
14 KiB
Python
Raw Normal View History

2019-02-28 08:43:20 +08:00
import unittest
import os
from gradio import inputs
BASE64_IMG = "
RAND_STRING = "2wBDAAYEBQYFBAYGBQYHBwYIC"
PACKAGE_NAME = 'gradio'
class TestSketchpad(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Sketchpad()
path = inp.get_template_path()
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:43:20 +08:00
def test_preprocessing(self):
inp = inputs.Sketchpad()
array = inp.preprocess(BASE64_IMG)
2019-02-28 08:43:20 +08:00
self.assertEqual(array.shape, (1, 28, 28, 1))
class TestWebcam(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Webcam()
path = inp.get_template_path()
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:43:20 +08:00
def test_preprocessing(self):
inp = inputs.Webcam()
array = inp.preprocess(BASE64_IMG)
2019-02-28 08:43:20 +08:00
self.assertEqual(array.shape, (1, 48, 48, 1))
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Textbox()
path = inp.get_template_path()
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:43:20 +08:00
def test_preprocessing(self):
inp = inputs.Textbox()
string = inp.preprocess(RAND_STRING)
2019-02-28 08:43:20 +08:00
self.assertEqual(string, RAND_STRING)
class TestImageUpload(unittest.TestCase):
def test_path_exists(self):
inp = inputs.ImageUpload()
path = inp.get_template_path()
2019-02-28 08:43:20 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.ImageUpload()
array = inp.preprocess(BASE64_IMG)
2019-02-28 08:43:20 +08:00
self.assertEqual(array.shape, (1, 48, 48, 1))
if __name__ == '__main__':
unittest.main()