gradio/test/test_interface.py

69 lines
2.7 KiB
Python
Raw Normal View History

2019-02-28 13:51:51 +08:00
import unittest
2019-03-08 05:53:34 +08:00
import numpy as np
2019-02-28 13:51:51 +08:00
from gradio import Interface
import gradio.inputs
import gradio.outputs
class TestInterface(unittest.TestCase):
def test_input_output_mapping(self):
io = Interface(inputs='SketCHPad', outputs='textBOX', model=lambda x: x, model_type='function')
self.assertIsInstance(io.input_interface, gradio.inputs.Sketchpad)
self.assertIsInstance(io.output_interface, gradio.outputs.Textbox)
def test_input_interface_is_instance(self):
inp = gradio.inputs.ImageUpload()
io = Interface(inputs=inp, outputs='textBOX', model=lambda x: x, model_type='function')
self.assertEqual(io.input_interface, inp)
def test_output_interface_is_instance(self):
out = gradio.outputs.Label(show_confidences=False)
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
self.assertEqual(io.output_interface, out)
2019-06-14 17:13:13 +08:00
# def test_keras_model(self):
# try:
# import tensorflow as tf
# except:
# raise unittest.SkipTest("Need tensorflow installed to do keras-based tests")
# inputs = tf.keras.Input(shape=(3,))
# x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
# outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
# model = tf.keras.Model(inputs=inputs, outputs=outputs)
# io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='keras')
# pred = io.predict(np.ones(shape=(1, 3), ))
# self.assertEqual(pred.shape, (1, 5))
2019-03-08 05:53:34 +08:00
2019-06-14 17:13:13 +08:00
# def test_func_model(self):
# def model(x):
# return 2*x
# io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='function')
# pred = io.predict(np.ones(shape=(1, 3)))
# self.assertEqual(pred.shape, (1, 3))
2019-03-08 05:53:34 +08:00
def test_pytorch_model(self):
try:
import torch
except:
raise unittest.SkipTest("Need torch installed to do pytorch-based tests")
class TwoLayerNet(torch.nn.Module):
def __init__(self):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(3, 4)
self.linear2 = torch.nn.Linear(4, 5)
def forward(self, x):
h_relu = torch.nn.functional.relu(self.linear1(x))
y_pred = self.linear2(h_relu)
return y_pred
model = TwoLayerNet()
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='pytorch')
pred = io.predict(np.ones(shape=(1, 3), dtype=np.float32))
self.assertEqual(pred.shape, (1, 5))
2019-02-28 13:51:51 +08:00
if __name__ == '__main__':
unittest.main()