This commit is contained in:
Your Name 2019-03-07 21:37:04 -08:00
commit 9a635980f6
9 changed files with 212 additions and 63 deletions

View File

@ -1,5 +1,26 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'torchvision'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-5-82bc70f8d29b>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mmodels\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'torchvision'"
]
}
],
"source": [
"import torchvision.models as models\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
@ -49,8 +70,8 @@
"output_type": "stream",
"text": [
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7862/interface.html\n",
"Model available publicly for 8 hours at: https://d38e1bdf.ngrok.io/interface.html\n"
"Model is running locally at: http://localhost:7861/interface.html\n",
"Model available publicly for 8 hours at: http://f0cf4515.ngrok.io/interface.html\n"
]
},
{
@ -60,14 +81,14 @@
" <iframe\n",
" width=\"1000\"\n",
" height=\"500\"\n",
" src=\"http://localhost:7862/interface.html\"\n",
" src=\"http://localhost:7861/interface.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x1ebc8181a20>"
"<IPython.lib.display.IFrame at 0x1f4f871aa90>"
]
},
"metadata": {},
@ -77,26 +98,30 @@
"name": "stderr",
"output_type": "stream",
"text": [
"127.0.0.1 - - [06/Mar/2019 21:54:13] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:13] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] code 404, message File not found\n",
"127.0.0.1 - - [06/Mar/2019 21:54:14] \"GET /favicon.ico HTTP/1.1\" 404 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:52] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:53] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:53] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:53] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:53] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:53] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:53] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:53] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 21:55:54] code 404, message File not found\n",
"127.0.0.1 - - [06/Mar/2019 21:55:54] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:05] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:06] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:06] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:06] code 404, message File not found\n",
"127.0.0.1 - - [07/Mar/2019 12:46:06] \"GET /favicon.ico HTTP/1.1\" 404 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 12:46:27] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:34] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:35] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [07/Mar/2019 13:02:35] code 404, message File not found\n",
"127.0.0.1 - - [07/Mar/2019 13:02:35] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
]
}
],

View File

@ -26,7 +26,8 @@ class Interface:
"""
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'function': 'python function',
'pytorch': 'PyTorch model'}
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
@ -122,6 +123,12 @@ class Interface:
return self.model_obj.predict(preprocessed_input)
elif self.model_type=='function':
return self.model_obj(preprocessed_input)
elif self.model_type=='pytorch':
import torch
value = torch.from_numpy(preprocessed_input)
value = torch.autograd.Variable(value)
prediction = self.model_obj(value)
return prediction.data.numpy()
else:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
@ -142,6 +149,10 @@ class Interface:
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
start_server = websockets.serve(self.communicate, LOCALHOST_IP, websocket_port)
networking.set_socket_port_in_js(output_directory, websocket_port) # sets the websocket port in the JS file.
networking.set_interface_types_in_config_file(output_directory,
self.input_interface.__class__.__name__.lower(),
self.output_interface.__class__.__name__.lower())
if self.verbose:
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
print("Model is running locally at: {}".format(path_to_server + networking.TEMPLATE_TEMP))

View File

@ -14,6 +14,7 @@ from signal import SIGTERM # or SIGKILL
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import stat
import time
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import pkg_resources
@ -32,6 +33,7 @@ STATIC_PATH_LIB = pkg_resources.resource_filename('gradio', 'static/')
STATIC_PATH_TEMP = 'static/'
TEMPLATE_TEMP = 'interface.html'
BASE_JS_FILE = 'static/js/all-io.js'
CONFIG_FILE = 'static/config.json'
NGROK_ZIP_URLS = {
@ -78,25 +80,43 @@ def copy_files(src_dir, dest_dir):
dir_util.copy_tree(src_dir, dest_dir)
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_socket_url_in_js(temp_dir, socket_url):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
def render_template_with_tags(template_path, context):
"""
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
in double curly braces) with corresponding values.
:param template_path: a string with the path to the template file
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r'{{' + key + r'}}', value)
new_lines.append(line)
with open(template_path, 'w') as fout:
for line in new_lines:
fout.write(line)
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')
js_file = os.path.join(temp_dir, BASE_JS_FILE)
render_template_with_tags(js_file, {'ngrok_socket_url': ngrok_socket_url})
config_file = os.path.join(temp_dir, CONFIG_FILE)
render_template_with_tags(config_file, {'ngrok_socket_url': ngrok_socket_url})
def set_socket_port_in_js(temp_dir, socket_port):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
js_file = os.path.join(temp_dir, BASE_JS_FILE)
render_template_with_tags(js_file, {'socket_port': str(socket_port)})
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
fout.write(line)
def set_interface_types_in_config_file(temp_dir, input_interface, output_interface):
config_file = os.path.join(temp_dir, CONFIG_FILE)
render_template_with_tags(config_file, {'input_interface_type': input_interface,
'output_interface_type': output_interface})
def get_first_available_port(initial, final):
@ -143,7 +163,7 @@ def serve_files_in_background(port, directory_to_serve=None):
sys.stdout.flush()
httpd.serve_forever()
except KeyboardInterrupt:
pass
httpd.server_close()
thread = threading.Thread(target=serve_forever)
thread.start()
@ -193,7 +213,7 @@ def setup_ngrok(server_port, websocket_port, output_directory):
kill_processes([4040, 4041]) #TODO(abidlabs): better way to do this
site_ngrok_url = create_ngrok_tunnel(server_port, NGROK_TUNNELS_API_URL)
socket_ngrok_url = create_ngrok_tunnel(websocket_port, NGROK_TUNNELS_API_URL2)
set_socket_url_in_js(output_directory, socket_ngrok_url)
set_ngrok_url_in_js(output_directory, socket_ngrok_url)
return site_ngrok_url

