This commit is contained in:
Abubakar Abid 2019-03-06 12:02:37 -08:00
parent a1b694b7da
commit 92bd5f6af6
18 changed files with 1447 additions and 358 deletions

View File

@ -29,15 +29,8 @@
"outputs": [],
"source": [
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299)\n",
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=8)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=8, num_top_classes=5)\n",
"\n",
"iface = gradio.Interface(inputs=inp, \n",
" outputs=out,\n",
" model=model, \n",
@ -54,8 +47,8 @@
"output_type": "stream",
"text": [
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7860/interface.html\n",
"Model available publicly for 8 hours at: https://efb97fa5.ngrok.io/interface.html\n"
"Model is running locally at: http://localhost:7861/interface.html\n",
"Model available publicly for 8 hours at: http://4d315e61.ngrok.io/interface.html\n"
]
},
{
@ -65,50 +58,63 @@
" <iframe\n",
" width=\"1000\"\n",
" height=\"500\"\n",
" src=\"http://localhost:7860/interface.html\"\n",
" src=\"http://localhost:7861/interface.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x118f6285eb8>"
"<IPython.lib.display.IFrame at 0x26985fbd940>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"('http://localhost:7860/interface.html', 'https://efb97fa5.ngrok.io')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"dddddddddddddddddddd cleaver,\n",
"dddddddddddddddddddd cleaver,\n"
"127.0.0.1 - - [06/Mar/2019 11:55:01] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:01] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:01] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:01] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:01] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:01] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:01] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:02] code 404, message File not found\n",
"127.0.0.1 - - [06/Mar/2019 11:55:02] \"GET /favicon.ico HTTP/1.1\" 404 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:08] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:09] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:09] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:09] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:09] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:09] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:09] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:09] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:55:10] code 404, message File not found\n",
"127.0.0.1 - - [06/Mar/2019 11:55:10] \"GET /favicon.ico HTTP/1.1\" 404 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:07] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:07] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:07] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:07] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:07] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:31] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:43] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:44] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:44] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:44] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:44] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:44] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:44] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:44] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [06/Mar/2019 11:58:45] code 404, message File not found\n",
"127.0.0.1 - - [06/Mar/2019 11:58:45] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
]
}
],
"source": [
"iface.launch(inline=True, browser=True, share=True)"
"iface.launch(inline=True, browser=True, share=True);"
]
}
],

View File

