merge
This commit is contained in:
Ali Abdalla 2019-03-09 18:16:18 -08:00
commit f57b93a1b4
18 changed files with 196 additions and 90 deletions

View File

@ -9,7 +9,7 @@ import base64
from gradio import preprocessing_utils
from io import BytesIO
import numpy as np
from PIL import Image
from PIL import Image, ImageOps
class AbstractInput(ABC):
"""
@ -43,13 +43,15 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28):
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28,
invert_colors=True):
self.image_width = image_width
self.image_height = image_height
self.invert_colors = invert_colors
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/sketchpad_input.html'
return 'templates/input/sketchpad.html'
def preprocess(self, inp):
"""
@ -58,8 +60,10 @@ class Sketchpad(AbstractInput):
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
if self.invert_colors:
im = ImageOps.invert(im)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1)
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height)
return array
@ -71,7 +75,7 @@ class Webcam(AbstractInput):
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/webcam_input.html'
return 'templates/input/webcam.html'
def preprocess(self, inp):
"""
@ -88,7 +92,7 @@ class Webcam(AbstractInput):
class Textbox(AbstractInput):
def get_template_path(self):
return 'templates/textbox_input.html'
return 'templates/input/textbox.html'
def preprocess(self, inp):
"""
@ -109,7 +113,7 @@ class ImageUpload(AbstractInput):
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/image_upload_input.html'
return 'templates/input/image_upload.html'
def preprocess(self, inp):
"""

View File

@ -11,6 +11,7 @@ import gradio.inputs
import gradio.outputs
from gradio import networking
import tempfile
import threading
nest_asyncio.apply()
@ -26,7 +27,8 @@ class Interface:
"""
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'function': 'python function',
'pytorch': 'PyTorch model'}
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
@ -122,6 +124,12 @@ class Interface:
return self.model_obj.predict(preprocessed_input)
elif self.model_type=='function':
return self.model_obj(preprocessed_input)
elif self.model_type=='pytorch':
import torch
value = torch.from_numpy(preprocessed_input)
value = torch.autograd.Variable(value)
prediction = self.model_obj(value)
return prediction.data.numpy()
else:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
@ -133,7 +141,7 @@ class Interface:
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port = networking.start_simple_server(output_directory)
server_port, httpd = networking.start_simple_server(output_directory)
path_to_server = 'http://localhost:{}/'.format(server_port)
networking.build_template(output_directory, self.input_interface, self.output_interface)
@ -142,9 +150,20 @@ class Interface:
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
start_server = websockets.serve(self.communicate, LOCALHOST_IP, websocket_port)
networking.set_socket_port_in_js(output_directory, websocket_port) # sets the websocket port in the JS file.
networking.set_interface_types_in_config_file(output_directory,
self.input_interface.__class__.__name__.lower(),
self.output_interface.__class__.__name__.lower())
try: # Check if running interactively using ipython.
from_ipynb = get_ipython()
if 'google.colab' in str(from_ipynb):
is_colab = True
except NameError:
is_colab = False
if self.verbose:
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
print("Model is running locally at: {}".format(path_to_server + networking.TEMPLATE_TEMP))
if not is_colab:
print("Model is running locally at: {}".format(path_to_server + networking.TEMPLATE_TEMP))
if share:
site_ngrok_url = networking.setup_ngrok(server_port, websocket_port, output_directory)
@ -155,17 +174,19 @@ class Interface:
if self.verbose:
print("To create a public link, set `share=True` in the argument to `launch()`")
site_ngrok_url = None
if is_colab:
site_ngrok_url = networking.setup_ngrok(server_port, websocket_port, output_directory)
# Keep the server running in the background.
asyncio.get_event_loop().run_until_complete(start_server)
try:
asyncio.get_event_loop().run_forever()
except RuntimeError: # Runtime errors are thrown in jupyter notebooks because of async.
pass
_ = get_ipython()
except NameError: # Runtime errors are thrown in jupyter notebooks because of async.
t = threading.Thread(target=asyncio.get_event_loop().run_forever, daemon=True)
t.start()
if inline is None:
try: # Check if running interactively using ipython.
_ = get_ipython()
from_ipynb = get_ipython()
inline = True
if browser is None:
browser = False
@ -176,10 +197,15 @@ class Interface:
else:
if browser is None:
browser = False
if browser:
if browser and not is_colab:
webbrowser.open(path_to_server + networking.TEMPLATE_TEMP) # Open a browser tab with the interface.
if inline:
from IPython.display import IFrame
display(IFrame(path_to_server + networking.TEMPLATE_TEMP, width=1000, height=500))
if is_colab:
print("Cannot display Interface inline on google colab, public link created at: {} and displayed below.".format(
site_ngrok_url + '/' + networking.TEMPLATE_TEMP))
display(IFrame(site_ngrok_url + '/' + networking.TEMPLATE_TEMP, width=1000, height=500))
else:
display(IFrame(path_to_server + networking.TEMPLATE_TEMP, width=1000, height=500))
return path_to_server + networking.TEMPLATE_TEMP, site_ngrok_url
return httpd, path_to_server + networking.TEMPLATE_TEMP, site_ngrok_url

