pass keras graph into thread

This commit is contained in:
Abubakar Abid 2019-04-18 23:44:19 -07:00
parent ef404beac0
commit e08dba3b22
6 changed files with 50 additions and 19 deletions

View File

@ -27,9 +27,7 @@
"metadata": {},
"outputs": [],
"source": [
"# model = tf.keras.applications.inception_v3.InceptionV3()\n",
"def model(x):\n",
" return \"test class\""
"model = tf.keras.applications.inception_v3.InceptionV3()"
]
},
{
@ -51,13 +49,12 @@
"io = gradio.Interface(inputs=inp, \n",
" outputs=out,\n",
" model=model, \n",
"# model_type='keras')\n",
" model_type='pyfunc')"
" model_type='keras')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {
"scrolled": false
},
@ -66,10 +63,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Closing existing server...\n",
"NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n",
"Model is running locally at: http://localhost:7860/\n",
"Model available publicly for 8 hours at: https://e46308c0.gradio.app/\n"
"Model available publicly at: https://10227.gradio.app -- may take up to a minute to setup.\n"
]
},
{
@ -86,7 +82,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x239f4273a90>"
"<IPython.lib.display.IFrame at 0x1cc2975aa90>"
]
},
"metadata": {},

View File

@ -10,10 +10,15 @@ import webbrowser
import gradio.inputs
import gradio.outputs
from gradio import networking, strings
from distutils.version import StrictVersion
import pkg_resources
import requests
import termcolor
LOCALHOST_IP = "127.0.0.1"
INITIAL_WEBSOCKET_PORT = 9200
TRY_NUM_PORTS = 100
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
class Interface:
@ -80,6 +85,9 @@ class Interface:
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
self.model_type = model_type
if self.model_type == "keras":
import tensorflow as tf
self.graph = tf.get_default_graph()
self.verbose = verbose
self.status = self.STATUS_TYPES["OFF"]
self.validate_flag = False
@ -127,12 +135,12 @@ class Interface:
if self.model_type == "sklearn":
return self.model_obj.predict(preprocessed_input)
elif self.model_type == "keras":
return self.model_obj.predict(preprocessed_input)
with self.graph.as_default():
return self.model_obj.predict(preprocessed_input)
elif self.model_type == "pyfunc":
return self.model_obj(preprocessed_input)
elif self.model_type == "pytorch":
import torch
value = torch.from_numpy(preprocessed_input)
value = torch.autograd.Variable(value)
prediction = self.model_obj(value)
@ -237,6 +245,12 @@ class Interface:
except NameError:
pass
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, however version {latest_pkg_version} "
f"is available, please upgrade.")
print('--------')
if self.verbose:
print(strings.en["BETA_MESSAGE"])
if not is_colab:

View File

@ -13,6 +13,7 @@ from gradio import inputs, outputs
import json
from gradio.tunneling import create_tunnel
import urllib.request
from shutil import copyfile
INITIAL_PORT_VALUE = (
7860
@ -32,6 +33,9 @@ TEMPLATE_TEMP = "index.html"
BASE_JS_FILE = "static/js/all_io.js"
CONFIG_FILE = "static/config.json"
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
def build_template(temp_dir, input_interface, output_interface):
"""
@ -75,6 +79,10 @@ def build_template(temp_dir, input_interface, output_interface):
f.write(str(all_io_soup))
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
render_template_with_tags(
os.path.join(
temp_dir,

View File

@ -0,0 +1,11 @@
{
"applinks": {
"apps": [],
"details": [
{
"appID": "RHW8FBGSTX.app.gradio.Gradio",
"paths": ["*"]
}
]
}
}

View File

@ -5,5 +5,5 @@ en = {
"restarting python interpreter.",
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly for 8 hours at: {}"
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} -- may take up to a minute to setup."
}

View File

@ -8,7 +8,7 @@ import socket
import sys
import threading
from io import StringIO
import warnings
import paramiko
DEBUG_MODE = False
@ -68,12 +68,14 @@ def create_tunnel(payload, local_server, local_server_port):
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
)
try:
client.connect(
hostname=payload["host"],
port=int(payload["port"]),
username=payload["user"],
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])),
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
client.connect(
hostname=payload["host"],
port=int(payload["port"]),
username=payload["user"],
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])),
)
except Exception as e:
print(
"*** Failed to connect to %s:%d: %r"