@ -1 +1 @@
from gradio.interface import Interface # This makes Interface importable as gradio.Interface.
from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,15 @@
"""
This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
`AbstractInput`, and each class must define a path to its template. All of the subclasses of `AbstractInput` are
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
"""
from abc import ABC, abstractmethod
import base64
from PIL import Image
from gradio import preprocessing_utils
from io import BytesIO
import numpy as np
from gradio import preprocessing_utils
from PIL import Image
class AbstractInput(ABC):
"""
@ -12,84 +18,115 @@ class AbstractInput(ABC):
"""
def __init__(self, preprocessing_fn=None):
"""
:param preprocessing_fn: an optional preprocessing function that overrides the default
"""
if preprocessing_fn is not None:
self._pre_process = preprocessing_fn
if not callable(preprocessing_fn):
raise ValueError('`preprocessing_fn` must be a callable function')
self.preprocess = preprocessing_fn
super().__init__()
@abstractmethod
def _get_template_path(self):
def get_template_path(self):
"""
All interfaces should define a method that returns the path to its template.
"""
pass
@abstractmethod
def _pre_process(self):
def preprocess(self, inp):
"""
All interfaces should define a method that returns the path to its template.
All interfaces should define a default preprocessing method
"""
pass
class Sketchpad(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28):
self.image_width = image_width
self.image_height = image_height
super().__init__(preprocessing_fn=preprocessing_fn)
def _get_template_path(self):
def get_template_path(self):
return 'templates/sketchpad_input.html'
def _pre_process(self, imgstring):
def preprocess(self, inp):
"""
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
"""
content = imgstring.split(';')[1]
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
body = base64.decodebytes(image_encoded.encode('utf-8'))
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (28, 28))
array = np.array(im).flatten().reshape(1, 28, 28, 1)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1)
return array
class Webcam(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def _get_template_path(self):
def get_template_path(self):
return 'templates/webcam_input.html'
def _pre_process(self, imgstring):
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
content = imgstring.split(';')[1]
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
body = base64.decodebytes(image_encoded.encode('utf-8'))
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (48, 48))
array = np.array(im).flatten().reshape(1, 48, 48, 1)
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array
class Textbox(AbstractInput):
def _get_template_path(self):
def get_template_path(self):
return 'templates/textbox_input.html'
def _pre_process(self, text):
def preprocess(self, inp):
"""
By default, no pre-processing is applied to text.
"""
return text
return inp
class ImageUpload(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB',
scale = 1/127.5, shift = -1):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
self.image_mode = image_mode
self.scale = scale
self.shift = shift
super().__init__(preprocessing_fn=preprocessing_fn)
def _get_template_path(self):
def get_template_path(self):
return 'templates/image_upload_input.html'
def _pre_process(self, imgstring):
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
content = imgstring.split(';')[1]
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
body = base64.decodebytes(image_encoded.encode('utf-8'))
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (48, 48))
array = np.array(im).flatten().reshape(1, 48, 48, 1)
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
im = np.array(im).flatten()
im = im * self.scale + self.shift
if self.num_channels is None:
array = im.reshape(1, self.image_width, self.image_height)
else:
array = im.reshape(1, self.image_width, self.image_height, self.num_channels)
return array
# Automatically adds all subclasses of AbstractInput into a dictionary (keyed by class name) for easy referencing.
registry = {cls.__name__.lower(): cls for cls in AbstractInput.__subclasses__()}

View File

