This commit is contained in:
Abubakar Abid 2019-03-25 12:06:56 -07:00
parent ef5ac8a2ad
commit aee18085b1
53 changed files with 16587 additions and 89 deletions

View File

@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@ -25,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -34,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -49,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 4,
"metadata": {
"scrolled": false
},
@ -59,7 +50,7 @@
"output_type": "stream",
"text": [
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7899/interface.html\n",
"Model is running locally at: http://localhost:7860/interface.html\n",
"To create a public link, set `share=True` in the argument to `launch()`\n"
]
}

View File

@ -6,11 +6,16 @@ automatically added to a registry, which allows them to be easily referenced in
from abc import ABC, abstractmethod
import base64
from gradio import preprocessing_utils
from gradio import preprocessing_utils, validation_data
from io import BytesIO
import numpy as np
from PIL import Image, ImageOps
# Where to find the static resources associated with each template.
BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'templates/input/{}.html'
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
class AbstractInput(ABC):
"""
An abstract class for defining the methods that all gradio inputs should have.
@ -27,10 +32,29 @@ class AbstractInput(ABC):
self.preprocess = preprocessing_fn
super().__init__()
@abstractmethod
def get_template_path(self):
def get_validation_inputs(self):
"""
All interfaces should define a method that returns the path to its template.
An interface can optionally implement a method that returns a list of examples inputs that it should be able to
accept and preprocess for validation purposes.
"""
return []
def get_js_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
@abstractmethod
def get_name(self):
"""
All interfaces should define a method that returns a name used for identifying the related static resources.
"""
pass
@ -50,8 +74,8 @@ class Sketchpad(AbstractInput):
self.invert_colors = invert_colors
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/input/sketchpad.html'
def get_name(self):
return 'sketchpad'
def preprocess(self, inp):
"""
@ -74,8 +98,11 @@ class Webcam(AbstractInput):
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/input/webcam.html'
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_name(self):
return 'webcam'
def preprocess(self, inp):
"""
@ -90,9 +117,11 @@ class Webcam(AbstractInput):
class Textbox(AbstractInput):
def get_validation_inputs(self):
return validation_data.ENGLISH_TEXTS
def get_template_path(self):
return 'templates/input/textbox.html'
def get_name(self):
return 'textbox'
def preprocess(self, inp):
"""
@ -103,17 +132,24 @@ class Textbox(AbstractInput):
class ImageUpload(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB',
scale = 1/127.5, shift = -1):
scale=1/127.5, shift=-1, aspect_ratio="false"):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
self.image_mode = image_mode
self.scale = scale
self.shift = shift
self.aspect_ratio = aspect_ratio
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/input/image_upload.html'
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_name(self):
return 'image_upload'
def get_js_context(self):
return {'aspect_ratio': self.aspect_ratio}
def preprocess(self, inp):
"""
@ -132,5 +168,18 @@ class ImageUpload(AbstractInput):
return array
class CSV(AbstractInput):
def get_name(self):
# return 'templates/input/csv.html'
return 'csv'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
"""
return inp
# Automatically adds all subclasses of AbstractInput into a dictionary (keyed by class name) for easy referencing.
registry = {cls.__name__.lower(): cls for cls in AbstractInput.__subclasses__()}

View File

