gradio/test/test_inputs.py
2020-08-05 10:42:52 -07:00

70 lines
24 KiB
Python

import unittest
import os
from gradio import inputs
BASE64_IMG = ""
BASE64_SKETCH = ""
RAND_STRING = "2wBDAAYEBQYFBAYGBQYHBwYIC"
PACKAGE_NAME = 'gradio'
# Where to find the static resources associated with each template.
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
class TestSketchpad(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Sketchpad()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Sketchpad()
array = inp.preprocess(BASE64_SKETCH)
self.assertEqual(array.shape, (1, 28, 28))
class TestWebcam(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Webcam()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Webcam()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (224, 224, 3))
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Textbox()
path = BASE_INPUT_INTERFACE_JS_PATH.format(
inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Textbox()
string = inp.preprocess(RAND_STRING)
self.assertEqual(string, RAND_STRING)
class TestImageUpload(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Image()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Image()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (224, 224, 3))
def test_preprocessing(self):
inp = inputs.Image()
inp.image_height = 48
inp.image_width = 48
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (48, 48, 3))
if __name__ == '__main__':
unittest.main()