@ -1,14 +1,15 @@
'''
This is the core file in the `gradio` package, and defines the Interface class, including methods for constructing the
interface using the input and output types.
'''
import asyncio
import websockets
import nest_asyncio
import webbrowser
import pkg_resources
from bs4 import BeautifulSoup
from gradio import inputs
from gradio import outputs
import gradio.inputs
import gradio.outputs
from gradio import networking
import os
import shutil
import tempfile
nest_asyncio.apply()
@ -17,46 +18,53 @@ LOCALHOST_IP = '127.0.0.1'
INITIAL_WEBSOCKET_PORT = 9200
TRY_NUM_PORTS = 100
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html')
JS_PATH_LIB = pkg_resources.resource_filename('gradio', 'js/')
CSS_PATH_LIB = pkg_resources.resource_filename('gradio', 'css/')
JS_PATH_TEMP = 'js/'
CSS_PATH_TEMP = 'css/'
TEMPLATE_TEMP = 'interface.html'
BASE_JS_FILE = 'js/all-io.js'
class Interface():
class Interface:
"""
The Interface class represents a general input/output interface for a machine learning model. During construction,
the appropriate inputs and outputs
"""
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
def __init__(self, input, output, model, model_type=None, preprocessing_fn=None, postprocessing_fn=None):
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
"""
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
:param inputs: a string or `AbstractInput` representing the input interface.
:param outputs: a string or `AbstractOutput` representing the output interface.
:param model_obj: the model object, such as a sklearn classifier or keras model.
:param model_params: additional model parameters.
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
provided.
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
"""
self.input_interface = inputs.registry[input](preprocessing_fn)
self.output_interface = outputs.registry[output](postprocessing_fn)
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)
if model_type is None:
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
else:
if verbose:
print("Model type not explicitly identified, inferred to be: {}".format(
self.VALID_MODEL_TYPES[model_type]))
self.VALID_MODEL_TYPES[model_type]))
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
self.model_type = model_type
self.verbose = verbose
def _infer_model_type(self, model):
if callable(model):
return 'function'
@staticmethod
def _infer_model_type(model):
""" Helper method that attempts to identify the type of trained ML model."""
try:
import sklearn
if isinstance(model, sklearn.base.BaseEstimator):
@ -78,124 +86,100 @@ class Interface():
except ImportError:
pass
return None
if callable(model):
return 'function'
def _build_template(self, temp_dir):
input_template_path = pkg_resources.resource_filename(
'gradio', self.input_interface._get_template_path())
output_template_path = pkg_resources.resource_filename(
'gradio', self.output_interface._get_template_path())
input_page = open(input_template_path)
output_page = open(output_template_path)
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
all_io_page = open(BASE_TEMPLATE)
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
input_tag = all_io_soup.find("div", {"id": "input"})
output_tag = all_io_soup.find("div", {"id": "output"})
input_tag.replace_with(input_soup)
output_tag.replace_with(output_soup)
f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w")
f.write(str(all_io_soup.prettify))
self._copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP))
self._copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP))
return
def _copy_files(self, src_dir, dest_dir):
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
src_files = os.listdir(src_dir)
for file_name in src_files:
full_file_name = os.path.join(src_dir, file_name)
if os.path.isfile(full_file_name):
shutil.copy(full_file_name, dest_dir)
def _set_socket_url_in_js(self, temp_dir, socket_url):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
fout.write(line)
def _set_socket_port_in_js(self, temp_dir, socket_port):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
fout.write(line)
def predict(self, array):
if self.model_type=='sklearn':
return self.model_obj.predict(array)
elif self.model_type=='keras':
return self.model_obj.predict(array)
elif self.model_type=='function':
return self.model_obj(array)
else:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
async def communicate(self, websocket, path):
"""
Method that defines how this interface communicates with the websocket.
:param websocket: a Websocket object used to communicate with the interface frontend
:param path: ignored
Method that defines how this interface should communicates with the websocket. (1) When an input is received by
the websocket, it is passed into the input interface and preprocssed. (2) Then the model is called to make a
prediction. (3) Finally, the prediction is postprocessed to get something to be displayed by the output.
:param websocket: a Websocket server used to communicate with the interface frontend
:param path: not used, but required for compliance with websocket library
"""
while True:
try:
msg = await websocket.recv()
processed_input = self.input_interface._pre_process(msg)
processed_input = self.input_interface.preprocess(msg)
prediction = self.predict(processed_input)
processed_output = self.output_interface._post_process(prediction)
processed_output = self.output_interface.postprocess(prediction)
await websocket.send(str(processed_output))
except websockets.exceptions.ConnectionClosed:
pass
# except Exception as e:
# print(e)
def launch(self, share_link=False, verbose=True):
def predict(self, preprocessed_input):
"""
Standard method shared by interfaces that launches a websocket at a specified IP address.
Method that calls the relevant method of the model object to make a prediction.
:param preprocessed_input: the preprocessed input returned by the input interface
"""
if self.model_type=='sklearn':
return self.model_obj.predict(preprocessed_input)
elif self.model_type=='keras':
return self.model_obj.predict(preprocessed_input)
elif self.model_type=='function':
return self.model_obj(preprocessed_input)
else:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
def launch(self, inline=None, browser=None, share=False):
"""
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
:param share: boolean. If True, then a share link is generated using ngrok is displayed to the user.
"""
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port = networking.start_simple_server(output_directory)
path_to_server = 'http://localhost:{}/'.format(server_port)
self._build_template(output_directory)
networking.build_template(output_directory, self.input_interface, self.output_interface)
ports_in_use = networking.get_ports_in_use(INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
for i in range(TRY_NUM_PORTS):
if not ((INITIAL_WEBSOCKET_PORT + i) in ports_in_use):
break
else:
raise OSError("All ports from {} to {} are in use. Please close a port.".format(
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS))
start_server = websockets.serve(self.communicate, LOCALHOST_IP, INITIAL_WEBSOCKET_PORT + i)
self._set_socket_port_in_js(output_directory, INITIAL_WEBSOCKET_PORT + i)
if verbose:
# Set up a port to serve a websocket that sets up the communication between the front-end and model.
websocket_port = networking.get_first_available_port(
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
start_server = websockets.serve(self.communicate, LOCALHOST_IP, websocket_port)
networking.set_socket_port_in_js(output_directory, websocket_port) # sets the websocket port in the JS file.
if self.verbose:
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
print("Model is running locally at: {}".format(path_to_server + networking.TEMPLATE_TEMP))
if share_link:
networking.kill_processes([4040, 4041])
site_ngrok_url = networking.setup_ngrok(server_port)
socket_ngrok_url = networking.setup_ngrok(INITIAL_WEBSOCKET_PORT, api_url=networking.NGROK_TUNNELS_API_URL2)
self._set_socket_url_in_js(output_directory, socket_ngrok_url)
if verbose:
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
if share:
site_ngrok_url = networking.setup_ngrok(server_port, websocket_port, output_directory)
if self.verbose:
print("Model available publicly for 8 hours at: {}".format(
site_ngrok_url + '/' + networking.TEMPLATE_TEMP))
else:
if verbose:
print("To create a public link, set `share_link=True` in the argument to `launch()`")
if self.verbose:
print("To create a public link, set `share=True` in the argument to `launch()`")
site_ngrok_url = None
# Keep the server running in the background.
asyncio.get_event_loop().run_until_complete(start_server)
try:
asyncio.get_event_loop().run_forever()
except RuntimeError: # Runtime errors are thrown in jupyter notebooks because of async.
pass
webbrowser.open(path_to_server + TEMPLATE_TEMP)
if inline is None:
try: # Check if running interactively using ipython.
_ = get_ipython()
inline = True
if browser is None:
browser = False
except NameError:
inline = False
if browser is None:
browser = True
else:
if browser is None:
browser = False
if browser:
webbrowser.open(path_to_server + networking.TEMPLATE_TEMP) # Open a browser tab with the interface.
if inline:
from IPython.display import IFrame
display(IFrame(path_to_server + networking.TEMPLATE_TEMP, width=1000, height=500))
return path_to_server + networking.TEMPLATE_TEMP, site_ngrok_url

View File

@ -1,3 +1,7 @@
'''
Defines helper methods useful for setting up ports, launching servers, and handling `ngrok`
'''
import subprocess
import requests
import zipfile
@ -5,21 +9,31 @@ import io
import sys
import os
import socket
from psutil import process_iter, AccessDenied
from psutil import process_iter, AccessDenied, NoSuchProcess
from signal import SIGTERM # or SIGKILL
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import stat
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import pkg_resources
from bs4 import BeautifulSoup
from distutils import dir_util
INITIAL_PORT_VALUE = 7860
TRY_NUM_PORTS = 100
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
LOCALHOST_NAME = 'localhost'
LOCALHOST_PREFIX = 'localhost:'
NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output)
NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output)
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html')
STATIC_PATH_LIB = pkg_resources.resource_filename('gradio', 'static/')
STATIC_PATH_TEMP = 'static/'
TEMPLATE_TEMP = 'interface.html'
BASE_JS_FILE = 'static/js/all-io.js'
NGROK_ZIP_URLS = {
"linux": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip",
"darwin": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-darwin-amd64.zip",
@ -27,48 +41,83 @@ NGROK_ZIP_URLS = {
}
def get_ports_in_use(start, stop):
ports_in_use = []
for port in range(start, stop):
def build_template(temp_dir, input_interface, output_interface):
"""
Builds a complete HTML template with supporting JS and CSS files in a given directory.
:param temp_dir: string with path to temp directory in which the template should be built
:param input_interface: an AbstractInput object which includes is used to get the input template
:param output_interface: an AbstractInput object which includes is used to get the input template
"""
input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path())
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
input_page = open(input_template_path)
output_page = open(output_template_path)
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
all_io_page = open(BASE_TEMPLATE)
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
input_tag = all_io_soup.find("div", {"id": "input"})
output_tag = all_io_soup.find("div", {"id": "output"})
input_tag.replace_with(input_soup)
output_tag.replace_with(output_soup)
f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w")
f.write(str(all_io_soup))
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
def copy_files(src_dir, dest_dir):
"""
Copies all the files from one directory to another
:param src_dir: string path to source directory
:param dest_dir: string path to destination directory
"""
dir_util.copy_tree(src_dir, dest_dir)
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_socket_url_in_js(temp_dir, socket_url):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
fout.write(line)
def set_socket_port_in_js(temp_dir, socket_port):
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
lines = fin.readlines()
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
for line in lines:
fout.write(line)
def get_first_available_port(initial, final):
"""
Gets the first open port in a specified range of port numbers
:param initial: the initial value in the range of port numbers
:param final: final (exclusive) value in the range of port numbers, should be greater than `initial`
:return:
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
ports_in_use.append(port)
return ports_in_use
# ports_in_use = []
# try:
# for proc in process_iter():
# for conns in proc.connections(kind='inet'):
# ports_in_use.append(conns.laddr.port)
# except AccessDenied:
# pass # TODO(abidlabs): somehow find a way to handle this issue?
# return ports_in_use
pass
raise OSError("All ports from {} to {} are in use. Please close a port.".format(initial, final))
def serve_files_in_background(port, directory_to_serve=None):
# class Handler(http.server.SimpleHTTPRequestHandler):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, directory=directory_to_serve, **kwargs)
#
# server = socketserver.ThreadingTCPServer(('localhost', port), Handler)
# # Ensures that Ctrl-C cleanly kills all spawned threads
# server.daemon_threads = True
# # Quicker rebinding
# server.allow_reuse_address = True
#
# # A custom signal handle to allow us to Ctrl-C out of the process
# def signal_handler(signal, frame):
# print('Exiting http server (Ctrl+C pressed)')
# try:
# if (server):
# server.server_close()
# finally:
# sys.exit(0)
#
# # Install the keyboard interrupt handler
# signal.signal(signal.SIGINT, signal_handler)
class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()"""
@ -101,21 +150,9 @@ def serve_files_in_background(port, directory_to_serve=None):
def start_simple_server(directory_to_serve=None):
# TODO(abidlabs): increment port number until free port is found
ports_in_use = get_ports_in_use(start=INITIAL_PORT_VALUE, stop=INITIAL_PORT_VALUE + TRY_NUM_PORTS)
for i in range(TRY_NUM_PORTS):
if not((INITIAL_PORT_VALUE + i) in ports_in_use):
break
else:
raise OSError("All ports from {} to {} are in use. Please close a port.".format(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS))
serve_files_in_background(INITIAL_PORT_VALUE + i, directory_to_serve)
# if directory_to_serve is None:
# subprocess.Popen(['python', '-m', 'http.server', str(INITIAL_PORT_VALUE + i)])
# else:
# cmd = ' '.join(['python', '-m', 'http.server', '-d', directory_to_serve, str(INITIAL_PORT_VALUE + i)])
# subprocess.Popen(cmd, shell=True) # Doesn't seem to work if list is passed for some reason.
return INITIAL_PORT_VALUE + i
port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
serve_files_in_background(port, directory_to_serve)
return port
def download_ngrok():
@ -133,7 +170,7 @@ def download_ngrok():
os.chmod('ngrok', st.st_mode | stat.S_IEXEC)
def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL):
def create_ngrok_tunnel(local_port, api_url):
if not(os.path.isfile('ngrok.exe') or os.path.isfile('ngrok')):
download_ngrok()
if sys.platform == 'win32':
@ -147,18 +184,26 @@ def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL):
session.mount('https://', adapter)
r = session.get(api_url)
for tunnel in r.json()['tunnels']:
if LOCALHOST_PREFIX + str(local_port) in tunnel['config']['addr']:
if '{}:'.format(LOCALHOST_NAME) + str(local_port) in tunnel['config']['addr']:
return tunnel['public_url']
raise RuntimeError("Not able to retrieve ngrok public URL")
def kill_processes(process_ids):
def setup_ngrok(server_port, websocket_port, output_directory):
kill_processes([4040, 4041]) #TODO(abidlabs): better way to do this
site_ngrok_url = create_ngrok_tunnel(server_port, NGROK_TUNNELS_API_URL)
socket_ngrok_url = create_ngrok_tunnel(websocket_port, NGROK_TUNNELS_API_URL2)
set_socket_url_in_js(output_directory, socket_ngrok_url)
return site_ngrok_url
def kill_processes(process_ids): #TODO(abidlabs): remove this, we shouldn't need to kill
for proc in process_iter():
try:
for conns in proc.connections(kind='inet'):
if conns.laddr.port in process_ids:
proc.send_signal(SIGTERM) # or SIGKILL
except AccessDenied:
except (AccessDenied, NoSuchProcess):
pass