@ -12,6 +12,7 @@ import gradio.outputs
from gradio import networking
import tempfile
import threading
import traceback
nest_asyncio.apply()
@ -29,13 +30,14 @@ class Interface:
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'function': 'python function',
'pytorch': 'PyTorch model'}
STATUS_TYPES = {'OFF': 'off', 'RUNNING': 'running'}
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
"""
:param inputs: a string or `AbstractInput` representing the input interface.
:param outputs: a string or `AbstractOutput` representing the output interface.
:param model_obj: the model object, such as a sklearn classifier or keras model.
:param model: the model object, such as a sklearn classifier or keras model.
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
provided.
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
@ -63,6 +65,9 @@ class Interface:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
self.model_type = model_type
self.verbose = verbose
self.status = self.STATUS_TYPES['OFF']
self.validate_flag = False
self.simple_server = None
@staticmethod
def _infer_model_type(model):
@ -133,16 +138,72 @@ class Interface:
else:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
def launch(self, inline=None, browser=None, share=False):
def validate(self):
if self.validate_flag:
if self.verbose:
print("Interface already validated")
return
validation_inputs = self.input_interface.get_validation_inputs()
n = len(validation_inputs)
if n == 0:
self.validate_flag = True
if self.verbose:
print("No validation samples for this interface... skipping validation.")
return
for m, msg in enumerate(validation_inputs):
if self.verbose:
print(f"Validating samples: {m+1}/{n} [" + "="*(m+1) + "."*(n-m-1) + "]", end='\r')
try:
processed_input = self.input_interface.preprocess(msg)
prediction = self.predict(processed_input)
except Exception as e:
if self.verbose:
print("\n----------")
print("Validation failed, likely due to incompatible pre-processing and model input. See below:\n")
print(traceback.format_exc())
break
try:
_ = self.output_interface.postprocess(prediction)
except Exception as e:
if self.verbose:
print("\n----------")
print("Validation failed, likely due to incompatible model output and post-processing."
"See below:\n")
print(traceback.format_exc())
break
else: # This means if a break was not explicitly called
self.validate_flag = True
if self.verbose:
print("\n\nValidation passed successfully!")
return
raise RuntimeError("Validation did not pass")
def launch(self, inline=None, inbrowser=None, share=False, validate=True):
"""
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
:param inline: boolean. If True, then a gradio interface is created inline (e.g. in jupyter or colab notebook)
:param inbrowser: boolean. If True, then a new browser window opens with the gradio interface.
:param share: boolean. If True, then a share link is generated using ngrok is displayed to the user.
:param validate: boolean. If True, then the validation is run if the interface has not already been validated.
"""
output_directory = tempfile.mkdtemp()
if validate and not self.validate_flag:
self.validate()
# If an existing interface is running with this instance, close it.
if self.status == self.STATUS_TYPES['RUNNING']:
if self.verbose:
print("Closing existing server...")
if self.simple_server is not None:
try:
networking.close_server(self.simple_server)
except OSError:
pass
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(output_directory)
path_to_server = 'http://localhost:{}/'.format(server_port)
path_to_local_server = 'http://localhost:{}/'.format(server_port)
path_to_local_interface_page = path_to_local_server + networking.TEMPLATE_TEMP
networking.build_template(output_directory, self.input_interface, self.output_interface)
# Set up a port to serve a websocket that sets up the communication between the front-end and model.
@ -153,6 +214,9 @@ class Interface:
networking.set_interface_types_in_config_file(output_directory,
self.input_interface.__class__.__name__.lower(),
self.output_interface.__class__.__name__.lower())
self.status = self.STATUS_TYPES['RUNNING']
self.simple_server = httpd
is_colab = False
try: # Check if running interactively using ipython.
from_ipynb = get_ipython()
@ -164,24 +228,27 @@ class Interface:
if self.verbose:
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
if not is_colab:
print("Model is running locally at: {}".format(path_to_server + networking.TEMPLATE_TEMP))
print(f"Model is running locally at: {path_to_local_interface_page}")
if share:
try:
site_ngrok_url = networking.setup_ngrok(server_port, websocket_port, output_directory)
path_to_ngrok_server = networking.setup_ngrok(server_port, websocket_port, output_directory)
path_to_ngrok_interface_page = path_to_ngrok_server + '/' + networking.TEMPLATE_TEMP
if self.verbose:
print("Model available publicly for 8 hours at: {}".format(
site_ngrok_url + '/' + networking.TEMPLATE_TEMP))
print(f"Model available publicly for 8 hours at: {path_to_ngrok_interface_page}")
except RuntimeError:
site_ngrok_url = None
path_to_ngrok_server = None
if self.verbose:
print("Unable to create public link for interface, please check internet connection.")
else:
if self.verbose:
print("To create a public link, set `share=True` in the argument to `launch()`")
site_ngrok_url = None
if is_colab:
site_ngrok_url = networking.setup_ngrok(server_port, websocket_port, output_directory)
path_to_ngrok_server = None
if is_colab: # for a colab notebook, create a public link even if share is False.
path_to_ngrok_server = networking.setup_ngrok(server_port, websocket_port, output_directory)
path_to_ngrok_interface_page = path_to_ngrok_server + '/' + networking.TEMPLATE_TEMP
print(f"Cannot display local interface on google colab, public link created at:"
f"{path_to_ngrok_interface_page} and displayed below.")
# Keep the server running in the background.
asyncio.get_event_loop().run_until_complete(start_server)
try:
@ -192,26 +259,24 @@ class Interface:
if inline is None:
try: # Check if running interactively using ipython.
from_ipynb = get_ipython()
get_ipython()
inline = True
if browser is None:
browser = False
if inbrowser is None:
inbrowser = False
except NameError:
inline = False
if browser is None:
browser = True
if inbrowser is None:
inbrowser = True
else:
if browser is None:
browser = False
if browser and not is_colab:
webbrowser.open(path_to_server + networking.TEMPLATE_TEMP) # Open a browser tab with the interface.
if inbrowser is None:
inbrowser = False
if inbrowser and not is_colab:
webbrowser.open(path_to_local_interface_page) # Open a browser tab with the interface.
if inline:
from IPython.display import IFrame
if is_colab:
print("Cannot display Interface inline on google colab, public link created at: {} and displayed below.".format(
site_ngrok_url + '/' + networking.TEMPLATE_TEMP))
display(IFrame(site_ngrok_url + '/' + networking.TEMPLATE_TEMP, width=1000, height=500))
if is_colab: # Embed the remote interface page if on google colab; otherwise, embed the local page.
display(IFrame(path_to_ngrok_interface_page, width=1000, height=500))
else:
display(IFrame(path_to_server + networking.TEMPLATE_TEMP, width=1000, height=500))
display(IFrame(path_to_local_interface_page, width=1000, height=500))
return httpd, path_to_server + networking.TEMPLATE_TEMP, site_ngrok_url
return httpd, path_to_local_server, path_to_ngrok_server

View File

@ -14,12 +14,12 @@ from signal import SIGTERM # or SIGKILL
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import stat
import time
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import pkg_resources
from bs4 import BeautifulSoup
from distutils import dir_util
from gradio import inputs, outputs
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
@ -50,12 +50,17 @@ def build_template(temp_dir, input_interface, output_interface):
:param input_interface: an AbstractInput object which includes is used to get the input template
:param output_interface: an AbstractInput object which includes is used to get the input template
"""
input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path())
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
input_template_path = pkg_resources.resource_filename(
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
output_template_path = pkg_resources.resource_filename(
'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()))
input_page = open(input_template_path)
output_page = open(output_template_path)
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
input_soup = BeautifulSoup(render_string_or_list_with_tags(
input_page.read(), input_interface.get_template_context()), features="html.parser")
output_soup = BeautifulSoup(
render_string_or_list_with_tags(
output_page.read(), output_interface.get_template_context()), features="html.parser")
all_io_page = open(BASE_TEMPLATE)
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
@ -69,6 +74,12 @@ def build_template(temp_dir, input_interface, output_interface):
f.write(str(all_io_soup))
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
render_template_with_tags(os.path.join(
temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
input_interface.get_js_context())
render_template_with_tags(os.path.join(
temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())),
output_interface.get_js_context())
def copy_files(src_dir, dest_dir):
@ -89,16 +100,28 @@ def render_template_with_tags(template_path, context):
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r'{{' + key + r'}}', value)
new_lines.append(line)
new_lines = render_string_or_list_with_tags(old_lines, context)
with open(template_path, 'w') as fout:
for line in new_lines:
fout.write(line)
def render_string_or_list_with_tags(old_lines, context):
# Handle string case
if isinstance(old_lines, str):
for key, value in context.items():
old_lines = old_lines.replace(r'{{' + key + r'}}', str(value))
return old_lines
# Handle list case
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r'{{' + key + r'}}', str(value))
new_lines.append(line)
return new_lines
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')
@ -157,7 +180,6 @@ def serve_files_in_background(port, directory_to_serve=None):
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
httpd = HTTPServer(directory_to_serve, (LOCALHOST_NAME, port))
# Now loop forever
@ -166,7 +188,7 @@ def serve_files_in_background(port, directory_to_serve=None):
while True:
# sys.stdout.flush()
httpd.serve_forever()
except KeyboardInterrupt:
except (KeyboardInterrupt, OSError):
httpd.server_close()
thread = threading.Thread(target=serve_forever, daemon=True)
@ -181,6 +203,11 @@ def start_simple_server(directory_to_serve=None):
return port, httpd
def close_server(server):
server.shutdown()
server.server_close()
def download_ngrok():
try:
zip_file_url = NGROK_ZIP_URLS[sys.platform]

View File