View File

@ -0,0 +1,5 @@
{
"input_interface_type": "{{input_interface_type}}",
"output_interface_type": "{{output_interface_type}}",
"ngrok_socket_url": "{{ngrok_socket_url}}"
}

View File

@ -1,5 +1,5 @@
var NGROK_URL = "ws://0f9bffb5.ngrok.io"
var SOCKET_PORT = 9200
var NGROK_URL = "{{ngrok_socket_url}}"
var SOCKET_PORT = "{{socket_port}}"
try {
var origin = window.location.origin;

View File

@ -1,4 +1,5 @@
import unittest
import numpy as np
from gradio import Interface
import gradio.inputs
import gradio.outputs
@ -20,6 +21,48 @@ class TestInterface(unittest.TestCase):
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
self.assertEqual(io.output_interface, out)
def test_keras_model(self):
try:
import tensorflow as tf
except:
raise unittest.SkipTest("Need tensorflow installed to do keras-based tests")
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='keras')
pred = io.predict(np.ones(shape=(1, 3), ))
self.assertEqual(pred.shape, (1, 5))
def test_func_model(self):
def model(x):
return 2*x
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='function')
pred = io.predict(np.ones(shape=(1, 3)))
self.assertEqual(pred.shape, (1, 3))
def test_pytorch_model(self):
try:
import torch
except:
raise unittest.SkipTest("Need torch installed to do pytorch-based tests")
class TwoLayerNet(torch.nn.Module):
def __init__(self):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(3, 4)
self.linear2 = torch.nn.Linear(4, 5)
def forward(self, x):
h_relu = torch.nn.functional.relu(self.linear1(x))
y_pred = self.linear2(h_relu)
return y_pred
model = TwoLayerNet()
io = Interface(inputs='SketCHPad', outputs='textBOX', model=model, model_type='pytorch')
pred = io.predict(np.ones(shape=(1, 3), dtype=np.float32))
self.assertEqual(pred.shape, (1, 5))
if __name__ == '__main__':
unittest.main()

View File

@ -16,7 +16,7 @@ class TestLabel(unittest.TestCase):
def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
label = out.postprocess(string)
label = json.loads(out.postprocess(string))
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
def test_postprocessing_1D_array(self):
@ -28,21 +28,21 @@ class TestLabel(unittest.TestCase):
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
]}
out = outputs.Label()
label = out.postprocess(array)
label = json.loads(out.postprocess(array))
self.assertDictEqual(label, true_label)
def test_postprocessing_1D_array_no_confidences(self):
array = np.array([0.1, 0.2, 0, 0.7, 0])
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label(show_confidences=False)
label = out.postprocess(array)
label = json.loads(out.postprocess(array))
self.assertDictEqual(label, true_label)
def test_postprocessing_int(self):
true_label_array = np.array([[[3]]])
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label()
label = out.postprocess(true_label_array)
label = json.loads(out.postprocess(true_label_array))
self.assertDictEqual(label, true_label)

View File