View File

@ -14,6 +14,7 @@ from signal import SIGTERM # or SIGKILL
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import stat
import time
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import pkg_resources
@ -32,6 +33,7 @@ STATIC_PATH_LIB = pkg_resources.resource_filename('gradio', 'static/')
STATIC_PATH_TEMP = 'static/'
TEMPLATE_TEMP = 'interface.html'
BASE_JS_FILE = 'static/js/all-io.js'
CONFIG_FILE = 'static/config.json'
NGROK_ZIP_URLS = {
@ -78,25 +80,43 @@ def copy_files(src_dir, dest_dir):
dir_util.copy_tree(src_dir, dest_dir)
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_socket_url_in_js(temp_dir, socket_url):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
def render_template_with_tags(template_path, context):
"""
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
in double curly braces) with corresponding values.
:param template_path: a string with the path to the template file
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r'{{' + key + r'}}', value)
new_lines.append(line)
with open(template_path, 'w') as fout:
for line in new_lines:
fout.write(line)
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')
js_file = os.path.join(temp_dir, BASE_JS_FILE)
render_template_with_tags(js_file, {'ngrok_socket_url': ngrok_socket_url})
config_file = os.path.join(temp_dir, CONFIG_FILE)
render_template_with_tags(config_file, {'ngrok_socket_url': ngrok_socket_url})
def set_socket_port_in_js(temp_dir, socket_port):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
js_file = os.path.join(temp_dir, BASE_JS_FILE)
render_template_with_tags(js_file, {'socket_port': str(socket_port)})
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
fout.write(line)
def set_interface_types_in_config_file(temp_dir, input_interface, output_interface):
config_file = os.path.join(temp_dir, CONFIG_FILE)
render_template_with_tags(config_file, {'input_interface_type': input_interface,
'output_interface_type': output_interface})
def get_first_available_port(initial, final):
@ -127,6 +147,9 @@ def serve_files_in_background(port, directory_to_serve=None):
fullpath = os.path.join(self.server.base_path, relpath)
return fullpath
def log_message(self, format, *args):
return
class HTTPServer(BaseHTTPServer):
"""The main server, you pass in base_path which is the path you want to serve requests from"""
@ -134,25 +157,28 @@ def serve_files_in_background(port, directory_to_serve=None):
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
httpd = HTTPServer(directory_to_serve, (LOCALHOST_NAME, port))
# Now loop forever
def serve_forever():
try:
while True:
sys.stdout.flush()
# sys.stdout.flush()
httpd.serve_forever()
except KeyboardInterrupt:
pass
httpd.server_close()
thread = threading.Thread(target=serve_forever)
thread = threading.Thread(target=serve_forever, daemon=True)
thread.start()
return httpd
def start_simple_server(directory_to_serve=None):
port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
serve_files_in_background(port, directory_to_serve)
return port
httpd = serve_files_in_background(port, directory_to_serve)
return port, httpd
def download_ngrok():
@ -193,7 +219,7 @@ def setup_ngrok(server_port, websocket_port, output_directory):
kill_processes([4040, 4041]) #TODO(abidlabs): better way to do this
site_ngrok_url = create_ngrok_tunnel(server_port, NGROK_TUNNELS_API_URL)
socket_ngrok_url = create_ngrok_tunnel(websocket_port, NGROK_TUNNELS_API_URL2)
set_socket_url_in_js(output_directory, socket_ngrok_url)
set_ngrok_url_in_js(output_directory, socket_ngrok_url)
return site_ngrok_url

View File

@ -63,7 +63,7 @@ class Label(AbstractOutput):
return name
def get_template_path(self):
return 'templates/label_output.html'
return 'templates/output/label.html'
def postprocess(self, prediction):
"""
@ -94,7 +94,7 @@ class Label(AbstractOutput):
class Textbox(AbstractOutput):
def get_template_path(self):
return 'templates/textbox_output.html'
return 'templates/output/textbox.html'
def postprocess(self, prediction):
"""

View File

@ -0,0 +1,5 @@
{
"input_interface_type": "{{input_interface_type}}",
"output_interface_type": "{{output_interface_type}}",
"ngrok_socket_url": "{{ngrok_socket_url}}"
}

View File

@ -32,12 +32,14 @@
.input.text .role, .output.text .role {
margin-left: 1px;
}
.submit, .clear {
.submit, .clear, .flag, .message, .send-message {
background-color: #F6F6F6 !important;
padding: 8px !important;
box-sizing: border-box;
width: calc(50% - 8px);
text-transform: uppercase;
font-weight: bold;
border: 0 none;
}
.clear {
background-color: #F6F6F6 !important;
@ -53,10 +55,15 @@
.clear {
margin-left: 8px;
}
/*.flag:focus {
background-color: pink !important;
}
*/
.input_text, .output_text {
background: transparent;
resize: none;
border: 0 none;
resize: none
border: 0 none;
width: 100%;
font-size: 18px;
outline: none;
@ -132,3 +139,49 @@
.confidence_intervals > * {
vertical-align: bottom;
}
.flag.flagged {
background-color: pink !important;
}
.sketchpad canvas {
background-color: white;
}
.sketch_tools {
flex: 0 1 auto;
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 16px;
}
.brush {
border-radius: 50%;
background-color: #AAA;
margin: 0px 20px;
cursor: pointer;
}
.brush.selected, .brush:hover {
background-color: black;
}
#brush_1 {
height: 8px;
width: 8px;
}
#brush_2 {
height: 16px;
width: 16px;
}
#brush_3 {
height: 24px;
width: 24px;
}
.canvas_holder {
flex: 1 1 auto;
text-align: center;
}
canvas {
border: solid 1px black;
}
textarea {
resize: none;
}

View File

@ -3,7 +3,7 @@ body {
font-size: 12px;
margin: 0;
}
button, input[type="submit"], input[type="reset"] {
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {
background: none;
color: inherit;
border: none;
@ -11,7 +11,15 @@ button, input[type="submit"], input[type="reset"] {
font: inherit;
cursor: pointer;
outline: inherit;
-webkit-appearance: none;
border-radius: 0;
}
input[type="text"] {
cursor: text;
text-transform: none;
}
body > div, body > nav {
margin-left: 60px;
margin-right: 60px;
@ -33,5 +41,5 @@ nav img {
.panel {
min-width: 300px;
margin: 40px 0 0;
flex-grow: 1;
flex-grow: 1;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

View File

@ -1,5 +1,24 @@
var NGROK_URL = "ws://0f9bffb5.ngrok.io"
var SOCKET_PORT = 9200
var NGROK_URL = "{{ngrok_socket_url}}"
var SOCKET_PORT = "{{socket_port}}"
function notifyError(error) {
$.notify({
// options
message: 'Not able to communicate with model (is python code still running?)'
},{
// settings
type: 'danger',
animate: {
enter: 'animated fadeInDown',
exit: 'animated fadeOutUp'
},
placement: {
from: "bottom",
align: "right"
},
delay: 5000
});
}
try {
var origin = window.location.origin;
@ -9,7 +28,7 @@ try {
var ws = new WebSocket("ws://127.0.0.1:" + SOCKET_PORT + "/")
}
ws.onerror = function(evt) {
console.log(evt)
notifyError(evt)
};
ws.onclose = function(event) {
console.log("WebSocket is closed now.");

View File

@ -1,27 +0,0 @@
function resizeImage(base64Str) {
var img = new Image();
img.src = base64Str;
var canvas = document.createElement('canvas');
var MAX_WIDTH = 360;
var MAX_HEIGHT = 360;
var width = img.width;
var height = img.height;
if (width > height) {
if (width > MAX_WIDTH) {
height *= MAX_WIDTH / width;
width = MAX_WIDTH;
}
} else {
if (height > MAX_HEIGHT) {
width *= MAX_HEIGHT / height;
height = MAX_HEIGHT;
}
}
canvas.width = width;
canvas.height = height;
var ctx = canvas.getContext('2d');
ctx.drawImage(img, 0, 0, width, height);
return canvas.toDataURL();
}

View File

@ -22,7 +22,9 @@
</div>
<div class="panel">
<div id="output"></div>
<input type="button" class="flag" value="Flag"/>
</div>
</div>
</body>
</html>

Binary file not shown.

BIN
dist/gradio-0.3.3.tar.gz vendored Normal file

Binary file not shown.

View File

@ -1,6 +1,6 @@
Metadata-Version: 1.0
Name: gradio
Version: 0.3.2
Version: 0.3.3
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/abidlabs/gradio
Author: Abubakar Abid

View File

@ -13,28 +13,18 @@ gradio.egg-info/SOURCES.txt
gradio.egg-info/dependency_links.txt
gradio.egg-info/requires.txt
gradio.egg-info/top_level.txt
gradio/static/config.json
gradio/static/css/.DS_Store
gradio/static/css/gradio.css
gradio/static/css/style.css
gradio/static/img/logo.png
gradio/static/img/logo_inline.png
gradio/static/img/logo_mini.png
gradio/static/img/mic.png
gradio/static/img/webcam.png
gradio/static/js/all-io.js
gradio/static/js/audio-input.js
gradio/static/js/class-output.js
gradio/static/js/draw-a-digit.js
gradio/static/js/emotion-detector.js
gradio/static/js/image-upload-input.js
gradio/static/js/jquery-3.3.1.min.js
gradio/static/js/sketchpad-input.js
gradio/static/js/textbox-input.js
gradio/static/js/textbox-output.js
gradio/static/js/utils.js
gradio/static/js/webcam-input.js
gradio/templates/base_template.html
gradio/templates/image_upload_input.html
gradio/templates/label_output.html
test/test_inputs.py
test/test_interface.py
test/test_networking.py

View File

@ -5,7 +5,7 @@ except ImportError:
setup(
name='gradio',
version='0.3.2',
version='0.3.3',
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid',

View File

@ -16,7 +16,7 @@ class TestSketchpad(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.Sketchpad()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 28, 28, 1))
self.assertEqual(array.shape, (1, 28, 28))
class TestWebcam(unittest.TestCase):