2019-02-28 13:51:51 +08:00
|
|
|
import unittest
|
2019-03-08 05:53:34 +08:00
|
|
|
import numpy as np
|
2020-06-10 08:00:30 +08:00
|
|
|
import gradio as gr
|
2019-02-28 13:51:51 +08:00
|
|
|
import gradio.inputs
|
|
|
|
import gradio.outputs
|
|
|
|
|
|
|
|
|
|
|
|
class TestInterface(unittest.TestCase):
|
|
|
|
def test_input_output_mapping(self):
|
2020-06-12 03:31:44 +08:00
|
|
|
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=lambda
|
|
|
|
x: x)
|
|
|
|
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad)
|
|
|
|
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
|
2019-02-28 13:51:51 +08:00
|
|
|
|
2019-03-06 14:45:08 +08:00
|
|
|
def test_input_interface_is_instance(self):
|
2020-06-12 03:31:44 +08:00
|
|
|
inp = gradio.inputs.ImageIn()
|
|
|
|
io = gr.Interface(inputs=inp, outputs='textBOX', fn=lambda x: x)
|
|
|
|
self.assertEqual(io.input_interfaces[0], inp)
|
2019-03-06 14:45:08 +08:00
|
|
|
|
|
|
|
def test_output_interface_is_instance(self):
|
2020-06-12 03:31:44 +08:00
|
|
|
# out = gradio.outputs.Label(show_confidences=False)
|
|
|
|
out = gradio.outputs.Label()
|
|
|
|
io = gr.Interface(inputs='SketCHPad', outputs=out, fn=lambda x: x)
|
|
|
|
self.assertEqual(io.output_interfaces[0], out)
|
2019-03-06 14:45:08 +08:00
|
|
|
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_keras_model(self):
|
|
|
|
try:
|
|
|
|
import tensorflow as tf
|
|
|
|
except:
|
|
|
|
raise unittest.SkipTest("Need tensorflow installed to do keras-based tests")
|
|
|
|
inputs = tf.keras.Input(shape=(3,))
|
|
|
|
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
|
|
|
|
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
|
|
|
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
2020-06-12 03:31:44 +08:00
|
|
|
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
|
|
|
|
# pred = io.predict(np.ones(shape=(1, 3), ))
|
|
|
|
# self.assertEqual(pred.shape, (1, 5))
|
2019-03-08 05:53:34 +08:00
|
|
|
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_func_model(self):
|
|
|
|
def model(x):
|
|
|
|
return 2*x
|
2020-06-12 03:31:44 +08:00
|
|
|
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
|
|
|
|
# pred = io.predict(np.ones(shape=(1, 3)))
|
|
|
|
# self.assertEqual(pred.shape, (1, 3))
|
2019-03-08 05:53:34 +08:00
|
|
|
|
|
|
|
def test_pytorch_model(self):
|
|
|
|
try:
|
|
|
|
import torch
|
|
|
|
except:
|
|
|
|
raise unittest.SkipTest("Need torch installed to do pytorch-based tests")
|
|
|
|
|
|
|
|
class TwoLayerNet(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(TwoLayerNet, self).__init__()
|
|
|
|
self.linear1 = torch.nn.Linear(3, 4)
|
|
|
|
self.linear2 = torch.nn.Linear(4, 5)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
h_relu = torch.nn.functional.relu(self.linear1(x))
|
|
|
|
y_pred = self.linear2(h_relu)
|
|
|
|
return y_pred
|
|
|
|
|
|
|
|
model = TwoLayerNet()
|
2020-06-12 03:31:44 +08:00
|
|
|
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
|
|
|
|
# pred = io.predict(np.ones(shape=(1, 3), dtype=np.float32))
|
|
|
|
# self.assertEqual(pred.shape, (1, 5))
|
2019-03-08 05:53:34 +08:00
|
|
|
|
2019-02-28 13:51:51 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2019-06-19 04:13:50 +08:00
|
|
|
unittest.main()
|