View File

@ -1,5 +1,13 @@
"""
This module defines various classes that can serve as the `output` to an interface. Each class must inherit from
`AbstractOutput`, and each class must define a path to its template. All of the subclasses of `AbstractOutput` are
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
"""
from abc import ABC, abstractmethod
import numpy as np
import json
from gradio import imagenet_class_labels
class AbstractOutput(ABC):
"""
@ -9,52 +17,86 @@ class AbstractOutput(ABC):
def __init__(self, postprocessing_fn=None):
"""
:param postprocessing_fn: an optional postprocessing function that overrides the default
"""
if postprocessing_fn is not None:
self._post_process = postprocessing_fn
self.postprocess = postprocessing_fn
super().__init__()
@abstractmethod
def _get_template_path(self):
def get_template_path(self):
"""
All interfaces should define a method that returns the path to its template.
"""
pass
@abstractmethod
def _post_process(self):
def postprocess(self, prediction):
"""
All interfaces should define a method that returns the path to its template.
All interfaces should define a default postprocessing method
"""
pass
class Class(AbstractOutput):
class Label(AbstractOutput):
LABEL_KEY = 'label'
CONFIDENCES_KEY = 'confidences'
CONFIDENCE_KEY = 'confidence'
def _get_template_path(self):
return 'templates/class_output.html'
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True, label_names=None,
max_label_length=None):
self.num_top_classes = num_top_classes
self.show_confidences = show_confidences
self.label_names = label_names
self.max_label_length = max_label_length
super().__init__(postprocessing_fn=postprocessing_fn)
def _post_process(self, prediction):
def get_label_name(self, label):
if self.label_names is None:
name = label
elif self.label_names == 'imagenet1000':
name = imagenet_class_labels.NAMES1000[label]
else: # if list or dictionary
name = self.label_names[label]
if self.max_label_length is not None:
name = name[:self.max_label_length]
return name
def get_template_path(self):
return 'templates/label_output.html'
def postprocess(self, prediction):
"""
"""
response = dict()
# TODO(abidlabs): check if list, if so convert to numpy array
if isinstance(prediction, np.ndarray):
prediction = prediction.squeeze()
if prediction.size == 1:
return prediction
else:
return prediction.argmax()
if prediction.size == 1: # if it's single value
response[Label.LABEL_KEY] = self.get_label_name(np.asscalar(prediction))
elif len(prediction.shape) == 1: # if a 1D
response[Label.LABEL_KEY] = self.get_label_name(int(prediction.argmax()))
if self.show_confidences:
response[Label.CONFIDENCES_KEY] = []
for i in range(self.num_top_classes):
response[Label.CONFIDENCES_KEY].append({
Label.LABEL_KEY: self.get_label_name(int(prediction.argmax())),
Label.CONFIDENCE_KEY: float(prediction.max()),
})
prediction[prediction.argmax()] = 0
elif isinstance(prediction, str):
return prediction
response[Label.LABEL_KEY] = prediction
else:
raise ValueError("Unable to post-process model prediction.")
return json.dumps(response)
class Textbox(AbstractOutput):
def _get_template_path(self):
def get_template_path(self):
return 'templates/textbox_output.html'
def _post_process(self, prediction):
def postprocess(self, prediction):
"""
"""
return prediction

View File

@ -0,0 +1,28 @@
<html lang="en">
<head>
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<title>Gradio</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<link rel="stylesheet" href="../static/css/style.css">
<link rel="stylesheet" href="../static/css/gradio.css">
<script src="../static/js/utils.js"></script>
<script src="../static/js/all-io.js"></script>
<script src="https://code.jquery.com/jquery-3.2.1.slim.min.js" integrity="sha384-KJ3o2DKtIkvYIK3UENzmM7KCkRr/rE9/Qpg6aAZGJwFDMVNA/GpGFF93hXpG5KkN" crossorigin="anonymous"></script>
</head>
<body>
<nav>
<img src="../static/img/logo_inline.png" />
</nav>
<div id="panels">
<div class="panel">
<div id="input"></div>
<input class="submit" type="submit" value="Submit"/><!--DO NOT DELETE
--><input class="clear" type="reset" value="Clear">
</div>
<div class="panel">
<div id="output"></div>
</div>
</div>
</body>
</html>

View File

@ -1,31 +1,11 @@
<!-- <link rel="stylesheet" href="../css/dropzone.css">
<div class="col-6">
<h5>Image Upload Input:</h5>
<input type="file" onchange="previewFile()"><br>
<img src="" height="200" alt="Image preview...">
<div class="btn-group" role="group" aria-label="Basic example">
<button type="button" class="btn btn-primary" id="submit-button">Submit</button>
<button type="button" class="btn btn-secondary" id="clear-button">Clear</button>
<div class="gradio input image_file">
<div class="role">Input</div>
<div class="input_image drop_mode">
<div class="input_caption">Drop Image Here<br>- or -<br>Click to Upload</div>
<img />
</div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
</div>
<script src="../js/image-upload-input.js"></script>
-->
<div class="col-md-6">
<div> <h5>Image Upload Input:</h5> </div>
<div class="uploader" style="text-align: center; vertical-align: middle" onclick="$('#filePhoto').click()">
<br>Click here <br> or <br> drag and drop <br>an image
<img/>
<input type="file" name="userprofile_picture" id="filePhoto" />
</div>
<div class="btn-group" role="group" aria-label="Basic example">
<button type="button" class="btn btn-primary" id="submit-button">Submit</button>
<button type="button" class="btn btn-secondary" id="clear-button">Clear</button>
</div>
</div>
<script src="../js/image-upload-input.js"></script>
<link rel="stylesheet" href="https://fengyuanchen.github.io/cropper/css/cropper.css">
<script src="https://fengyuanchen.github.io/cropper/js/cropper.js"></script>
<script src="../static/js/image-upload-input.js"></script>

View File

@ -0,0 +1,7 @@
<div class="gradio output classifier">
<div class="role">Output</div>
<div class="output_class"></div>
<div class="confidence_intervals">
</div>
</div>
<script src="../static/js/class-output.js"></script>

BIN
dist/gradio-0.3.1.tar.gz vendored Normal file

Binary file not shown.

View File

@ -1,6 +1,6 @@
Metadata-Version: 1.0
Name: gradio
Version: 0.2.1
Version: 0.3.1
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/abidlabs/gradio
Author: Abubakar Abid

View File

@ -2,6 +2,7 @@ MANIFEST.in
README.md
setup.py
gradio/__init__.py
gradio/imagenet_class_labels.py
gradio/inputs.py
gradio/interface.py
gradio/networking.py
@ -12,50 +13,10 @@ gradio.egg-info/SOURCES.txt
gradio.egg-info/dependency_links.txt
gradio.egg-info/requires.txt
gradio.egg-info/top_level.txt
gradio/css/.DS_Store
gradio/css/bootstrap-grid.css
gradio/css/bootstrap-grid.css.map
gradio/css/bootstrap-grid.min.css
gradio/css/bootstrap-grid.min.css.map
gradio/css/bootstrap-reboot.css
gradio/css/bootstrap-reboot.css.map
gradio/css/bootstrap-reboot.min.css
gradio/css/bootstrap-reboot.min.css.map
gradio/css/bootstrap.css
gradio/css/bootstrap.css.map
gradio/css/bootstrap.min.css
gradio/css/bootstrap.min.css.map
gradio/css/draw-a-digit.css
gradio/css/dropzone.css
gradio/css/index.css
gradio/js/all-io.js
gradio/js/audio-input.js
gradio/js/bootstrap-notify.min.js
gradio/js/bootstrap.bundle.js
gradio/js/bootstrap.bundle.js.map
gradio/js/bootstrap.bundle.min.js
gradio/js/bootstrap.bundle.min.js.map
gradio/js/bootstrap.js
gradio/js/bootstrap.js.map
gradio/js/bootstrap.min.js
gradio/js/bootstrap.min.js.map
gradio/js/class-output.js
gradio/js/draw-a-digit.js
gradio/js/dropzone.js
gradio/js/emotion-detector.js
gradio/js/image-upload-input.js
gradio/js/jquery-3.3.1.min.js
gradio/js/sketchpad-input.js
gradio/js/textbox-input.js
gradio/js/textbox-output.js
gradio/js/webcam-input.js
gradio/templates/all_io.html
gradio/templates/audio_input.html
gradio/templates/class_output.html
gradio/templates/draw_a_digit.html
gradio/templates/emotion_detector.html
gradio/templates/base_template.html
gradio/templates/image_upload_input.html
gradio/templates/sketchpad_input.html
gradio/templates/textbox_input.html
gradio/templates/textbox_output.html
gradio/templates/webcam_input.html
gradio/templates/label_output.html
test/test_inputs.py
test/test_interface.py
test/test_networking.py
test/test_outputs.py

View File

@ -69,7 +69,6 @@ class Label(AbstractOutput):
"""
"""
response = dict()
print('dddddddddddddddddddd', self.get_label_name(499))
# TODO(abidlabs): check if list, if so convert to numpy array
if isinstance(prediction, np.ndarray):
prediction = prediction.squeeze()

View File

@ -27,19 +27,19 @@ try {
sleep(300).then(() => {
// $(".output_class").text(event.data);
var data = JSON.parse(event.data)
data = {
label: "happy",
confidences : [
{
label : "happy",
confidence: 0.7
},
{
label : "sad",
confidence: 0.3
},
]
}
// data = {
// label: "happy",
// confidences : [
// {
// label : "happy",
// confidence: 0.7
// },
// {
// label : "sad",
// confidence: 0.3
// },
// ]
// }
$(".output_class").text(data["label"])
$(".confidence_intervals").empty()
if ("confidences" in data) {

View File

@ -1,4 +1,4 @@
var cropper;
// var cropper;
$('body').on('click', ".input_image.drop_mode", function (e) {
$(this).parent().find(".hidden_upload").click();
@ -17,10 +17,10 @@ function loadPreviewFromFiles(files) {
$(".input_image").removeClass("drop_mode")
var image = $(".input_image img")
image.attr("src", this.result)
image.cropper({aspectRatio : 1.0});
if (!cropper) {
cropper = image.data('cropper');
}
// image.cropper({aspectRatio : 1.0});
// if (!cropper) {
// cropper = image.data('cropper');
// }
}
}
@ -50,10 +50,10 @@ $('body').on('click', '.submit', function(e) {
})
$('body').on('click', '.clear', function(e) {
if (cropper) {
cropper.destroy();
cropper = null
}
// if (cropper) {
// cropper.destroy();
// cropper = null
// }
$(".input_caption").show()
$(".input_image img").removeAttr("src");
$(".input_image").addClass("drop_mode")

View File

@ -5,7 +5,7 @@ except ImportError:
setup(
name='gradio',
version='0.2.1',
version='0.3.1',
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid',