This commit is contained in:
Ali Abid 2020-08-10 13:52:43 -07:00
parent c90fc29ebd
commit b8707fb628
3 changed files with 11 additions and 64 deletions

View File

@ -11,60 +11,22 @@ PACKAGE_NAME = 'gradio'
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
class TestSketchpad(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Sketchpad()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
class TestImage(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.Sketchpad()
inp = inputs.Image(shape=(20, 20))
array = inp.preprocess(BASE64_SKETCH)
self.assertEqual(array.shape, (1, 28, 28))
class TestWebcam(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Webcam()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Webcam()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (224, 224, 3))
self.assertEqual(array.shape, (20, 20, 3))
inp2 = inputs.Image(shape=(20, 20), image_mode="L")
array2 = inp2.preprocess(BASE64_SKETCH)
self.assertEqual(array2.shape, (20, 20))
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Textbox()
path = BASE_INPUT_INTERFACE_JS_PATH.format(
inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Textbox()
string = inp.preprocess(RAND_STRING)
self.assertEqual(string, RAND_STRING)
class TestImageUpload(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Image()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Image()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (224, 224, 3))
def test_preprocessing(self):
inp = inputs.Image()
inp.image_height = 48
inp.image_width = 48
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (48, 48, 3))
if __name__ == '__main__':
unittest.main()

View File

@ -7,24 +7,24 @@ import gradio.outputs
class TestInterface(unittest.TestCase):
def test_input_output_mapping(self):
io = gr.Interface(inputs='SketCHPad', outputs='TexT', fn=lambda x: x)
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad)
io = gr.Interface(inputs='sketchpad', outputs='text', fn=lambda x: x)
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Image)
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
def test_input_interface_is_instance(self):
inp = gradio.inputs.Image()
io = gr.Interface(inputs=inp, outputs='teXT', fn=lambda x: x)
io = gr.Interface(inputs=inp, outputs='text', fn=lambda x: x)
self.assertEqual(io.input_interfaces[0], inp)
def test_output_interface_is_instance(self):
out = gradio.outputs.Label()
io = gr.Interface(inputs='SketCHPad', outputs=out, fn=lambda x: x)
io = gr.Interface(inputs='sketchpad', outputs=out, fn=lambda x: x)
self.assertEqual(io.output_interfaces[0], out)
def test_prediction(self):
def model(x):
return 2*x
io = gr.Interface(inputs='textbox', outputs='TEXT', fn=model)
io = gr.Interface(inputs='textbox', outputs='text', fn=model)
self.assertEqual(io.predict[0](11), 22)

View File

@ -11,11 +11,6 @@ BASE64_IMG = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAYEBQYFBAY
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'
class TestLabel(unittest.TestCase):
def test_path_exists(self):
out = outputs.Label()
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
@ -52,11 +47,6 @@ class TestLabel(unittest.TestCase):
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing(self):
string = 'happy'
out = outputs.Textbox()
@ -65,11 +55,6 @@ class TestTextbox(unittest.TestCase):
class TestImage(unittest.TestCase):
def test_path_exists(self):
out = outputs.Image()
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing(self):
string = BASE64_IMG
out = outputs.Textbox()