mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
pass keras graph into thread
This commit is contained in:
parent
12e05caed1
commit
38a6899d95
@ -27,9 +27,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# model = tf.keras.applications.inception_v3.InceptionV3()\n",
|
"model = tf.keras.applications.inception_v3.InceptionV3()"
|
||||||
"def model(x):\n",
|
|
||||||
" return \"test class\""
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -51,13 +49,12 @@
|
|||||||
"io = gradio.Interface(inputs=inp, \n",
|
"io = gradio.Interface(inputs=inp, \n",
|
||||||
" outputs=out,\n",
|
" outputs=out,\n",
|
||||||
" model=model, \n",
|
" model=model, \n",
|
||||||
"# model_type='keras')\n",
|
" model_type='keras')"
|
||||||
" model_type='pyfunc')"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 4,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"scrolled": false
|
"scrolled": false
|
||||||
},
|
},
|
||||||
@ -66,10 +63,9 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Closing existing server...\n",
|
|
||||||
"NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n",
|
"NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n",
|
||||||
"Model is running locally at: http://localhost:7860/\n",
|
"Model is running locally at: http://localhost:7860/\n",
|
||||||
"Model available publicly for 8 hours at: https://e46308c0.gradio.app/\n"
|
"Model available publicly at: https://10227.gradio.app -- may take up to a minute to setup.\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -86,7 +82,7 @@
|
|||||||
" "
|
" "
|
||||||
],
|
],
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<IPython.lib.display.IFrame at 0x239f4273a90>"
|
"<IPython.lib.display.IFrame at 0x1cc2975aa90>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -10,10 +10,15 @@ import webbrowser
|
|||||||
import gradio.inputs
|
import gradio.inputs
|
||||||
import gradio.outputs
|
import gradio.outputs
|
||||||
from gradio import networking, strings
|
from gradio import networking, strings
|
||||||
|
from distutils.version import StrictVersion
|
||||||
|
import pkg_resources
|
||||||
|
import requests
|
||||||
|
import termcolor
|
||||||
|
|
||||||
LOCALHOST_IP = "127.0.0.1"
|
LOCALHOST_IP = "127.0.0.1"
|
||||||
INITIAL_WEBSOCKET_PORT = 9200
|
INITIAL_WEBSOCKET_PORT = 9200
|
||||||
TRY_NUM_PORTS = 100
|
TRY_NUM_PORTS = 100
|
||||||
|
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||||
|
|
||||||
|
|
||||||
class Interface:
|
class Interface:
|
||||||
@ -80,6 +85,9 @@ class Interface:
|
|||||||
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
|
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
|
||||||
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
|
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
|
if self.model_type == "keras":
|
||||||
|
import tensorflow as tf
|
||||||
|
self.graph = tf.get_default_graph()
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.status = self.STATUS_TYPES["OFF"]
|
self.status = self.STATUS_TYPES["OFF"]
|
||||||
self.validate_flag = False
|
self.validate_flag = False
|
||||||
@ -127,12 +135,12 @@ class Interface:
|
|||||||
if self.model_type == "sklearn":
|
if self.model_type == "sklearn":
|
||||||
return self.model_obj.predict(preprocessed_input)
|
return self.model_obj.predict(preprocessed_input)
|
||||||
elif self.model_type == "keras":
|
elif self.model_type == "keras":
|
||||||
|
with self.graph.as_default():
|
||||||
return self.model_obj.predict(preprocessed_input)
|
return self.model_obj.predict(preprocessed_input)
|
||||||
elif self.model_type == "pyfunc":
|
elif self.model_type == "pyfunc":
|
||||||
return self.model_obj(preprocessed_input)
|
return self.model_obj(preprocessed_input)
|
||||||
elif self.model_type == "pytorch":
|
elif self.model_type == "pytorch":
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
value = torch.from_numpy(preprocessed_input)
|
value = torch.from_numpy(preprocessed_input)
|
||||||
value = torch.autograd.Variable(value)
|
value = torch.autograd.Variable(value)
|
||||||
prediction = self.model_obj(value)
|
prediction = self.model_obj(value)
|
||||||
@ -237,6 +245,12 @@ class Interface:
|
|||||||
except NameError:
|
except NameError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||||
|
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||||
|
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||||
|
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, however version {latest_pkg_version} "
|
||||||
|
f"is available, please upgrade.")
|
||||||
|
print('--------')
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(strings.en["BETA_MESSAGE"])
|
print(strings.en["BETA_MESSAGE"])
|
||||||
if not is_colab:
|
if not is_colab:
|
||||||
|
@ -13,6 +13,7 @@ from gradio import inputs, outputs
|
|||||||
import json
|
import json
|
||||||
from gradio.tunneling import create_tunnel
|
from gradio.tunneling import create_tunnel
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from shutil import copyfile
|
||||||
|
|
||||||
INITIAL_PORT_VALUE = (
|
INITIAL_PORT_VALUE = (
|
||||||
7860
|
7860
|
||||||
@ -32,6 +33,9 @@ TEMPLATE_TEMP = "index.html"
|
|||||||
BASE_JS_FILE = "static/js/all_io.js"
|
BASE_JS_FILE = "static/js/all_io.js"
|
||||||
CONFIG_FILE = "static/config.json"
|
CONFIG_FILE = "static/config.json"
|
||||||
|
|
||||||
|
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
|
||||||
|
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
|
||||||
|
|
||||||
|
|
||||||
def build_template(temp_dir, input_interface, output_interface):
|
def build_template(temp_dir, input_interface, output_interface):
|
||||||
"""
|
"""
|
||||||
@ -75,6 +79,10 @@ def build_template(temp_dir, input_interface, output_interface):
|
|||||||
f.write(str(all_io_soup))
|
f.write(str(all_io_soup))
|
||||||
|
|
||||||
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||||
|
# Move association file to root of temporary directory.
|
||||||
|
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
|
||||||
|
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
|
||||||
|
|
||||||
render_template_with_tags(
|
render_template_with_tags(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
temp_dir,
|
temp_dir,
|
||||||
|
11
gradio/static/apple-app-site-association
Normal file
11
gradio/static/apple-app-site-association
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"applinks": {
|
||||||
|
"apps": [],
|
||||||
|
"details": [
|
||||||
|
{
|
||||||
|
"appID": "RHW8FBGSTX.app.gradio.Gradio",
|
||||||
|
"paths": ["*"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
@ -5,5 +5,5 @@ en = {
|
|||||||
"restarting python interpreter.",
|
"restarting python interpreter.",
|
||||||
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
||||||
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
|
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
|
||||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly for 8 hours at: {}"
|
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} -- may take up to a minute to setup."
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ import socket
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
import warnings
|
||||||
import paramiko
|
import paramiko
|
||||||
|
|
||||||
DEBUG_MODE = False
|
DEBUG_MODE = False
|
||||||
@ -68,6 +68,8 @@ def create_tunnel(payload, local_server, local_server_port):
|
|||||||
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
|
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
client.connect(
|
client.connect(
|
||||||
hostname=payload["host"],
|
hostname=payload["host"],
|
||||||
port=int(payload["port"]),
|
port=int(payload["port"]),
|
||||||
|
Loading…
Reference in New Issue
Block a user