mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
tests
This commit is contained in:
parent
c90fc29ebd
commit
b8707fb628
@ -11,60 +11,22 @@ PACKAGE_NAME = 'gradio'
|
||||
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
|
||||
|
||||
|
||||
class TestSketchpad(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Sketchpad()
|
||||
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Sketchpad()
|
||||
inp = inputs.Image(shape=(20, 20))
|
||||
array = inp.preprocess(BASE64_SKETCH)
|
||||
self.assertEqual(array.shape, (1, 28, 28))
|
||||
|
||||
|
||||
class TestWebcam(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Webcam()
|
||||
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Webcam()
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (224, 224, 3))
|
||||
self.assertEqual(array.shape, (20, 20, 3))
|
||||
inp2 = inputs.Image(shape=(20, 20), image_mode="L")
|
||||
array2 = inp2.preprocess(BASE64_SKETCH)
|
||||
self.assertEqual(array2.shape, (20, 20))
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Textbox()
|
||||
path = BASE_INPUT_INTERFACE_JS_PATH.format(
|
||||
inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Textbox()
|
||||
string = inp.preprocess(RAND_STRING)
|
||||
self.assertEqual(string, RAND_STRING)
|
||||
|
||||
|
||||
class TestImageUpload(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Image()
|
||||
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Image()
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (224, 224, 3))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Image()
|
||||
inp.image_height = 48
|
||||
inp.image_width = 48
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (48, 48, 3))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -7,24 +7,24 @@ import gradio.outputs
|
||||
|
||||
class TestInterface(unittest.TestCase):
|
||||
def test_input_output_mapping(self):
|
||||
io = gr.Interface(inputs='SketCHPad', outputs='TexT', fn=lambda x: x)
|
||||
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad)
|
||||
io = gr.Interface(inputs='sketchpad', outputs='text', fn=lambda x: x)
|
||||
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Image)
|
||||
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
|
||||
|
||||
def test_input_interface_is_instance(self):
|
||||
inp = gradio.inputs.Image()
|
||||
io = gr.Interface(inputs=inp, outputs='teXT', fn=lambda x: x)
|
||||
io = gr.Interface(inputs=inp, outputs='text', fn=lambda x: x)
|
||||
self.assertEqual(io.input_interfaces[0], inp)
|
||||
|
||||
def test_output_interface_is_instance(self):
|
||||
out = gradio.outputs.Label()
|
||||
io = gr.Interface(inputs='SketCHPad', outputs=out, fn=lambda x: x)
|
||||
io = gr.Interface(inputs='sketchpad', outputs=out, fn=lambda x: x)
|
||||
self.assertEqual(io.output_interfaces[0], out)
|
||||
|
||||
def test_prediction(self):
|
||||
def model(x):
|
||||
return 2*x
|
||||
io = gr.Interface(inputs='textbox', outputs='TEXT', fn=model)
|
||||
io = gr.Interface(inputs='textbox', outputs='text', fn=model)
|
||||
self.assertEqual(io.predict[0](11), 22)
|
||||
|
||||
|
||||
|
@ -11,11 +11,6 @@ BASE64_IMG = "
|
||||
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'
|
||||
|
||||
class TestLabel(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Label()
|
||||
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing_string(self):
|
||||
string = 'happy'
|
||||
out = outputs.Label()
|
||||
@ -52,11 +47,6 @@ class TestLabel(unittest.TestCase):
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Textbox()
|
||||
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing(self):
|
||||
string = 'happy'
|
||||
out = outputs.Textbox()
|
||||
@ -65,11 +55,6 @@ class TestTextbox(unittest.TestCase):
|
||||
|
||||
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Image()
|
||||
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing(self):
|
||||
string = BASE64_IMG
|
||||
out = outputs.Textbox()
|
||||
|
Loading…
Reference in New Issue
Block a user