mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
added pytorch handling and tests
This commit is contained in:
parent
7b82401cc0
commit
152f3e7ac5
@ -1,5 +1,26 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ModuleNotFoundError",
|
||||
"evalue": "No module named 'torchvision'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m<ipython-input-5-82bc70f8d29b>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mmodels\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchvision'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torchvision.models as models\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
@ -90,7 +111,17 @@
|
||||
"127.0.0.1 - - [07/Mar/2019 12:46:06] code 404, message File not found\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 12:46:06] \"GET /favicon.ico HTTP/1.1\" 404 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n"
|
||||
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:35] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:35] code 404, message File not found\n",
|
||||
"127.0.0.1 - - [07/Mar/2019 13:02:35] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -26,7 +26,8 @@ class Interface:
|
||||
"""
|
||||
|
||||
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
|
||||
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
|
||||
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'function': 'python function',
|
||||
'pytorch': 'PyTorch model'}
|
||||
|
||||
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
|
||||
verbose=True):
|
||||
@ -122,6 +123,12 @@ class Interface:
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type=='function':
|
||||
return self.model_obj(preprocessed_input)
|
||||
elif self.model_type=='pytorch':
|
||||
import torch
|
||||
value = torch.from_numpy(preprocessed_input)
|
||||
value = torch.autograd.Variable(value)
|
||||
prediction = self.model_obj(value)
|
||||
return prediction.data.numpy()
|
||||
else:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from gradio import Interface
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
@ -20,6 +21,48 @@ class TestInterface(unittest.TestCase):
|
||||
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
|
||||
self.assertEqual(io.output_interface, out)
|
||||
|
||||
def test_keras_model(self):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except:
|
||||
raise unittest.SkipTest("Need tensorflow installed to do keras-based tests")
|
||||
inputs = tf.keras.Input(shape=(3,))
|
||||
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
|
||||
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='keras')
|
||||
pred = io.predict(np.ones(shape=(1, 3), ))
|
||||
self.assertEqual(pred.shape, (1, 5))
|
||||
|
||||
def test_func_model(self):
|
||||
def model(x):
|
||||
return 2*x
|
||||
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='function')
|
||||
pred = io.predict(np.ones(shape=(1, 3)))
|
||||
self.assertEqual(pred.shape, (1, 3))
|
||||
|
||||
def test_pytorch_model(self):
|
||||
try:
|
||||
import torch
|
||||
except:
|
||||
raise unittest.SkipTest("Need torch installed to do pytorch-based tests")
|
||||
|
||||
class TwoLayerNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TwoLayerNet, self).__init__()
|
||||
self.linear1 = torch.nn.Linear(3, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
h_relu = torch.nn.functional.relu(self.linear1(x))
|
||||
y_pred = self.linear2(h_relu)
|
||||
return y_pred
|
||||
|
||||
model = TwoLayerNet()
|
||||
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='pytorch')
|
||||
pred = io.predict(np.ones(shape=(1, 3), dtype=np.float32))
|
||||
self.assertEqual(pred.shape, (1, 5))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -16,7 +16,7 @@ class TestLabel(unittest.TestCase):
|
||||
def test_postprocessing_string(self):
|
||||
string = 'happy'
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(string)
|
||||
label = json.loads(out.postprocess(string))
|
||||
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
|
||||
|
||||
def test_postprocessing_1D_array(self):
|
||||
@ -28,21 +28,21 @@ class TestLabel(unittest.TestCase):
|
||||
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
|
||||
]}
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(array)
|
||||
label = json.loads(out.postprocess(array))
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
def test_postprocessing_1D_array_no_confidences(self):
|
||||
array = np.array([0.1, 0.2, 0, 0.7, 0])
|
||||
true_label = {outputs.Label.LABEL_KEY: 3}
|
||||
out = outputs.Label(show_confidences=False)
|
||||
label = out.postprocess(array)
|
||||
label = json.loads(out.postprocess(array))
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
def test_postprocessing_int(self):
|
||||
true_label_array = np.array([[[3]]])
|
||||
true_label = {outputs.Label.LABEL_KEY: 3}
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(true_label_array)
|
||||
label = json.loads(out.postprocess(true_label_array))
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user