2019-02-28 08:43:20 +08:00
|
|
|
import unittest
|
|
|
|
import os
|
|
|
|
from gradio import inputs
|
|
|
|
|
|
|
|
BASE64_IMG = "
|
2019-07-22 09:46:00 +08:00
|
|
|
BASE64_SKETCH = "
|
2019-02-28 08:43:20 +08:00
|
|
|
RAND_STRING = "2wBDAAYEBQYFBAYGBQYHBwYIC"
|
|
|
|
PACKAGE_NAME = 'gradio'
|
|
|
|
|
|
|
|
|
|
|
|
class TestSketchpad(unittest.TestCase):
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_path_exists(self):
|
|
|
|
inp = inputs.Sketchpad()
|
|
|
|
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.get_name())
|
2019-06-20 03:03:54 +08:00
|
|
|
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
2019-02-28 08:43:20 +08:00
|
|
|
|
|
|
|
def test_preprocessing(self):
|
|
|
|
inp = inputs.Sketchpad()
|
2019-07-22 09:46:00 +08:00
|
|
|
array = inp.preprocess(BASE64_SKETCH)
|
2019-03-10 10:05:30 +08:00
|
|
|
self.assertEqual(array.shape, (1, 28, 28))
|
2019-02-28 08:43:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestWebcam(unittest.TestCase):
|
|
|
|
def test_path_exists(self):
|
|
|
|
inp = inputs.Webcam()
|
2019-06-19 04:13:50 +08:00
|
|
|
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.get_name())
|
2019-06-20 03:03:54 +08:00
|
|
|
self.assertFalse(os.path.exists(os.path.join(PACKAGE_NAME, path))) # Note implemented yet.
|
2019-02-28 08:43:20 +08:00
|
|
|
|
|
|
|
def test_preprocessing(self):
|
|
|
|
inp = inputs.Webcam()
|
2019-02-28 13:39:18 +08:00
|
|
|
array = inp.preprocess(BASE64_IMG)
|
2019-03-06 14:51:36 +08:00
|
|
|
self.assertEqual(array.shape, (1, 224, 224, 3))
|
2019-02-28 08:43:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestTextbox(unittest.TestCase):
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_path_exists(self):
|
|
|
|
inp = inputs.Textbox()
|
|
|
|
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.get_name())
|
2019-06-20 03:03:54 +08:00
|
|
|
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
2019-02-28 08:43:20 +08:00
|
|
|
|
|
|
|
def test_preprocessing(self):
|
|
|
|
inp = inputs.Textbox()
|
2019-02-28 13:39:18 +08:00
|
|
|
string = inp.preprocess(RAND_STRING)
|
2019-02-28 08:43:20 +08:00
|
|
|
self.assertEqual(string, RAND_STRING)
|
|
|
|
|
|
|
|
|
|
|
|
class TestImageUpload(unittest.TestCase):
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_path_exists(self):
|
|
|
|
inp = inputs.ImageUpload()
|
|
|
|
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.get_name())
|
2019-06-20 03:03:54 +08:00
|
|
|
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
2019-02-28 08:43:20 +08:00
|
|
|
|
|
|
|
def test_preprocessing(self):
|
|
|
|
inp = inputs.ImageUpload()
|
2019-02-28 13:39:18 +08:00
|
|
|
array = inp.preprocess(BASE64_IMG)
|
2019-03-06 14:51:36 +08:00
|
|
|
self.assertEqual(array.shape, (1, 224, 224, 3))
|
|
|
|
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_preprocessing(self):
|
|
|
|
inp = inputs.ImageUpload()
|
|
|
|
inp.image_height = 48
|
|
|
|
inp.image_width = 48
|
|
|
|
array = inp.preprocess(BASE64_IMG)
|
|
|
|
self.assertEqual(array.shape, (1, 48, 48, 3))
|
2019-02-28 08:43:20 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2019-06-19 04:13:50 +08:00
|
|
|
unittest.main()
|