@ -9,6 +9,11 @@ import numpy as np
import json
from gradio import imagenet_class_labels
# Where to find the static resources associated with each template.
BASE_OUTPUT_INTERFACE_TEMPLATE_PATH = 'templates/output/{}.html'
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'
class AbstractOutput(ABC):
"""
An abstract class for defining the methods that all gradio inputs should have.
@ -23,10 +28,22 @@ class AbstractOutput(ABC):
self.postprocess = postprocessing_fn
super().__init__()
@abstractmethod
def get_template_path(self):
def get_js_context(self):
"""
All interfaces should define a method that returns the path to its template.
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
@abstractmethod
def get_name(self):
"""
All outputs should define a method that returns a name used for identifying the related static resources.
"""
pass
@ -51,10 +68,13 @@ class Label(AbstractOutput):
self.max_label_length = max_label_length
super().__init__(postprocessing_fn=postprocessing_fn)
def get_name(self):
return 'label'
def get_label_name(self, label):
if self.label_names is None:
name = label
elif self.label_names == 'imagenet1000':
elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this
name = imagenet_class_labels.NAMES1000[label]
else: # if list or dictionary
name = self.label_names[label]
@ -62,9 +82,6 @@ class Label(AbstractOutput):
name = name[:self.max_label_length]
return name
def get_template_path(self):
return 'templates/output/label.html'
def postprocess(self, prediction):
"""
"""
@ -93,8 +110,19 @@ class Label(AbstractOutput):
class Textbox(AbstractOutput):
def get_template_path(self):
return 'templates/output/textbox.html'
def get_name(self):
return 'textbox'
def postprocess(self, prediction):
"""
"""
return prediction
class Image(AbstractOutput):
def get_name(self):
return 'image'
def postprocess(self, prediction):
"""

File diff suppressed because one or more lines are too long

View File

@ -21,6 +21,16 @@
display: flex;
flex-flow: column;
}
.loading {
margin-left: auto;
}
.loading img {
display: none;
height: 22px;
}
.panel_head {
display: flex
}
.role {
text-transform: uppercase;
font-family: Arial;
@ -70,15 +80,15 @@
height: 100%;
padding: 0;
}
.input_image, .input_audio, .input_snapshot, .input_mic, .output_class,
.output_image {
.input_image, .input_audio, .input_snapshot, .input_mic, .input_csv, .output_class,
.output_image, .output_csv {
flex: 1 1 auto;
display: flex;
justify-content: center;
align-items: center;
text-align: center;
}
.input_image, .input_audio, .input_snapshot, .input_mic {
.input_image, .input_audio, .input_snapshot, .input_mic, .input_csv {
font-weight: bold;
font-size: 24px;
color: #BBB;
@ -182,6 +192,41 @@ canvas {
border: solid 1px black;
}
textarea {
.text textarea {
resize: none;
background-color: white;
border: none;
box-sizing: border-box;
padding: 4px;
}
.output_image img {
display: none
}
.table_holder {
max-width: 100%;
max-height: 100%;
overflow: scroll;
display: none;
}
.csv_preview {
background-color: white;
max-width: 100%;
max-height: 100%;
font-size: 12px;
font-family: monospace;
}
.csv_preview tr {
border-bottom: solid 1px black;
}
.csv_preview tr.header td {
background-color: #EEA45D;
font-weight: bold;
}
.csv_preview td {
padding: 2px 4px;
}
.csv_preview td:nth-child(even) {
background-color: #EEEEEE;
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@ -45,3 +45,24 @@ try {
const sleep = (milliseconds) => {
return new Promise(resolve => setTimeout(resolve, milliseconds))
}
$('body').on('click', '.flag', function(e) {
if ($(".flag").hasClass("flagged")) {
$(".flag").removeClass("flagged").attr("value", "flag");
} else {
$(".flag").addClass("flagged").attr("value", "flagged");
}
})
var start_time;
function loadStart() {
$(".loading img").show();
$(".loading_time").text("");
start_time = new Date().getTime()
}
function loadEnd() {
$(".loading img").hide();
end_time = new Date().getTime()
$(".loading_time").text(((end_time - start_time) / 1000).toFixed(2) + "s");
}

File diff suppressed because one or more lines are too long

View File

@ -2,12 +2,14 @@
<head>
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<title>Gradio</title>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<link rel="stylesheet" href="../static/css/font-awesome.min.css">
<link rel="stylesheet" href="../static/css/style.css">
<link rel="stylesheet" href="../static/css/gradio.css">
<link rel="stylesheet" type="text/css" href="../static/css/loading.css"/>
<link rel="shortcut icon" type="image/png" href="../static/img/logo_mini.png"/>
<script src="../static/js/jquery-3.2.1.slim.min.js"></script>
<script src="../static/js/utils.js"></script>
<script src="../static/js/all-io.js"></script>
<script src="https://code.jquery.com/jquery-3.2.1.slim.min.js" integrity="sha384-KJ3o2DKtIkvYIK3UENzmM7KCkRr/rE9/Qpg6aAZGJwFDMVNA/GpGFF93hXpG5KkN" crossorigin="anonymous"></script>
</head>
<body>

View File

@ -0,0 +1,11 @@
<div class="gradio input csv">
<div class="role">Input</div>
<div class="input_csv drop_mode">
<div class="input_caption">Drop CSV File Here<br>- or -<br>Click to Upload</div>
</div>
<div class="table_holder"><table class="csv_preview"></table></div>
<input class="hidden_upload" type="file" accept=".csv" />
</div>
<script src="../static/js/vendor/papaparse.min.js"></script>
<script src="../static/js/interfaces/input/csv.js"></script>

View File

@ -7,6 +7,6 @@
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
</div>
<link rel="stylesheet" href="https://fengyuanchen.github.io/cropper/css/cropper.css">
<script src="https://fengyuanchen.github.io/cropper/js/cropper.js"></script>
<link rel="stylesheet" href="../static/css/vendor/cropper.css">
<script src="../static/js/vendor/cropper.js"></script>
<script src="../static/js/interfaces/input/image_upload.js"></script>

View File

@ -10,5 +10,5 @@
</div>
</div>
<script src="http://yiom.github.io/sketchpad/javascripts/sketchpad.js"></script>
<script src="../static/js/vendor/sketchpad.js"></script>
<script src="../static/js/interfaces/input/sketchpad.js"></script>

View File

@ -0,0 +1,12 @@
<div class="gradio output image">
<div class="panel_head">
<div class="role">Output</div>
<div class="loading">
<img src="../static/img/logo_mini.png" class="ld ld-skew"/>
<span class="loading_time"></span>
</div>
</div>
<div class="output_image"><img /></div>
</div>
<script src="../static/js/interfaces/output/image.js"></script>

View File

@ -1,5 +1,11 @@
<div class="gradio output classifier">
<div class="role">Output</div>
<div class="panel_head">
<div class="role">Output</div>
<div class="loading">
<img src="../static/img/logo_mini.png" class="ld ld-skew"/>
<span class="loading_time"></span>
</div>
</div>
<div class="output_class"></div>
<div class="confidence_intervals">
</div>

View File

@ -1,5 +1,11 @@
<div class="gradio output text">
<div class="role">Output</div>
<div class="panel_head">
<div class="role">Output</div>
<div class="loading">
<img src="../static/img/logo_mini.png" class="ld ld-skew"/>
<span class="loading_time"></span>
</div>
</div>
<textarea readonly class="output_text"></textarea>
</div>

File diff suppressed because one or more lines are too long

Binary file not shown.

7
gradio-0.3.5/MANIFEST.in Normal file
View File

@ -0,0 +1,7 @@
include gradio/static/*
include gradio/static/css/*
include gradio/static/js/*
include gradio/static/img/*
include gradio/templates/*
include gradio/templates/input/*
include gradio/templates/output/*

73
gradio-0.3.5/README.md Normal file
View File

@ -0,0 +1,73 @@
# Gradiome / Gradio
`Gradio` is a python library that allows you to place input and output interfaces over trained models to make it easy for you to "play around" with your model. Gradio runs entirely locally using your browser.
To get a sense of `gradio`, take a look at the python notebooks in the `examples` folder, or read on below!
## Installation
```
pip install gradio
```
(you may need to replace `pip` with `pip3` if you're running `python3`).
## Usage
Gradio is very easy to use with your existing code. The general way it's used is something like this:
```python
import tensorflow as tf
import gradio
mdl = tf.keras.models.Sequential()
# ... define and train the model as you would normally
iface = gradio.Interface(input=“sketchpad”, output=“class”, model_type=“keras”, model=mdl)
iface.launch()
```
Changing the `input` and `output` parameters in the `Interface` face object allow you to create different interfaces, depending on the needs of your model. Take a look at the python notebooks for more examples. The currently supported interfaces are as follows:
**Input interfaces**:
* Sketchpad
* ImageUplaod
* Webcam
* Textbox
**Output interfaces**:
* Class
* Textbox
## Screenshots
Here are a few screenshots that show examples of gradio interfaces
#### MNIST Digit Recognition (Input: Sketchpad, Output: Class)
```python
iface = gradio.Interface(input='sketchpad', output='class', model=model, model_type='keras')
iface.launch()
```
![alt text](https://raw.githubusercontent.com/abidlabs/gradio/master/screenshots/mnist4.png)
#### Facial Emotion Detector (Input: Webcam, Output: Class)
```python
iface = gradio.Interface(input='webcam', output='class', model=model, model_type='keras')
iface.launch()
```
![alt text](https://raw.githubusercontent.com/abidlabs/gradio/master/screenshots/webcam_happy.png)
#### Sentiment Analysis (Input: Textbox, Output: Class)
```python
iface = gradio.Interface(input='textbox', output='class', model=model, model_type='keras')
iface.launch()
```
![alt text](https://raw.githubusercontent.com/abidlabs/gradio/master/screenshots/sentiment_positive.png)

View File

@ -0,0 +1,11 @@
Metadata-Version: 1.0
Name: gradio
Version: 0.3.5
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/abidlabs/gradio
Author: Abubakar Abid
Author-email: a12d@stanford.edu
License: UNKNOWN
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN

View File

@ -0,0 +1,43 @@
MANIFEST.in
README.md
setup.py
gradio/__init__.py
gradio/imagenet_class_labels.py
gradio/inputs.py
gradio/interface.py
gradio/networking.py
gradio/outputs.py
gradio/preprocessing_utils.py
gradio/validation_data.py
gradio.egg-info/PKG-INFO
gradio.egg-info/SOURCES.txt
gradio.egg-info/dependency_links.txt
gradio.egg-info/requires.txt
gradio.egg-info/top_level.txt
gradio/static/config.json
gradio/static/css/.DS_Store
gradio/static/css/font-awesome.min.css
gradio/static/css/gradio.css
gradio/static/css/loading.css
gradio/static/css/style.css
gradio/static/img/logo.png
gradio/static/img/logo_inline.png
gradio/static/img/logo_mini.png
gradio/static/img/mic.png
gradio/static/img/table.png
gradio/static/img/webcam.png
gradio/static/js/all-io.js
gradio/static/js/jquery-3.2.1.slim.min.js
gradio/static/js/utils.js
gradio/templates/base_template.html
gradio/templates/input/csv.html
gradio/templates/input/image_upload.html
gradio/templates/input/sketchpad.html
gradio/templates/input/textbox.html
gradio/templates/output/image.html
gradio/templates/output/label.html
gradio/templates/output/textbox.html
test/test_inputs.py
test/test_interface.py
test/test_networking.py
test/test_outputs.py

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,7 @@
numpy
websockets
nest_asyncio
beautifulsoup4
Pillow
requests
psutil

View File

@ -0,0 +1 @@
gradio

View File

@ -0,0 +1 @@
from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,185 @@
"""
This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
`AbstractInput`, and each class must define a path to its template. All of the subclasses of `AbstractInput` are
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
"""
from abc import ABC, abstractmethod
import base64
from gradio import preprocessing_utils, validation_data
from io import BytesIO
import numpy as np
from PIL import Image, ImageOps
# Where to find the static resources associated with each template.
BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'templates/input/{}.html'
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
class AbstractInput(ABC):
"""
An abstract class for defining the methods that all gradio inputs should have.
When this is subclassed, it is automatically added to the registry
"""
def __init__(self, preprocessing_fn=None):
"""
:param preprocessing_fn: an optional preprocessing function that overrides the default
"""
if preprocessing_fn is not None:
if not callable(preprocessing_fn):
raise ValueError('`preprocessing_fn` must be a callable function')
self.preprocess = preprocessing_fn
super().__init__()
def get_validation_inputs(self):
"""
An interface can optionally implement a method that returns a list of examples inputs that it should be able to
accept and preprocess for validation purposes.
"""
return []
def get_js_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
@abstractmethod
def get_name(self):
"""
All interfaces should define a method that returns a name used for identifying the related static resources.
"""
pass
@abstractmethod
def preprocess(self, inp):
"""
All interfaces should define a default preprocessing method
"""
pass
class Sketchpad(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28,
invert_colors=True):
self.image_width = image_width
self.image_height = image_height
self.invert_colors = invert_colors
super().__init__(preprocessing_fn=preprocessing_fn)
def get_name(self):
return 'sketchpad'
def preprocess(self, inp):
"""
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
"""
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
if self.invert_colors:
im = ImageOps.invert(im)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height)
return array
class Webcam(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_name(self):
return 'webcam'
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array
class Textbox(AbstractInput):
def get_validation_inputs(self):
return validation_data.ENGLISH_TEXTS
def get_name(self):
return 'textbox'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to text.
"""
return inp
class ImageUpload(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB',
scale=1/127.5, shift=-1, aspect_ratio="false"):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
self.image_mode = image_mode
self.scale = scale
self.shift = shift
self.aspect_ratio = aspect_ratio
super().__init__(preprocessing_fn=preprocessing_fn)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_name(self):
return 'image_upload'
def get_js_context(self):
return {'aspect_ratio': self.aspect_ratio}
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
im = np.array(im).flatten()
im = im * self.scale + self.shift
if self.num_channels is None:
array = im.reshape(1, self.image_width, self.image_height)
else:
array = im.reshape(1, self.image_width, self.image_height, self.num_channels)
return array
class CSV(AbstractInput):
def get_name(self):
# return 'templates/input/csv.html'
return 'csv'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
"""
return inp
# Automatically adds all subclasses of AbstractInput into a dictionary (keyed by class name) for easy referencing.
registry = {cls.__name__.lower(): cls for cls in AbstractInput.__subclasses__()}

View File

@ -0,0 +1,282 @@
'''
This is the core file in the `gradio` package, and defines the Interface class, including methods for constructing the
interface using the input and output types.
'''
import asyncio
import websockets
import nest_asyncio
import webbrowser
import gradio.inputs
import gradio.outputs
from gradio import networking
import tempfile
import threading
import traceback
nest_asyncio.apply()
LOCALHOST_IP = '127.0.0.1'
INITIAL_WEBSOCKET_PORT = 9200
TRY_NUM_PORTS = 100
class Interface:
"""
The Interface class represents a general input/output interface for a machine learning model. During construction,
the appropriate inputs and outputs
"""
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'function': 'python function',
'pytorch': 'PyTorch model'}
STATUS_TYPES = {'OFF': 'off', 'RUNNING': 'running'}
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
"""
:param inputs: a string or `AbstractInput` representing the input interface.
:param outputs: a string or `AbstractOutput` representing the output interface.
:param model: the model object, such as a sklearn classifier or keras model.
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
provided.
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
"""
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)
if verbose:
print("Model type not explicitly identified, inferred to be: {}".format(
self.VALID_MODEL_TYPES[model_type]))
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
self.model_type = model_type
self.verbose = verbose
self.status = self.STATUS_TYPES['OFF']
self.validate_flag = False
self.simple_server = None
@staticmethod
def _infer_model_type(model):
""" Helper method that attempts to identify the type of trained ML model."""
try:
import sklearn
if isinstance(model, sklearn.base.BaseEstimator):
return 'sklearn'
except ImportError:
pass
try:
import tensorflow as tf
if isinstance(model, tf.keras.Model):
return 'keras'
except ImportError:
pass
try:
import keras
if isinstance(model, keras.Model):
return 'keras'
except ImportError:
pass
if callable(model):
return 'function'
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
async def communicate(self, websocket, path):
"""
Method that defines how this interface should communicates with the websocket. (1) When an input is received by
the websocket, it is passed into the input interface and preprocssed. (2) Then the model is called to make a
prediction. (3) Finally, the prediction is postprocessed to get something to be displayed by the output.
:param websocket: a Websocket server used to communicate with the interface frontend
:param path: not used, but required for compliance with websocket library
"""
while True:
try:
msg = await websocket.recv()
processed_input = self.input_interface.preprocess(msg)
prediction = self.predict(processed_input)
processed_output = self.output_interface.postprocess(prediction)
await websocket.send(str(processed_output))
except websockets.exceptions.ConnectionClosed:
pass
# except Exception as e:
# print(e)
def predict(self, preprocessed_input):
"""
Method that calls the relevant method of the model object to make a prediction.
:param preprocessed_input: the preprocessed input returned by the input interface
"""
if self.model_type=='sklearn':
return self.model_obj.predict(preprocessed_input)
elif self.model_type=='keras':
return self.model_obj.predict(preprocessed_input)
elif self.model_type=='function':
return self.model_obj(preprocessed_input)
elif self.model_type=='pytorch':
import torch
value = torch.from_numpy(preprocessed_input)
value = torch.autograd.Variable(value)
prediction = self.model_obj(value)
return prediction.data.numpy()
else:
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
def validate(self):
if self.validate_flag:
if self.verbose:
print("Interface already validated")
return
validation_inputs = self.input_interface.get_validation_inputs()
n = len(validation_inputs)
if n == 0:
self.validate_flag = True
if self.verbose:
print("No validation samples for this interface... skipping validation.")
return
for m, msg in enumerate(validation_inputs):
if self.verbose:
print(f"Validating samples: {m+1}/{n} [" + "="*(m+1) + "."*(n-m-1) + "]", end='\r')
try:
processed_input = self.input_interface.preprocess(msg)
prediction = self.predict(processed_input)
except Exception as e:
if self.verbose:
print("\n----------")
print("Validation failed, likely due to incompatible pre-processing and model input. See below:\n")
print(traceback.format_exc())
break
try:
_ = self.output_interface.postprocess(prediction)
except Exception as e:
if self.verbose:
print("\n----------")
print("Validation failed, likely due to incompatible model output and post-processing."
"See below:\n")
print(traceback.format_exc())
break
else: # This means if a break was not explicitly called
self.validate_flag = True
if self.verbose:
print("\n\nValidation passed successfully!")
return
raise RuntimeError("Validation did not pass")
def launch(self, inline=None, inbrowser=None, share=False, validate=True):
"""
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
:param inline: boolean. If True, then a gradio interface is created inline (e.g. in jupyter or colab notebook)
:param inbrowser: boolean. If True, then a new browser window opens with the gradio interface.
:param share: boolean. If True, then a share link is generated using ngrok is displayed to the user.
:param validate: boolean. If True, then the validation is run if the interface has not already been validated.
"""
if validate and not self.validate_flag:
self.validate()
# If an existing interface is running with this instance, close it.
if self.status == self.STATUS_TYPES['RUNNING']:
if self.verbose:
print("Closing existing server...")
if self.simple_server is not None:
try:
networking.close_server(self.simple_server)
except OSError:
pass
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(output_directory)
path_to_local_server = 'http://localhost:{}/'.format(server_port)
path_to_local_interface_page = path_to_local_server + networking.TEMPLATE_TEMP
networking.build_template(output_directory, self.input_interface, self.output_interface)
# Set up a port to serve a websocket that sets up the communication between the front-end and model.
websocket_port = networking.get_first_available_port(
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
start_server = websockets.serve(self.communicate, LOCALHOST_IP, websocket_port)
networking.set_socket_port_in_js(output_directory, websocket_port) # sets the websocket port in the JS file.
networking.set_interface_types_in_config_file(output_directory,
self.input_interface.__class__.__name__.lower(),
self.output_interface.__class__.__name__.lower())
self.status = self.STATUS_TYPES['RUNNING']
self.simple_server = httpd
is_colab = False
try: # Check if running interactively using ipython.
from_ipynb = get_ipython()
if 'google.colab' in str(from_ipynb):
is_colab = True
except NameError:
pass
if self.verbose:
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
if not is_colab:
print(f"Model is running locally at: {path_to_local_interface_page}")
if share:
try:
path_to_ngrok_server = networking.setup_ngrok(server_port, websocket_port, output_directory)
path_to_ngrok_interface_page = path_to_ngrok_server + '/' + networking.TEMPLATE_TEMP
if self.verbose:
print(f"Model available publicly for 8 hours at: {path_to_ngrok_interface_page}")
except RuntimeError:
path_to_ngrok_server = None
if self.verbose:
print("Unable to create public link for interface, please check internet connection.")
else:
if self.verbose:
print("To create a public link, set `share=True` in the argument to `launch()`")
path_to_ngrok_server = None
if is_colab: # for a colab notebook, create a public link even if share is False.
path_to_ngrok_server = networking.setup_ngrok(server_port, websocket_port, output_directory)
path_to_ngrok_interface_page = path_to_ngrok_server + '/' + networking.TEMPLATE_TEMP
print(f"Cannot display local interface on google colab, public link created at:"
f"{path_to_ngrok_interface_page} and displayed below.")
# Keep the server running in the background.
asyncio.get_event_loop().run_until_complete(start_server)
try:
_ = get_ipython()
except NameError: # Runtime errors are thrown in jupyter notebooks because of async.
t = threading.Thread(target=asyncio.get_event_loop().run_forever, daemon=True)
t.start()
if inline is None:
try: # Check if running interactively using ipython.
get_ipython()
inline = True
if inbrowser is None:
inbrowser = False
except NameError:
inline = False
if inbrowser is None:
inbrowser = True
else:
if inbrowser is None:
inbrowser = False
if inbrowser and not is_colab:
webbrowser.open(path_to_local_interface_page) # Open a browser tab with the interface.
if inline:
from IPython.display import IFrame
if is_colab: # Embed the remote interface page if on google colab; otherwise, embed the local page.
display(IFrame(path_to_ngrok_interface_page, width=1000, height=500))
else:
display(IFrame(path_to_local_interface_page, width=1000, height=500))
return httpd, path_to_local_server, path_to_ngrok_server

View File

@ -0,0 +1,262 @@
'''
Defines helper methods useful for setting up ports, launching servers, and handling `ngrok`
'''
import subprocess
import requests
import zipfile
import io
import sys
import os
import socket
from psutil import process_iter, AccessDenied, NoSuchProcess
from signal import SIGTERM # or SIGKILL
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import stat
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import pkg_resources
from bs4 import BeautifulSoup
from distutils import dir_util
from gradio import inputs, outputs
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
LOCALHOST_NAME = 'localhost'
NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output)
NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output)
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html')
STATIC_PATH_LIB = pkg_resources.resource_filename('gradio', 'static/')
STATIC_PATH_TEMP = 'static/'
TEMPLATE_TEMP = 'interface.html'
BASE_JS_FILE = 'static/js/all-io.js'
CONFIG_FILE = 'static/config.json'
NGROK_ZIP_URLS = {
"linux": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip",
"darwin": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-darwin-amd64.zip",
"win32": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-windows-amd64.zip",
}
def build_template(temp_dir, input_interface, output_interface):
"""
Builds a complete HTML template with supporting JS and CSS files in a given directory.
:param temp_dir: string with path to temp directory in which the template should be built
:param input_interface: an AbstractInput object which includes is used to get the input template
:param output_interface: an AbstractInput object which includes is used to get the input template
"""
input_template_path = pkg_resources.resource_filename(
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
output_template_path = pkg_resources.resource_filename(
'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()))
input_page = open(input_template_path)
output_page = open(output_template_path)
input_soup = BeautifulSoup(render_string_or_list_with_tags(
input_page.read(), input_interface.get_template_context()), features="html.parser")
output_soup = BeautifulSoup(
render_string_or_list_with_tags(
output_page.read(), output_interface.get_template_context()), features="html.parser")
all_io_page = open(BASE_TEMPLATE)
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
input_tag = all_io_soup.find("div", {"id": "input"})
output_tag = all_io_soup.find("div", {"id": "output"})
input_tag.replace_with(input_soup)
output_tag.replace_with(output_soup)
f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w")
f.write(str(all_io_soup))
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
render_template_with_tags(os.path.join(
temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
input_interface.get_js_context())
render_template_with_tags(os.path.join(
temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())),
output_interface.get_js_context())
def copy_files(src_dir, dest_dir):
"""
Copies all the files from one directory to another
:param src_dir: string path to source directory
:param dest_dir: string path to destination directory
"""
dir_util.copy_tree(src_dir, dest_dir)
def render_template_with_tags(template_path, context):
"""
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
in double curly braces) with corresponding values.
:param template_path: a string with the path to the template file
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = render_string_or_list_with_tags(old_lines, context)
with open(template_path, 'w') as fout:
for line in new_lines:
fout.write(line)
def render_string_or_list_with_tags(old_lines, context):
# Handle string case
if isinstance(old_lines, str):
for key, value in context.items():
old_lines = old_lines.replace(r'{{' + key + r'}}', str(value))
return old_lines
# Handle list case
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r'{{' + key + r'}}', str(value))
new_lines.append(line)
return new_lines
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')
js_file = os.path.join(temp_dir, BASE_JS_FILE)
render_template_with_tags(js_file, {'ngrok_socket_url': ngrok_socket_url})
config_file = os.path.join(temp_dir, CONFIG_FILE)
render_template_with_tags(config_file, {'ngrok_socket_url': ngrok_socket_url})
def set_socket_port_in_js(temp_dir, socket_port):
js_file = os.path.join(temp_dir, BASE_JS_FILE)
render_template_with_tags(js_file, {'socket_port': str(socket_port)})
def set_interface_types_in_config_file(temp_dir, input_interface, output_interface):
config_file = os.path.join(temp_dir, CONFIG_FILE)
render_template_with_tags(config_file, {'input_interface_type': input_interface,
'output_interface_type': output_interface})
def get_first_available_port(initial, final):
"""
Gets the first open port in a specified range of port numbers
:param initial: the initial value in the range of port numbers
:param final: final (exclusive) value in the range of port numbers, should be greater than `initial`
:return:
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
pass
raise OSError("All ports from {} to {} are in use. Please close a port.".format(initial, final))
def serve_files_in_background(port, directory_to_serve=None):
class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()"""
def translate_path(self, path):
path = SimpleHTTPRequestHandler.translate_path(self, path)
relpath = os.path.relpath(path, os.getcwd())
fullpath = os.path.join(self.server.base_path, relpath)
return fullpath
def log_message(self, format, *args):
return
class HTTPServer(BaseHTTPServer):
"""The main server, you pass in base_path which is the path you want to serve requests from"""
def __init__(self, base_path, server_address, RequestHandlerClass=HTTPHandler):
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
httpd = HTTPServer(directory_to_serve, (LOCALHOST_NAME, port))
# Now loop forever
def serve_forever():
try:
while True:
# sys.stdout.flush()
httpd.serve_forever()
except (KeyboardInterrupt, OSError):
httpd.server_close()
thread = threading.Thread(target=serve_forever, daemon=True)
thread.start()
return httpd
def start_simple_server(directory_to_serve=None):
port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
httpd = serve_files_in_background(port, directory_to_serve)
return port, httpd
def close_server(server):
server.shutdown()
server.server_close()
def download_ngrok():
try:
zip_file_url = NGROK_ZIP_URLS[sys.platform]
except KeyError:
print("Sorry, we don't currently support your operating system, please leave us a note on GitHub, and "
"we'll look into it!")
return
r = requests.get(zip_file_url)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall()
if sys.platform == 'darwin' or sys.platform == 'linux':
st = os.stat('ngrok')
os.chmod('ngrok', st.st_mode | stat.S_IEXEC)
def create_ngrok_tunnel(local_port, api_url):
if not(os.path.isfile('ngrok.exe') or os.path.isfile('ngrok')):
download_ngrok()
if sys.platform == 'win32':
subprocess.Popen(['ngrok', 'http', str(local_port)])
else:
subprocess.Popen(['./ngrok', 'http', str(local_port)])
session = requests.Session()
retry = Retry(connect=3, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
r = session.get(api_url)
for tunnel in r.json()['tunnels']:
if '{}:'.format(LOCALHOST_NAME) + str(local_port) in tunnel['config']['addr']:
return tunnel['public_url']
raise RuntimeError("Not able to retrieve ngrok public URL")
def setup_ngrok(server_port, websocket_port, output_directory):
kill_processes([4040, 4041]) #TODO(abidlabs): better way to do this
site_ngrok_url = create_ngrok_tunnel(server_port, NGROK_TUNNELS_API_URL)
socket_ngrok_url = create_ngrok_tunnel(websocket_port, NGROK_TUNNELS_API_URL2)
set_ngrok_url_in_js(output_directory, socket_ngrok_url)
return site_ngrok_url
def kill_processes(process_ids): #TODO(abidlabs): remove this, we shouldn't need to kill
for proc in process_iter():
try:
for conns in proc.connections(kind='inet'):
if conns.laddr.port in process_ids:
proc.send_signal(SIGTERM) # or SIGKILL
except (AccessDenied, NoSuchProcess):
pass

View File

@ -0,0 +1,133 @@
"""
This module defines various classes that can serve as the `output` to an interface. Each class must inherit from
`AbstractOutput`, and each class must define a path to its template. All of the subclasses of `AbstractOutput` are
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
"""
from abc import ABC, abstractmethod
import numpy as np
import json
from gradio import imagenet_class_labels
# Where to find the static resources associated with each template.
BASE_OUTPUT_INTERFACE_TEMPLATE_PATH = 'templates/output/{}.html'
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'
class AbstractOutput(ABC):
"""
An abstract class for defining the methods that all gradio inputs should have.
When this is subclassed, it is automatically added to the registry
"""
def __init__(self, postprocessing_fn=None):
"""
:param postprocessing_fn: an optional postprocessing function that overrides the default
"""
if postprocessing_fn is not None:
self.postprocess = postprocessing_fn
super().__init__()
def get_js_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
@abstractmethod
def get_name(self):
"""
All outputs should define a method that returns a name used for identifying the related static resources.
"""
pass
@abstractmethod
def postprocess(self, prediction):
"""
All interfaces should define a default postprocessing method
"""
pass
class Label(AbstractOutput):
LABEL_KEY = 'label'
CONFIDENCES_KEY = 'confidences'
CONFIDENCE_KEY = 'confidence'
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True, label_names=None,
max_label_length=None):
self.num_top_classes = num_top_classes
self.show_confidences = show_confidences
self.label_names = label_names
self.max_label_length = max_label_length
super().__init__(postprocessing_fn=postprocessing_fn)
def get_name(self):
return 'label'
def get_label_name(self, label):
if self.label_names is None:
name = label
elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this
name = imagenet_class_labels.NAMES1000[label]
else: # if list or dictionary
name = self.label_names[label]
if self.max_label_length is not None:
name = name[:self.max_label_length]
return name
def postprocess(self, prediction):
"""
"""
response = dict()
# TODO(abidlabs): check if list, if so convert to numpy array
if isinstance(prediction, np.ndarray):
prediction = prediction.squeeze()
if prediction.size == 1: # if it's single value
response[Label.LABEL_KEY] = self.get_label_name(np.asscalar(prediction))
elif len(prediction.shape) == 1: # if a 1D
response[Label.LABEL_KEY] = self.get_label_name(int(prediction.argmax()))
if self.show_confidences:
response[Label.CONFIDENCES_KEY] = []
for i in range(self.num_top_classes):
response[Label.CONFIDENCES_KEY].append({
Label.LABEL_KEY: self.get_label_name(int(prediction.argmax())),
Label.CONFIDENCE_KEY: float(prediction.max()),
})
prediction[prediction.argmax()] = 0
elif isinstance(prediction, str):
response[Label.LABEL_KEY] = prediction
else:
raise ValueError("Unable to post-process model prediction.")
return json.dumps(response)
class Textbox(AbstractOutput):
def get_name(self):
return 'textbox'
def postprocess(self, prediction):
"""
"""
return prediction
class Image(AbstractOutput):
def get_name(self):
return 'image'
def postprocess(self, prediction):
"""
"""
return prediction
registry = {cls.__name__.lower(): cls for cls in AbstractOutput.__subclasses__()}