@ -15,18 +15,18 @@
<div id="hero-section"><!--
--><div id="intro">
<img id="logo" src="img/logo.png"/>
<p>GradIO is a free, open-source python library that helps machine
learning engineers develop, interact with, and share their machine
learning models with only a couple lines of extra code.</p>
<p>With GradIO, you can generate user interfaces in your browser that
<p>Gradio is a free, open-source python library that helps machine
learning researchers <strong>interact</strong> with and <strong>share</strong> their machine
learning models with collaborators and clients with only a few lines of extra code.</p>
<p>With gradio, you can easily generate in-browser interfaces that
enable you to enter various forms of input for your model and explore
the behavior of your model immediately. GradIO also generates links
the behavior of your model immediately. Gradio also generates <strong>links</strong>
that can be shared with collaborators and other audiences, so they can
interact with the model without setting up any software or even having
any background in machine learning or software at all.</p>
any background in machine learning or software at all!</p>
<p>Visit the <a href="https://github.com/abidlabs/gradio"
target="_blank">GradIO GitHub >></a></p>
<p>GradIO was developed by researchers at Stanford University and is
target="_blank">Gradio GitHub >></a></p>
<p>Gradio was developed by researchers at Stanford University and is
under the Apache license.</p>
</div><!--
--><div id="demos">
@ -68,9 +68,8 @@
</div>
<div id="gradio">
<div class="instructions">
In this demo, draw a digit from 0 to 9 in the input box. Click
submit, and the digit recognized by the model will be presented in
the output.
The code above produces produces an interface like this, in which you can draw a digit from 0 to 9 in the input box. Click
submit to get the prediction!
</div>
<div class="panel">
<div class="input sketchpad">
@ -90,24 +89,24 @@
<div id="summaries">
<div id="setup" class="summary_box">
<h2>Fast, easy setup</h2>
<p>Using GradIO only requires adding a couple lines of code to your
project. You can install GradIO from pip and deploy your model in
<p>Using Gradio only requires adding a couple lines of code to your
project. You can install Gradio from pip and deploy your model in
seconds. Once launched, you can choose from a variety of interface
types to interact with, iterate over, and improve your models.</p>
<p>More on <a href="getting_started.html">Getting Stared >></a><p>
</div>
<div id="present" class="summary_box">
<h2>Present and share</h2>
<p>GradIO present an interface that is intuitive to engineers and
<p>Gradio present an interface that is intuitive to engineers and
non-engineers alike, and thus a valuable tool in sharing insights from
your models. When GradIO launches a model, it also creates a link you
your models. When Gradio launches a model, it also creates a link you
can share with colleagues that lets them interact with the model
on your computer remotely from their own devices.</p>
<p>More on <a href="sharing.html">Sharing >></a><p>
</div>
<div id="embed" class="summary_box">
<h2>Embed and go</h2>
<p>GradIO can be embedded in Jupyter and Colab notebooks, in blogs and
<p>Gradio can be embedded in Jupyter and Colab notebooks, in blogs and
websites, and screenshotted for use in research papers. These features
all help your models be more easily shared and consumed with a larger
audience.</p>

46
web/js/models.js Normal file
View File

@ -0,0 +1,46 @@
<script src="https://unpkg.com/ml5@0.1.3/dist/ml5.min.js"></script>
// Takes in the ID of the image, and returns labels and confidences
function imageupload_label(image){
var output;
classifier = ml5.imageClassifier('MobileNet', function() {
console.log('Model Loaded!');
});
classifier.predict(image, function(err, results) {
var output = {
'label': results[0].className,
'confidences': [
{'label': results[0].className,
'confidence': results[0].probability.toFixed(4)},
{'label': results[1].className,
'confidence': results[1].probability.toFixed(4)},
{'label': results[2].className,
'confidence': results[2].probability.toFixed(4)},
]
}
});
return output
}
// Takes in the ID of the image, and returns labels and confidences
function sketchpad_label(image){
var output;
classifier = ml5.imageClassifier('MobileNet', function() {
console.log('Model Loaded!');
});
classifier.predict(image, function(err, results) {
var output = {
'label': results[0].className,
'confidences': [
{'label': results[0].className,
'confidence': results[0].probability.toFixed(4)},
{'label': results[1].className,
'confidence': results[1].probability.toFixed(4)},
{'label': results[2].className,
'confidence': results[2].probability.toFixed(4)},
]
}
});
return output
}