mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
pass keras graph into thread
This commit is contained in:
parent
12e05caed1
commit
38a6899d95
@ -27,9 +27,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# model = tf.keras.applications.inception_v3.InceptionV3()\n",
|
||||
"def model(x):\n",
|
||||
" return \"test class\""
|
||||
"model = tf.keras.applications.inception_v3.InceptionV3()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -51,13 +49,12 @@
|
||||
"io = gradio.Interface(inputs=inp, \n",
|
||||
" outputs=out,\n",
|
||||
" model=model, \n",
|
||||
"# model_type='keras')\n",
|
||||
" model_type='pyfunc')"
|
||||
" model_type='keras')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
@ -66,10 +63,9 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Closing existing server...\n",
|
||||
"NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n",
|
||||
"Model is running locally at: http://localhost:7860/\n",
|
||||
"Model available publicly for 8 hours at: https://e46308c0.gradio.app/\n"
|
||||
"Model available publicly at: https://10227.gradio.app -- may take up to a minute to setup.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -86,7 +82,7 @@
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.lib.display.IFrame at 0x239f4273a90>"
|
||||
"<IPython.lib.display.IFrame at 0x1cc2975aa90>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
@ -10,10 +10,15 @@ import webbrowser
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
from gradio import networking, strings
|
||||
from distutils.version import StrictVersion
|
||||
import pkg_resources
|
||||
import requests
|
||||
import termcolor
|
||||
|
||||
LOCALHOST_IP = "127.0.0.1"
|
||||
INITIAL_WEBSOCKET_PORT = 9200
|
||||
TRY_NUM_PORTS = 100
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
|
||||
|
||||
class Interface:
|
||||
@ -80,6 +85,9 @@ class Interface:
|
||||
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
|
||||
self.model_type = model_type
|
||||
if self.model_type == "keras":
|
||||
import tensorflow as tf
|
||||
self.graph = tf.get_default_graph()
|
||||
self.verbose = verbose
|
||||
self.status = self.STATUS_TYPES["OFF"]
|
||||
self.validate_flag = False
|
||||
@ -127,12 +135,12 @@ class Interface:
|
||||
if self.model_type == "sklearn":
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type == "keras":
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
with self.graph.as_default():
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type == "pyfunc":
|
||||
return self.model_obj(preprocessed_input)
|
||||
elif self.model_type == "pytorch":
|
||||
import torch
|
||||
|
||||
value = torch.from_numpy(preprocessed_input)
|
||||
value = torch.autograd.Variable(value)
|
||||
prediction = self.model_obj(value)
|
||||
@ -237,6 +245,12 @@ class Interface:
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, however version {latest_pkg_version} "
|
||||
f"is available, please upgrade.")
|
||||
print('--------')
|
||||
if self.verbose:
|
||||
print(strings.en["BETA_MESSAGE"])
|
||||
if not is_colab:
|
||||
|
@ -13,6 +13,7 @@ from gradio import inputs, outputs
|
||||
import json
|
||||
from gradio.tunneling import create_tunnel
|
||||
import urllib.request
|
||||
from shutil import copyfile
|
||||
|
||||
INITIAL_PORT_VALUE = (
|
||||
7860
|
||||
@ -32,6 +33,9 @@ TEMPLATE_TEMP = "index.html"
|
||||
BASE_JS_FILE = "static/js/all_io.js"
|
||||
CONFIG_FILE = "static/config.json"
|
||||
|
||||
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
|
||||
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
|
||||
|
||||
|
||||
def build_template(temp_dir, input_interface, output_interface):
|
||||
"""
|
||||
@ -75,6 +79,10 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
f.write(str(all_io_soup))
|
||||
|
||||
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||
# Move association file to root of temporary directory.
|
||||
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
|
||||
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
|
||||
|
||||
render_template_with_tags(
|
||||
os.path.join(
|
||||
temp_dir,
|
||||
|
11
gradio/static/apple-app-site-association
Normal file
11
gradio/static/apple-app-site-association
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"applinks": {
|
||||
"apps": [],
|
||||
"details": [
|
||||
{
|
||||
"appID": "RHW8FBGSTX.app.gradio.Gradio",
|
||||
"paths": ["*"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -5,5 +5,5 @@ en = {
|
||||
"restarting python interpreter.",
|
||||
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
||||
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
|
||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly for 8 hours at: {}"
|
||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} -- may take up to a minute to setup."
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import socket
|
||||
import sys
|
||||
import threading
|
||||
from io import StringIO
|
||||
|
||||
import warnings
|
||||
import paramiko
|
||||
|
||||
DEBUG_MODE = False
|
||||
@ -68,12 +68,14 @@ def create_tunnel(payload, local_server, local_server_port):
|
||||
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
|
||||
)
|
||||
try:
|
||||
client.connect(
|
||||
hostname=payload["host"],
|
||||
port=int(payload["port"]),
|
||||
username=payload["user"],
|
||||
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])),
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
client.connect(
|
||||
hostname=payload["host"],
|
||||
port=int(payload["port"]),
|
||||
username=payload["user"],
|
||||
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])),
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"*** Failed to connect to %s:%d: %r"
|
||||
|
Loading…
Reference in New Issue
Block a user