added pytorch handling and tests

This commit is contained in:
Abubakar Abid 2019-03-07 13:53:34 -08:00
parent 7b82401cc0
commit 152f3e7ac5
4 changed files with 87 additions and 6 deletions

View File

@ -1,5 +1,26 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torchvision'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-5-82bc70f8d29b>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mmodels\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchvision'"
]
}
],
"source": [
"import torchvision.models as models\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
@ -90,7 +111,17 @@
"127.0.0.1 - - [07/Mar/2019 12:46:06] code 404, message File not found\n",
"127.0.0.1 - - [07/Mar/2019 12:46:06] \"GET /favicon.ico HTTP/1.1\" 404 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n"
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:35] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:35] code 404, message File not found\n",
"127.0.0.1 - - [07/Mar/2019 13:02:35] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
]
}
],

View File

@ -26,7 +26,8 @@ class Interface:
"""
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'function': 'python function',
'pytorch': 'PyTorch model'}
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
@ -122,6 +123,12 @@ class Interface:
return self.model_obj.predict(preprocessed_input)
elif self.model_type=='function':
return self.model_obj(preprocessed_input)
elif self.model_type=='pytorch':
import torch
value = torch.from_numpy(preprocessed_input)
value = torch.autograd.Variable(value)
prediction = self.model_obj(value)
return prediction.data.numpy()
else:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))

View File

@ -1,4 +1,5 @@
import unittest
import numpy as np
from gradio import Interface
import gradio.inputs
import gradio.outputs
@ -20,6 +21,48 @@ class TestInterface(unittest.TestCase):
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
self.assertEqual(io.output_interface, out)
def test_keras_model(self):
try:
import tensorflow as tf
except:
raise unittest.SkipTest("Need tensorflow installed to do keras-based tests")
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='keras')
pred = io.predict(np.ones(shape=(1, 3), ))
self.assertEqual(pred.shape, (1, 5))
def test_func_model(self):
def model(x):
return 2*x
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='function')
pred = io.predict(np.ones(shape=(1, 3)))
self.assertEqual(pred.shape, (1, 3))
def test_pytorch_model(self):
try:
import torch
except:
raise unittest.SkipTest("Need torch installed to do pytorch-based tests")
class TwoLayerNet(torch.nn.Module):
def __init__(self):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(3, 4)
self.linear2 = torch.nn.Linear(4, 5)
def forward(self, x):
h_relu = torch.nn.functional.relu(self.linear1(x))
y_pred = self.linear2(h_relu)
return y_pred
model = TwoLayerNet()
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='pytorch')
pred = io.predict(np.ones(shape=(1, 3), dtype=np.float32))
self.assertEqual(pred.shape, (1, 5))
if __name__ == '__main__':
unittest.main()

View File

@ -16,7 +16,7 @@ class TestLabel(unittest.TestCase):
def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
label = out.postprocess(string)
label = json.loads(out.postprocess(string))
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
def test_postprocessing_1D_array(self):
@ -28,21 +28,21 @@ class TestLabel(unittest.TestCase):
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
]}
out = outputs.Label()
label = out.postprocess(array)
label = json.loads(out.postprocess(array))
self.assertDictEqual(label, true_label)
def test_postprocessing_1D_array_no_confidences(self):
array = np.array([0.1, 0.2, 0, 0.7, 0])
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label(show_confidences=False)
label = out.postprocess(array)
label = json.loads(out.postprocess(array))
self.assertDictEqual(label, true_label)
def test_postprocessing_int(self):
true_label_array = np.array([[[3]]])
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label()
label = out.postprocess(true_label_array)
label = json.loads(out.postprocess(true_label_array))
self.assertDictEqual(label, true_label)