View File

@ -0,0 +1,53 @@
from PIL import Image
def resize_and_crop(img, size, crop_type='top'):
"""
Resize and crop an image to fit the specified size.
args:
img_path: path for the image to resize.
modified_path: path to store the modified image.
size: `(width, height)` tuple.
crop_type: can be 'top', 'middle' or 'bottom', depending on this
value, the image will cropped getting the 'top/left', 'midle' or
'bottom/rigth' of the image to fit the size.
raises:
Exception: if can not open the file in img_path of there is problems
to save the image.
ValueError: if an invalid `crop_type` is provided.
"""
# Get current and desired ratio for the images
img_ratio = img.size[0] // float(img.size[1])
ratio = size[0] // float(size[1])
# The image is scaled//cropped vertically or horizontally depending on the ratio
if ratio > img_ratio:
img = img.resize((size[0], size[0] * img.size[1] // img.size[0]),
Image.ANTIALIAS)
# Crop in the top, middle or bottom
if crop_type == 'top':
box = (0, 0, img.size[0], size[1])
elif crop_type == 'middle':
box = (0, (img.size[1] - size[1]) // 2, img.size[0], (img.size[1] + size[1]) // 2)
elif crop_type == 'bottom':
box = (0, img.size[1] - size[1], img.size[0], img.size[1])
else:
raise ValueError('ERROR: invalid value for crop_type')
img = img.crop(box)
elif ratio < img_ratio:
img = img.resize((size[1] * img.size[0] // img.size[1], size[1]),
Image.ANTIALIAS)
# Crop in the top, middle or bottom
if crop_type == 'top':
box = (0, 0, size[0], img.size[1])
elif crop_type == 'middle':
box = ((img.size[0] - size[0]) // 2, 0, (img.size[0] + size[0]) // 2, img.size[1])
elif crop_type == 'bottom':
box = (img.size[0] - size[0], 0, img.size[0], img.size[1])
else:
raise ValueError('ERROR: invalid value for crop_type')
img = img.crop(box)
else:
img = img.resize((size[0], size[1]),
Image.ANTIALIAS)
# If the scale is the same, we do not need to crop
return img

View File

@ -0,0 +1,5 @@
{
"input_interface_type": "{{input_interface_type}}",
"output_interface_type": "{{output_interface_type}}",
"ngrok_socket_url": "{{ngrok_socket_url}}"
}

BIN
gradio-0.3.5/gradio/static/css/.DS_Store vendored Normal file

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,232 @@
.panel {
display: inline-block;
max-width: 45%;
min-width: 300px;
box-sizing: border-box;
vertical-align: top;
}
.panel {
margin: 0 14px 14px;
}
.instructions {
margin-bottom: 10px;
}
.input, .output {
width: 100%;
height: 360px;
background-color: #F6F6F6;
margin-bottom: 16px;
box-sizing: border-box;
padding: 6px;
display: flex;
flex-flow: column;
}
.loading {
margin-left: auto;
}
.loading img {
display: none;
height: 22px;
}
.panel_head {
display: flex
}
.role {
text-transform: uppercase;
font-family: Arial;
color: #BBB;
margin-bottom: 6px;
font-size: 14px;
font-weight: bold;
}
.input.text .role, .output.text .role {
margin-left: 1px;
}
.submit, .clear, .flag, .message, .send-message {
background-color: #F6F6F6 !important;
padding: 8px !important;
box-sizing: border-box;
width: calc(50% - 8px);
text-transform: uppercase;
font-weight: bold;
border: 0 none;
}
.clear {
background-color: #F6F6F6 !important;
}
.submit {
background-color: #EEA45D !important;
color: white !important;
}
.submit {
margin-right: 8px;
}
.clear {
margin-left: 8px;
}
/*.flag:focus {
background-color: pink !important;
}
*/
.input_text, .output_text {
background: transparent;
resize: none
border: 0 none;
width: 100%;
font-size: 18px;
outline: none;
height: 100%;
padding: 0;
}
.input_image, .input_audio, .input_snapshot, .input_mic, .input_csv, .output_class,
.output_image, .output_csv {
flex: 1 1 auto;
display: flex;
justify-content: center;
align-items: center;
text-align: center;
}
.input_image, .input_audio, .input_snapshot, .input_mic, .input_csv {
font-weight: bold;
font-size: 24px;
color: #BBB;
cursor: pointer;
}
.input_image img {
max-height: 100%;
max-width: 100%;
}
.hidden_upload {
display: none;
}
.output_class {
font-weight: bold;
font-size: 36px;
}
.drop_mode {
border: dashed 8px #DDD;
}
.input_image, .input_audio {
line-height: 1.5em;
}
.input_snapshot, .input_mic {
flex-direction: column;
}
.input_snapshot .webcam, .input_mic .mic {
height: 80px;
}
.output_image img {
width: 100%; /* or any custom size */
height: 100%;
object-fit: contain;
}
.confidence_intervals {
font-size: 16px;
}
.confidence {
padding: 3px;
display: flex;
}
.level, .label {
display: inline-block;
}
.label {
width: 60px;
}
.confidence_intervals .level {
font-size: 14px;
margin-left: 8px;
margin-right: 8px;
background-color: #AAA;
padding: 2px 4px;
text-align: right;
font-family: monospace;
color: white;
font-weight: bold;
}
.confidence_intervals > * {
vertical-align: bottom;
}
.flag.flagged {
background-color: pink !important;
}
.sketchpad canvas {
background-color: white;
}
.sketch_tools {
flex: 0 1 auto;
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 16px;
}
.brush {
border-radius: 50%;
background-color: #AAA;
margin: 0px 20px;
cursor: pointer;
}
.brush.selected, .brush:hover {
background-color: black;
}
#brush_1 {
height: 8px;
width: 8px;
}
#brush_2 {
height: 16px;
width: 16px;
}
#brush_3 {
height: 24px;
width: 24px;
}
.canvas_holder {
flex: 1 1 auto;
text-align: center;
}
canvas {
border: solid 1px black;
}
.text textarea {
resize: none;
background-color: white;
border: none;
box-sizing: border-box;
padding: 4px;
}
.output_image img {
display: none
}
.table_holder {
max-width: 100%;
max-height: 100%;
overflow: scroll;
display: none;
}
.csv_preview {
background-color: white;
max-width: 100%;
max-height: 100%;
font-size: 12px;
font-family: monospace;
}
.csv_preview tr {
border-bottom: solid 1px black;
}
.csv_preview tr.header td {
background-color: #EEA45D;
font-weight: bold;
}
.csv_preview td {
padding: 2px 4px;
}
.csv_preview td:nth-child(even) {
background-color: #EEEEEE;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,45 @@
body {
font-family: 'Open Sans', sans-serif;
font-size: 12px;
margin: 0;
}
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {
background: none;
color: inherit;
border: none;
padding: 0;
font: inherit;
cursor: pointer;
outline: inherit;
-webkit-appearance: none;
border-radius: 0;
}
input[type="text"] {
cursor: text;
text-transform: none;
}
body > div, body > nav {
margin-left: 60px;
margin-right: 60px;
}
nav {
text-align: center;
padding: 16px 0 8px;
}
nav img {
margin-right: auto;
height: 32px;
}
#panels {
display: flex;
flex-flow: row;
flex-wrap: wrap;
justify-content: center;
}
.panel {
min-width: 300px;
margin: 40px 0 0;
flex-grow: 1;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

View File

@ -0,0 +1,68 @@
var NGROK_URL = "{{ngrok_socket_url}}"
var SOCKET_PORT = "{{socket_port}}"
function notifyError(error) {
$.notify({
// options
message: 'Not able to communicate with model (is python code still running?)'
},{
// settings
type: 'danger',
animate: {
enter: 'animated fadeInDown',
exit: 'animated fadeOutUp'
},
placement: {
from: "bottom",
align: "right"
},
delay: 5000
});
}
try {
var origin = window.location.origin;
if (origin.includes("ngrok")){
var ws = new WebSocket(NGROK_URL)
} else {
var ws = new WebSocket("ws://127.0.0.1:" + SOCKET_PORT + "/")
}
ws.onerror = function(evt) {
notifyError(evt)
};
ws.onclose = function(event) {
console.log("WebSocket is closed now.");
var model_status = $('#model-status')
model_status.html('Model: closed');
model_status.css('color', '#e23e44');
$('#overlay').css('visibility','visible')
};
} catch (e) {
console.log(e)
}
const sleep = (milliseconds) => {
return new Promise(resolve => setTimeout(resolve, milliseconds))
}
$('body').on('click', '.flag', function(e) {
if ($(".flag").hasClass("flagged")) {
$(".flag").removeClass("flagged").attr("value", "flag");
} else {
$(".flag").addClass("flagged").attr("value", "flagged");
}
})
var start_time;
function loadStart() {
$(".loading img").show();
$(".loading_time").text("");
start_time = new Date().getTime()
}
function loadEnd() {
$(".loading img").hide();
end_time = new Date().getTime()
$(".loading_time").text(((end_time - start_time) / 1000).toFixed(2) + "s");
}

File diff suppressed because one or more lines are too long

25
gradio-0.3.5/setup.py Normal file
View File

@ -0,0 +1,25 @@
try:
from setuptools import setup
except ImportError:
from distutils.core import setup
setup(
name='gradio',
version='0.3.5',
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid',
author_email='a12d@stanford.edu',
url='https://github.com/abidlabs/gradio',
packages=['gradio'],
keywords=['machine learning', 'visualization', 'reproducibility'],
install_requires=[
'numpy',
'websockets',
'nest_asyncio',
'beautifulsoup4',
'Pillow',
'requests',
'psutil',
],
)

View File

@ -1,6 +1,6 @@
Metadata-Version: 1.0
Name: gradio
Version: 0.3.5
Version: 0.4.0
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/abidlabs/gradio
Author: Abubakar Abid

View File

@ -8,6 +8,7 @@ gradio/interface.py
gradio/networking.py
gradio/outputs.py
gradio/preprocessing_utils.py
gradio/validation_data.py
gradio.egg-info/PKG-INFO
gradio.egg-info/SOURCES.txt
gradio.egg-info/dependency_links.txt
@ -15,19 +16,25 @@ gradio.egg-info/requires.txt
gradio.egg-info/top_level.txt
gradio/static/config.json
gradio/static/css/.DS_Store
gradio/static/css/font-awesome.min.css
gradio/static/css/gradio.css
gradio/static/css/loading.css
gradio/static/css/style.css
gradio/static/img/logo.png
gradio/static/img/logo_inline.png
gradio/static/img/logo_mini.png
gradio/static/img/mic.png
gradio/static/img/table.png
gradio/static/img/webcam.png
gradio/static/js/all-io.js
gradio/static/js/jquery-3.2.1.slim.min.js
gradio/static/js/utils.js
gradio/templates/base_template.html
gradio/templates/input/csv.html
gradio/templates/input/image_upload.html
gradio/templates/input/sketchpad.html
gradio/templates/input/textbox.html
gradio/templates/output/image.html
gradio/templates/output/label.html
gradio/templates/output/textbox.html
test/test_inputs.py

View File

@ -5,7 +5,7 @@ except ImportError:
setup(
name='gradio',
version='0.3.5',
version='0.4.0',
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid',