renamed func to function and added an exception for psutil access denied

This commit is contained in:
Abubakar Abid 2019-02-24 20:50:23 -08:00
parent 3e474d7a36
commit 166221cd22
4 changed files with 97 additions and 30 deletions

View File

@ -2,26 +2,28 @@
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 15,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"%autoreload 2\n",
"\n",
"import gradio"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@ -31,7 +33,26 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 17,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model type not explicitly identified, inferred to be: python function\n"
]
}
],
"source": [
"iface = gradio.Interface(input=\"textbox\", output=\"textbox\", model=test)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
@ -39,13 +60,13 @@
"output_type": "stream",
"text": [
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model available locally at: http://localhost:7862/interface.html\n",
"Model available publicly for 8 hours at: http://bf3590aa.ngrok.io/interface.html\n"
"Model available locally at: http://localhost:7871/interface.html\n",
"To create a public link, set `share_link=True` in the argument to `launch()`\n"
]
}
],
"source": [
"gradio.Interface(input=\"textbox\", output=\"textbox\", model_type=\"func\", model=test).launch()"
"iface.launch()"
]
}
],

View File

@ -1 +1 @@
from interface import Interface # This makes Interface importable as gradio.Interface.
from gradio.interface import Interface # This makes Interface importable as gradio.Interface.

View File

@ -30,7 +30,10 @@ class Interface():
"""
"""
def __init__(self, input, output, model, model_type, preprocessing_fn=None, postprocessing_fn=None):
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
def __init__(self, input, output, model, model_type=None, preprocessing_fn=None, postprocessing_fn=None):
"""
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
:param model_obj: the model object, such as a sklearn classifier or keras model.
@ -38,8 +41,44 @@ class Interface():
"""
self.input_interface = inputs.registry[input](preprocessing_fn)
self.output_interface = outputs.registry[output](postprocessing_fn)
self.model_type = model_type
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)
if model_type is None:
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
else:
print("Model type not explicitly identified, inferred to be: {}".format(
self.VALID_MODEL_TYPES[model_type]))
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
self.model_type = model_type
def _infer_model_type(self, model):
if callable(model):
return 'function'
try:
import sklearn
if isinstance(model, sklearn.base.BaseEstimator):
return 'sklearn'
except ImportError:
pass
try:
import tensorflow as tf
if isinstance(model, tf.keras.Model):
return 'keras'
except ImportError:
pass
try:
import keras
if isinstance(model, keras.Model):
return 'keras'
except ImportError:
pass
return None
def _build_template(self, temp_dir):
input_template_path = pkg_resources.resource_filename(
@ -98,10 +137,10 @@ class Interface():
return self.model_obj.predict(array)
elif self.model_type=='keras':
return self.model_obj.predict(array)
elif self.model_type=='func':
elif self.model_type=='function':
return self.model_obj(array)
else:
raise ValueError('model_type must be one of: "sklearn" or "keras" or "func".')
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
async def communicate(self, websocket, path):
"""
@ -119,13 +158,11 @@ class Interface():
except websockets.exceptions.ConnectionClosed:
pass
def launch(self, share_link=True):
def launch(self, share_link=False, verbose=True):
"""
Standard method shared by interfaces that launches a websocket at a specified IP address.
"""
networking.kill_processes([4040, 4041])
output_directory = tempfile.mkdtemp()
server_port = networking.start_simple_server(output_directory)
path_to_server = 'http://localhost:{}/'.format(server_port)
self._build_template(output_directory)
@ -140,14 +177,20 @@ class Interface():
start_server = websockets.serve(self.communicate, LOCALHOST_IP, INITIAL_WEBSOCKET_PORT + i)
self._set_socket_port_in_js(output_directory, INITIAL_WEBSOCKET_PORT + i)
if verbose:
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
if share_link:
networking.kill_processes([4040, 4041])
site_ngrok_url = networking.setup_ngrok(server_port)
socket_ngrok_url = networking.setup_ngrok(INITIAL_WEBSOCKET_PORT, api_url=networking.NGROK_TUNNELS_API_URL2)
self._set_socket_url_in_js(output_directory, socket_ngrok_url)
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
if verbose:
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
else:
if verbose:
print("To create a public link, set `share_link=True` in the argument to `launch()`")
asyncio.get_event_loop().run_until_complete(start_server)
try:
asyncio.get_event_loop().run_forever()

View File

@ -22,9 +22,12 @@ NGROK_ZIP_URLS = {
def get_ports_in_use():
ports_in_use = []
for proc in process_iter():
for conns in proc.connections(kind='inet'):
ports_in_use.append(conns.laddr.port)
try:
for proc in process_iter():
for conns in proc.connections(kind='inet'):
ports_in_use.append(conns.laddr.port)
except AccessDenied:
pass # TODO(abidlabs): somehow find a way to handle this issue?
return ports_in_use