mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
renamed func to function and added an exception for psutil access denied
This commit is contained in:
parent
3e474d7a36
commit
166221cd22
@ -2,26 +2,28 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext autoreload\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"import gradio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -31,7 +33,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model type not explicitly identified, inferred to be: python function\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"iface = gradio.Interface(input=\"textbox\", output=\"textbox\", model=test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -39,13 +60,13 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
|
||||
"Model available locally at: http://localhost:7862/interface.html\n",
|
||||
"Model available publicly for 8 hours at: http://bf3590aa.ngrok.io/interface.html\n"
|
||||
"Model available locally at: http://localhost:7871/interface.html\n",
|
||||
"To create a public link, set `share_link=True` in the argument to `launch()`\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"gradio.Interface(input=\"textbox\", output=\"textbox\", model_type=\"func\", model=test).launch()"
|
||||
"iface.launch()"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -1 +1 @@
|
||||
from interface import Interface # This makes Interface importable as gradio.Interface.
|
||||
from gradio.interface import Interface # This makes Interface importable as gradio.Interface.
|
||||
|
@ -30,7 +30,10 @@ class Interface():
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, input, output, model, model_type, preprocessing_fn=None, postprocessing_fn=None):
|
||||
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
|
||||
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
|
||||
|
||||
def __init__(self, input, output, model, model_type=None, preprocessing_fn=None, postprocessing_fn=None):
|
||||
"""
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
|
||||
:param model_obj: the model object, such as a sklearn classifier or keras model.
|
||||
@ -38,8 +41,44 @@ class Interface():
|
||||
"""
|
||||
self.input_interface = inputs.registry[input](preprocessing_fn)
|
||||
self.output_interface = outputs.registry[output](postprocessing_fn)
|
||||
self.model_type = model_type
|
||||
self.model_obj = model
|
||||
if model_type is None:
|
||||
model_type = self._infer_model_type(model)
|
||||
if model_type is None:
|
||||
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
|
||||
else:
|
||||
print("Model type not explicitly identified, inferred to be: {}".format(
|
||||
self.VALID_MODEL_TYPES[model_type]))
|
||||
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
self.model_type = model_type
|
||||
|
||||
def _infer_model_type(self, model):
|
||||
if callable(model):
|
||||
return 'function'
|
||||
|
||||
try:
|
||||
import sklearn
|
||||
if isinstance(model, sklearn.base.BaseEstimator):
|
||||
return 'sklearn'
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
if isinstance(model, tf.keras.Model):
|
||||
return 'keras'
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import keras
|
||||
if isinstance(model, keras.Model):
|
||||
return 'keras'
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _build_template(self, temp_dir):
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
@ -98,10 +137,10 @@ class Interface():
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='keras':
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='func':
|
||||
elif self.model_type=='function':
|
||||
return self.model_obj(array)
|
||||
else:
|
||||
raise ValueError('model_type must be one of: "sklearn" or "keras" or "func".')
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
|
||||
async def communicate(self, websocket, path):
|
||||
"""
|
||||
@ -119,13 +158,11 @@ class Interface():
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
def launch(self, share_link=True):
|
||||
def launch(self, share_link=False, verbose=True):
|
||||
"""
|
||||
Standard method shared by interfaces that launches a websocket at a specified IP address.
|
||||
"""
|
||||
networking.kill_processes([4040, 4041])
|
||||
output_directory = tempfile.mkdtemp()
|
||||
|
||||
server_port = networking.start_simple_server(output_directory)
|
||||
path_to_server = 'http://localhost:{}/'.format(server_port)
|
||||
self._build_template(output_directory)
|
||||
@ -140,14 +177,20 @@ class Interface():
|
||||
|
||||
start_server = websockets.serve(self.communicate, LOCALHOST_IP, INITIAL_WEBSOCKET_PORT + i)
|
||||
self._set_socket_port_in_js(output_directory, INITIAL_WEBSOCKET_PORT + i)
|
||||
if verbose:
|
||||
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
|
||||
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
|
||||
|
||||
if share_link:
|
||||
networking.kill_processes([4040, 4041])
|
||||
site_ngrok_url = networking.setup_ngrok(server_port)
|
||||
socket_ngrok_url = networking.setup_ngrok(INITIAL_WEBSOCKET_PORT, api_url=networking.NGROK_TUNNELS_API_URL2)
|
||||
self._set_socket_url_in_js(output_directory, socket_ngrok_url)
|
||||
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
|
||||
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
|
||||
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
|
||||
if verbose:
|
||||
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
|
||||
else:
|
||||
if verbose:
|
||||
print("To create a public link, set `share_link=True` in the argument to `launch()`")
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
try:
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
@ -22,9 +22,12 @@ NGROK_ZIP_URLS = {
|
||||
|
||||
def get_ports_in_use():
|
||||
ports_in_use = []
|
||||
for proc in process_iter():
|
||||
for conns in proc.connections(kind='inet'):
|
||||
ports_in_use.append(conns.laddr.port)
|
||||
try:
|
||||
for proc in process_iter():
|
||||
for conns in proc.connections(kind='inet'):
|
||||
ports_in_use.append(conns.laddr.port)
|
||||
except AccessDenied:
|
||||
pass # TODO(abidlabs): somehow find a way to handle this issue?
|
||||
return ports_in_use
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user