From 38a6899d951a2bee36dcdcc85a9157e410d3bded Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 18 Apr 2019 23:44:19 -0700 Subject: [PATCH] pass keras graph into thread --- Test Keras.ipynb | 14 +++++--------- gradio/interface.py | 18 ++++++++++++++++-- gradio/networking.py | 8 ++++++++ gradio/static/apple-app-site-association | 11 +++++++++++ gradio/strings.py | 2 +- gradio/tunneling.py | 16 +++++++++------- 6 files changed, 50 insertions(+), 19 deletions(-) create mode 100644 gradio/static/apple-app-site-association diff --git a/Test Keras.ipynb b/Test Keras.ipynb index 5555d7296d..90a4e5b24c 100644 --- a/Test Keras.ipynb +++ b/Test Keras.ipynb @@ -27,9 +27,7 @@ "metadata": {}, "outputs": [], "source": [ - "# model = tf.keras.applications.inception_v3.InceptionV3()\n", - "def model(x):\n", - " return \"test class\"" + "model = tf.keras.applications.inception_v3.InceptionV3()" ] }, { @@ -51,13 +49,12 @@ "io = gradio.Interface(inputs=inp, \n", " outputs=out,\n", " model=model, \n", - "# model_type='keras')\n", - " model_type='pyfunc')" + " model_type='keras')" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "scrolled": false }, @@ -66,10 +63,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Closing existing server...\n", "NOTE: Gradio is in beta stage, please report all bugs to: contact.gradio@gmail.com\n", "Model is running locally at: http://localhost:7860/\n", - "Model available publicly for 8 hours at: https://e46308c0.gradio.app/\n" + "Model available publicly at: https://10227.gradio.app -- may take up to a minute to setup.\n" ] }, { @@ -86,7 +82,7 @@ " " ], "text/plain": [ - "" + "" ] }, "metadata": {}, diff --git a/gradio/interface.py b/gradio/interface.py index 409f3d48b5..82719007bc 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -10,10 +10,15 @@ import webbrowser import gradio.inputs import gradio.outputs from gradio import networking, strings +from distutils.version import StrictVersion +import pkg_resources +import requests +import termcolor LOCALHOST_IP = "127.0.0.1" INITIAL_WEBSOCKET_PORT = 9200 TRY_NUM_PORTS = 100 +PKG_VERSION_URL = "https://gradio.app/api/pkg-version" class Interface: @@ -80,6 +85,9 @@ class Interface: elif not (model_type.lower() in self.VALID_MODEL_TYPES): ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES)) self.model_type = model_type + if self.model_type == "keras": + import tensorflow as tf + self.graph = tf.get_default_graph() self.verbose = verbose self.status = self.STATUS_TYPES["OFF"] self.validate_flag = False @@ -127,12 +135,12 @@ class Interface: if self.model_type == "sklearn": return self.model_obj.predict(preprocessed_input) elif self.model_type == "keras": - return self.model_obj.predict(preprocessed_input) + with self.graph.as_default(): + return self.model_obj.predict(preprocessed_input) elif self.model_type == "pyfunc": return self.model_obj(preprocessed_input) elif self.model_type == "pytorch": import torch - value = torch.from_numpy(preprocessed_input) value = torch.autograd.Variable(value) prediction = self.model_obj(value) @@ -237,6 +245,12 @@ class Interface: except NameError: pass + current_pkg_version = pkg_resources.require("gradio")[0].version + latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"] + if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version): + print(f"IMPORTANT: You are using gradio version {current_pkg_version}, however version {latest_pkg_version} " + f"is available, please upgrade.") + print('--------') if self.verbose: print(strings.en["BETA_MESSAGE"]) if not is_colab: diff --git a/gradio/networking.py b/gradio/networking.py index d455c62a30..624772780e 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -13,6 +13,7 @@ from gradio import inputs, outputs import json from gradio.tunneling import create_tunnel import urllib.request +from shutil import copyfile INITIAL_PORT_VALUE = ( 7860 @@ -32,6 +33,9 @@ TEMPLATE_TEMP = "index.html" BASE_JS_FILE = "static/js/all_io.js" CONFIG_FILE = "static/config.json" +ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association" +ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association" + def build_template(temp_dir, input_interface, output_interface): """ @@ -75,6 +79,10 @@ def build_template(temp_dir, input_interface, output_interface): f.write(str(all_io_soup)) copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP)) + # Move association file to root of temporary directory. + copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC), + os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT)) + render_template_with_tags( os.path.join( temp_dir, diff --git a/gradio/static/apple-app-site-association b/gradio/static/apple-app-site-association new file mode 100644 index 0000000000..36c39e3174 --- /dev/null +++ b/gradio/static/apple-app-site-association @@ -0,0 +1,11 @@ +{ + "applinks": { + "apps": [], + "details": [ + { + "appID": "RHW8FBGSTX.app.gradio.Gradio", + "paths": ["*"] + } + ] + } +} \ No newline at end of file diff --git a/gradio/strings.py b/gradio/strings.py index 4f95d4b752..81b291a0f6 100644 --- a/gradio/strings.py +++ b/gradio/strings.py @@ -5,5 +5,5 @@ en = { "restarting python interpreter.", "COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.", "PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.", - "MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly for 8 hours at: {}" + "MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} -- may take up to a minute to setup." } diff --git a/gradio/tunneling.py b/gradio/tunneling.py index 415e8d4e70..990eee3d0b 100644 --- a/gradio/tunneling.py +++ b/gradio/tunneling.py @@ -8,7 +8,7 @@ import socket import sys import threading from io import StringIO - +import warnings import paramiko DEBUG_MODE = False @@ -68,12 +68,14 @@ def create_tunnel(payload, local_server, local_server_port): "Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"])) ) try: - client.connect( - hostname=payload["host"], - port=int(payload["port"]), - username=payload["user"], - pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])), - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client.connect( + hostname=payload["host"], + port=int(payload["port"]), + username=payload["user"], + pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])), + ) except Exception as e: print( "*** Failed to connect to %s:%d: %r"