pass keras graph into thread

This commit is contained in:
Abubakar Abid 2019-04-18 23:44:19 -07:00
parent 12e05caed1
commit 38a6899d95
6 changed files with 50 additions and 19 deletions

View File

@ -27,9 +27,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# model = tf.keras.applications.inception_v3.InceptionV3()\n", "model = tf.keras.applications.inception_v3.InceptionV3()"
"def model(x):\n",
" return \"test class\""
] ]
}, },
{ {
@ -51,13 +49,12 @@
"io = gradio.Interface(inputs=inp, \n", "io = gradio.Interface(inputs=inp, \n",
" outputs=out,\n", " outputs=out,\n",
" model=model, \n", " model=model, \n",
"# model_type='keras')\n", " model_type='keras')"
" model_type='pyfunc')"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"metadata": { "metadata": {
"scrolled": false "scrolled": false
}, },
@ -66,10 +63,9 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Closing existing server...\n",
"NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n", "NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n",
"Model is running locally at: http://localhost:7860/\n", "Model is running locally at: http://localhost:7860/\n",
"Model available publicly for 8 hours at: https://e46308c0.gradio.app/\n" "Model available publicly at: https://10227.gradio.app -- may take up to a minute to setup.\n"
] ]
}, },
{ {
@ -86,7 +82,7 @@
" " " "
], ],
"text/plain": [ "text/plain": [
"<IPython.lib.display.IFrame at 0x239f4273a90>" "<IPython.lib.display.IFrame at 0x1cc2975aa90>"
] ]
}, },
"metadata": {}, "metadata": {},

View File

@ -10,10 +10,15 @@ import webbrowser
import gradio.inputs import gradio.inputs
import gradio.outputs import gradio.outputs
from gradio import networking, strings from gradio import networking, strings
from distutils.version import StrictVersion
import pkg_resources
import requests
import termcolor
LOCALHOST_IP = "127.0.0.1" LOCALHOST_IP = "127.0.0.1"
INITIAL_WEBSOCKET_PORT = 9200 INITIAL_WEBSOCKET_PORT = 9200
TRY_NUM_PORTS = 100 TRY_NUM_PORTS = 100
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
class Interface: class Interface:
@ -80,6 +85,9 @@ class Interface:
elif not (model_type.lower() in self.VALID_MODEL_TYPES): elif not (model_type.lower() in self.VALID_MODEL_TYPES):
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES)) ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
self.model_type = model_type self.model_type = model_type
if self.model_type == "keras":
import tensorflow as tf
self.graph = tf.get_default_graph()
self.verbose = verbose self.verbose = verbose
self.status = self.STATUS_TYPES["OFF"] self.status = self.STATUS_TYPES["OFF"]
self.validate_flag = False self.validate_flag = False
@ -127,12 +135,12 @@ class Interface:
if self.model_type == "sklearn": if self.model_type == "sklearn":
return self.model_obj.predict(preprocessed_input) return self.model_obj.predict(preprocessed_input)
elif self.model_type == "keras": elif self.model_type == "keras":
return self.model_obj.predict(preprocessed_input) with self.graph.as_default():
return self.model_obj.predict(preprocessed_input)
elif self.model_type == "pyfunc": elif self.model_type == "pyfunc":
return self.model_obj(preprocessed_input) return self.model_obj(preprocessed_input)
elif self.model_type == "pytorch": elif self.model_type == "pytorch":
import torch import torch
value = torch.from_numpy(preprocessed_input) value = torch.from_numpy(preprocessed_input)
value = torch.autograd.Variable(value) value = torch.autograd.Variable(value)
prediction = self.model_obj(value) prediction = self.model_obj(value)
@ -237,6 +245,12 @@ class Interface:
except NameError: except NameError:
pass pass
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, however version {latest_pkg_version} "
f"is available, please upgrade.")
print('--------')
if self.verbose: if self.verbose:
print(strings.en["BETA_MESSAGE"]) print(strings.en["BETA_MESSAGE"])
if not is_colab: if not is_colab:

View File

@ -13,6 +13,7 @@ from gradio import inputs, outputs
import json import json
from gradio.tunneling import create_tunnel from gradio.tunneling import create_tunnel
import urllib.request import urllib.request
from shutil import copyfile
INITIAL_PORT_VALUE = ( INITIAL_PORT_VALUE = (
7860 7860
@ -32,6 +33,9 @@ TEMPLATE_TEMP = "index.html"
BASE_JS_FILE = "static/js/all_io.js" BASE_JS_FILE = "static/js/all_io.js"
CONFIG_FILE = "static/config.json" CONFIG_FILE = "static/config.json"
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
def build_template(temp_dir, input_interface, output_interface): def build_template(temp_dir, input_interface, output_interface):
""" """
@ -75,6 +79,10 @@ def build_template(temp_dir, input_interface, output_interface):
f.write(str(all_io_soup)) f.write(str(all_io_soup))
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP)) copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
render_template_with_tags( render_template_with_tags(
os.path.join( os.path.join(
temp_dir, temp_dir,

View File

@ -0,0 +1,11 @@
{
"applinks": {
"apps": [],
"details": [
{
"appID": "RHW8FBGSTX.app.gradio.Gradio",
"paths": ["*"]
}
]
}
}

View File

@ -5,5 +5,5 @@ en = {
"restarting python interpreter.", "restarting python interpreter.",
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.", "COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.", "PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly for 8 hours at: {}" "MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} -- may take up to a minute to setup."
} }

View File

@ -8,7 +8,7 @@ import socket
import sys import sys
import threading import threading
from io import StringIO from io import StringIO
import warnings
import paramiko import paramiko
DEBUG_MODE = False DEBUG_MODE = False
@ -68,12 +68,14 @@ def create_tunnel(payload, local_server, local_server_port):
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"])) "Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
) )
try: try:
client.connect( with warnings.catch_warnings():
hostname=payload["host"], warnings.simplefilter("ignore")
port=int(payload["port"]), client.connect(
username=payload["user"], hostname=payload["host"],
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])), port=int(payload["port"]),
) username=payload["user"],
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])),
)
except Exception as e: except Exception as e:
print( print(
"*** Failed to connect to %s:%d: %r" "*** Failed to connect to %s:%d: %r"