mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
added networking to support linux
This commit is contained in:
parent
ff914750b0
commit
a2903b115b
@ -2,9 +2,19 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model available publicly for 8 hours at: http://22c0b84e.ngrok.io/interface.html\n",
|
||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext autoreload\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
@ -14,7 +24,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -24,7 +34,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
@ -43,7 +53,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -52,31 +62,42 @@
|
||||
"text": [
|
||||
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
|
||||
"Model available locally at: http://localhost:7862/interface.html\n",
|
||||
"Model available publicly for 8 hours at: http://9fb08e2b.ngrok.io/interface.html\n"
|
||||
"Model available publicly for 8 hours at: https://81712345.ngrok.io/interface.html\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:11] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:12] \"GET /js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:12] \"GET /css/bootstrap.min.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:12] \"GET /css/draw-a-digit.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:12] \"GET /js/bootstrap.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:12] \"GET /js/bootstrap-notify.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:12] \"GET /js/textbox-input.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:12] \"GET /js/textbox-output.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:15] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:16] \"GET /js/bootstrap.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:16] \"GET /js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:16] \"GET /css/bootstrap.min.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:16] \"GET /js/textbox-input.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:16] \"GET /css/draw-a-digit.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:16] \"GET /js/bootstrap-notify.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:17] \"GET /js/textbox-output.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:17] code 404, message File not found\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 22:39:17] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /css/bootstrap.min.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /css/draw-a-digit.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /js/bootstrap.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /js/bootstrap-notify.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /js/textbox-input.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:10] \"GET /js/textbox-output.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:28] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:29] \"GET /js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:29] \"GET /css/bootstrap.min.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:30] \"GET /css/draw-a-digit.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:30] \"GET /js/bootstrap.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:30] \"GET /js/bootstrap-notify.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:30] \"GET /js/textbox-input.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:30] \"GET /js/textbox-output.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:31] code 404, message File not found\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:05:31] \"GET /favicon.ico HTTP/1.1\" 404 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:15] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:24] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:24] \"GET /js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:24] \"GET /js/bootstrap.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:24] \"GET /css/bootstrap.min.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:24] \"GET /js/bootstrap-notify.min.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:25] \"GET /css/draw-a-digit.css HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:25] \"GET /js/textbox-input.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:25] \"GET /js/textbox-output.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:26] code 404, message File not found\n",
|
||||
"127.0.0.1 - - [24/Feb/2019 23:06:26] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -1,157 +1 @@
|
||||
import asyncio
|
||||
import websockets
|
||||
import nest_asyncio
|
||||
import webbrowser
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
from gradio import inputs
|
||||
from gradio import outputs
|
||||
from gradio import networking
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
LOCALHOST_IP = '127.0.0.1'
|
||||
INITIAL_WEBSOCKET_PORT = 9200
|
||||
TRY_NUM_PORTS = 100
|
||||
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/all_io.html')
|
||||
JS_PATH_LIB = pkg_resources.resource_filename('gradio', 'js/')
|
||||
CSS_PATH_LIB = pkg_resources.resource_filename('gradio', 'css/')
|
||||
JS_PATH_TEMP = 'js/'
|
||||
CSS_PATH_TEMP = 'css/'
|
||||
TEMPLATE_TEMP = 'interface.html'
|
||||
BASE_JS_FILE = 'js/all-io.js'
|
||||
|
||||
|
||||
class Interface():
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, input, output, model, model_type, preprocessing_fn=None, postprocessing_fn=None):
|
||||
"""
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
|
||||
:param model_obj: the model object, such as a sklearn classifier or keras model.
|
||||
:param model_params: additional model parameters.
|
||||
"""
|
||||
self.input_interface = inputs.registry[input](preprocessing_fn)
|
||||
self.output_interface = outputs.registry[output](postprocessing_fn)
|
||||
self.model_type = model_type
|
||||
self.model_obj = model
|
||||
|
||||
def _build_template(self, temp_dir):
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
'gradio', self.input_interface._get_template_path())
|
||||
output_template_path = pkg_resources.resource_filename(
|
||||
'gradio', self.output_interface._get_template_path())
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
|
||||
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||
input_tag = all_io_soup.find("div", {"id": "input"})
|
||||
output_tag = all_io_soup.find("div", {"id": "output"})
|
||||
|
||||
input_tag.replace_with(input_soup)
|
||||
output_tag.replace_with(output_soup)
|
||||
|
||||
f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w")
|
||||
f.write(str(all_io_soup.prettify))
|
||||
|
||||
self._copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP))
|
||||
self._copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP))
|
||||
return
|
||||
|
||||
def _copy_files(self, src_dir, dest_dir):
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
src_files = os.listdir(src_dir)
|
||||
for file_name in src_files:
|
||||
full_file_name = os.path.join(src_dir, file_name)
|
||||
if os.path.isfile(full_file_name):
|
||||
shutil.copy(full_file_name, dest_dir)
|
||||
|
||||
def _set_socket_url_in_js(self, temp_dir, socket_url):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
def _set_socket_port_in_js(self, temp_dir, socket_port):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
def predict(self, array):
|
||||
if self.model_type=='sklearn':
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='keras':
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='func':
|
||||
return self.model_obj(array)
|
||||
else:
|
||||
raise ValueError('model_type must be one of: "sklearn" or "keras" or "func".')
|
||||
|
||||
async def communicate(self, websocket, path):
|
||||
"""
|
||||
Method that defines how this interface communicates with the websocket.
|
||||
:param websocket: a Websocket object used to communicate with the interface frontend
|
||||
:param path: ignored
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
processed_input = self.input_interface._pre_process(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
processed_output = self.output_interface._post_process(prediction)
|
||||
await websocket.send(str(processed_output))
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
def launch(self, share_link=True):
|
||||
"""
|
||||
Standard method shared by interfaces that launches a websocket at a specified IP address.
|
||||
"""
|
||||
networking.kill_processes([4040, 4041])
|
||||
output_directory = tempfile.mkdtemp()
|
||||
|
||||
server_port = networking.start_simple_server(output_directory)
|
||||
path_to_server = 'http://localhost:{}/'.format(server_port)
|
||||
self._build_template(output_directory)
|
||||
|
||||
ports_in_use = networking.get_ports_in_use()
|
||||
for i in range(TRY_NUM_PORTS):
|
||||
if not ((INITIAL_WEBSOCKET_PORT + i) in ports_in_use):
|
||||
break
|
||||
else:
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(
|
||||
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS))
|
||||
|
||||
start_server = websockets.serve(self.communicate, LOCALHOST_IP, INITIAL_WEBSOCKET_PORT + i)
|
||||
self._set_socket_port_in_js(output_directory, INITIAL_WEBSOCKET_PORT + i)
|
||||
|
||||
if share_link:
|
||||
site_ngrok_url = networking.setup_ngrok(server_port)
|
||||
socket_ngrok_url = networking.setup_ngrok(INITIAL_WEBSOCKET_PORT, api_url=networking.NGROK_TUNNELS_API_URL2)
|
||||
self._set_socket_url_in_js(output_directory, socket_ngrok_url)
|
||||
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
|
||||
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
|
||||
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
try:
|
||||
asyncio.get_event_loop().run_forever()
|
||||
except RuntimeError: # Runtime errors are thrown in jupyter notebooks because of async.
|
||||
pass
|
||||
|
||||
webbrowser.open(path_to_server + TEMPLATE_TEMP)
|
||||
from gradio.interface import Interface # This makes Interface importable as gradio.Interface.
|
||||
|
201
build/lib/gradio/interface.py
Normal file
201
build/lib/gradio/interface.py
Normal file
@ -0,0 +1,201 @@
|
||||
import asyncio
|
||||
import websockets
|
||||
import nest_asyncio
|
||||
import webbrowser
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
from gradio import inputs
|
||||
from gradio import outputs
|
||||
from gradio import networking
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
LOCALHOST_IP = '127.0.0.1'
|
||||
INITIAL_WEBSOCKET_PORT = 9200
|
||||
TRY_NUM_PORTS = 100
|
||||
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/all_io.html')
|
||||
JS_PATH_LIB = pkg_resources.resource_filename('gradio', 'js/')
|
||||
CSS_PATH_LIB = pkg_resources.resource_filename('gradio', 'css/')
|
||||
JS_PATH_TEMP = 'js/'
|
||||
CSS_PATH_TEMP = 'css/'
|
||||
TEMPLATE_TEMP = 'interface.html'
|
||||
BASE_JS_FILE = 'js/all-io.js'
|
||||
|
||||
|
||||
class Interface():
|
||||
"""
|
||||
"""
|
||||
|
||||
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
|
||||
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
|
||||
|
||||
def __init__(self, input, output, model, model_type=None, preprocessing_fn=None, postprocessing_fn=None):
|
||||
"""
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
|
||||
:param model_obj: the model object, such as a sklearn classifier or keras model.
|
||||
:param model_params: additional model parameters.
|
||||
"""
|
||||
self.input_interface = inputs.registry[input](preprocessing_fn)
|
||||
self.output_interface = outputs.registry[output](postprocessing_fn)
|
||||
self.model_obj = model
|
||||
if model_type is None:
|
||||
model_type = self._infer_model_type(model)
|
||||
if model_type is None:
|
||||
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
|
||||
else:
|
||||
print("Model type not explicitly identified, inferred to be: {}".format(
|
||||
self.VALID_MODEL_TYPES[model_type]))
|
||||
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
self.model_type = model_type
|
||||
|
||||
def _infer_model_type(self, model):
|
||||
if callable(model):
|
||||
return 'function'
|
||||
|
||||
try:
|
||||
import sklearn
|
||||
if isinstance(model, sklearn.base.BaseEstimator):
|
||||
return 'sklearn'
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
if isinstance(model, tf.keras.Model):
|
||||
return 'keras'
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import keras
|
||||
if isinstance(model, keras.Model):
|
||||
return 'keras'
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _build_template(self, temp_dir):
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
'gradio', self.input_interface._get_template_path())
|
||||
output_template_path = pkg_resources.resource_filename(
|
||||
'gradio', self.output_interface._get_template_path())
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
|
||||
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||
input_tag = all_io_soup.find("div", {"id": "input"})
|
||||
output_tag = all_io_soup.find("div", {"id": "output"})
|
||||
|
||||
input_tag.replace_with(input_soup)
|
||||
output_tag.replace_with(output_soup)
|
||||
|
||||
f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w")
|
||||
f.write(str(all_io_soup.prettify))
|
||||
|
||||
self._copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP))
|
||||
self._copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP))
|
||||
return
|
||||
|
||||
def _copy_files(self, src_dir, dest_dir):
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
src_files = os.listdir(src_dir)
|
||||
for file_name in src_files:
|
||||
full_file_name = os.path.join(src_dir, file_name)
|
||||
if os.path.isfile(full_file_name):
|
||||
shutil.copy(full_file_name, dest_dir)
|
||||
|
||||
def _set_socket_url_in_js(self, temp_dir, socket_url):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
def _set_socket_port_in_js(self, temp_dir, socket_port):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
def predict(self, array):
|
||||
if self.model_type=='sklearn':
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='keras':
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='function':
|
||||
return self.model_obj(array)
|
||||
else:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
|
||||
async def communicate(self, websocket, path):
|
||||
"""
|
||||
Method that defines how this interface communicates with the websocket.
|
||||
:param websocket: a Websocket object used to communicate with the interface frontend
|
||||
:param path: ignored
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
processed_input = self.input_interface._pre_process(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
processed_output = self.output_interface._post_process(prediction)
|
||||
await websocket.send(str(processed_output))
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
def launch(self, share_link=False, verbose=True):
|
||||
"""
|
||||
Standard method shared by interfaces that launches a websocket at a specified IP address.
|
||||
"""
|
||||
output_directory = tempfile.mkdtemp()
|
||||
server_port = networking.start_simple_server(output_directory)
|
||||
path_to_server = 'http://localhost:{}/'.format(server_port)
|
||||
self._build_template(output_directory)
|
||||
|
||||
ports_in_use = networking.get_ports_in_use(INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
|
||||
for i in range(TRY_NUM_PORTS):
|
||||
if not ((INITIAL_WEBSOCKET_PORT + i) in ports_in_use):
|
||||
break
|
||||
else:
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(
|
||||
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS))
|
||||
|
||||
start_server = websockets.serve(self.communicate, LOCALHOST_IP, INITIAL_WEBSOCKET_PORT + i)
|
||||
self._set_socket_port_in_js(output_directory, INITIAL_WEBSOCKET_PORT + i)
|
||||
if verbose:
|
||||
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
|
||||
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
|
||||
|
||||
if share_link:
|
||||
networking.kill_processes([4040, 4041])
|
||||
site_ngrok_url = networking.setup_ngrok(server_port)
|
||||
socket_ngrok_url = networking.setup_ngrok(INITIAL_WEBSOCKET_PORT, api_url=networking.NGROK_TUNNELS_API_URL2)
|
||||
self._set_socket_url_in_js(output_directory, socket_ngrok_url)
|
||||
if verbose:
|
||||
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
|
||||
else:
|
||||
if verbose:
|
||||
print("To create a public link, set `share_link=True` in the argument to `launch()`")
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
try:
|
||||
asyncio.get_event_loop().run_forever()
|
||||
except RuntimeError: # Runtime errors are thrown in jupyter notebooks because of async.
|
||||
pass
|
||||
|
||||
webbrowser.open(path_to_server + TEMPLATE_TEMP)
|
@ -4,11 +4,18 @@ import zipfile
|
||||
import io
|
||||
import sys
|
||||
import os
|
||||
import socket
|
||||
from psutil import process_iter, AccessDenied
|
||||
from signal import SIGTERM # or SIGKILL
|
||||
import threading
|
||||
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
|
||||
import stat
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
INITIAL_PORT_VALUE = 7860
|
||||
TRY_NUM_PORTS = 100
|
||||
LOCALHOST_NAME = 'localhost'
|
||||
LOCALHOST_PREFIX = 'localhost:'
|
||||
NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output)
|
||||
NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output)
|
||||
@ -20,28 +27,94 @@ NGROK_ZIP_URLS = {
|
||||
}
|
||||
|
||||
|
||||
def get_ports_in_use():
|
||||
def get_ports_in_use(start, stop):
|
||||
ports_in_use = []
|
||||
for proc in process_iter():
|
||||
for conns in proc.connections(kind='inet'):
|
||||
ports_in_use.append(conns.laddr.port)
|
||||
for port in range(start, stop):
|
||||
try:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.close()
|
||||
except OSError:
|
||||
ports_in_use.append(port)
|
||||
return ports_in_use
|
||||
# ports_in_use = []
|
||||
# try:
|
||||
# for proc in process_iter():
|
||||
# for conns in proc.connections(kind='inet'):
|
||||
# ports_in_use.append(conns.laddr.port)
|
||||
# except AccessDenied:
|
||||
# pass # TODO(abidlabs): somehow find a way to handle this issue?
|
||||
# return ports_in_use
|
||||
|
||||
|
||||
def serve_files_in_background(port, directory_to_serve=None):
|
||||
# class Handler(http.server.SimpleHTTPRequestHandler):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, directory=directory_to_serve, **kwargs)
|
||||
#
|
||||
# server = socketserver.ThreadingTCPServer(('localhost', port), Handler)
|
||||
# # Ensures that Ctrl-C cleanly kills all spawned threads
|
||||
# server.daemon_threads = True
|
||||
# # Quicker rebinding
|
||||
# server.allow_reuse_address = True
|
||||
#
|
||||
# # A custom signal handle to allow us to Ctrl-C out of the process
|
||||
# def signal_handler(signal, frame):
|
||||
# print('Exiting http server (Ctrl+C pressed)')
|
||||
# try:
|
||||
# if (server):
|
||||
# server.server_close()
|
||||
# finally:
|
||||
# sys.exit(0)
|
||||
#
|
||||
# # Install the keyboard interrupt handler
|
||||
# signal.signal(signal.SIGINT, signal_handler)
|
||||
class HTTPHandler(SimpleHTTPRequestHandler):
|
||||
"""This handler uses server.base_path instead of always using os.getcwd()"""
|
||||
|
||||
def translate_path(self, path):
|
||||
path = SimpleHTTPRequestHandler.translate_path(self, path)
|
||||
relpath = os.path.relpath(path, os.getcwd())
|
||||
fullpath = os.path.join(self.server.base_path, relpath)
|
||||
return fullpath
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
||||
def __init__(self, base_path, server_address, RequestHandlerClass=HTTPHandler):
|
||||
self.base_path = base_path
|
||||
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
|
||||
|
||||
httpd = HTTPServer(directory_to_serve, (LOCALHOST_NAME, port))
|
||||
|
||||
# Now loop forever
|
||||
def serve_forever():
|
||||
try:
|
||||
while True:
|
||||
sys.stdout.flush()
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
thread = threading.Thread(target=serve_forever)
|
||||
thread.start()
|
||||
|
||||
|
||||
def start_simple_server(directory_to_serve=None):
|
||||
# TODO(abidlabs): increment port number until free port is found
|
||||
ports_in_use = get_ports_in_use()
|
||||
ports_in_use = get_ports_in_use(start=INITIAL_PORT_VALUE, stop=INITIAL_PORT_VALUE + TRY_NUM_PORTS)
|
||||
for i in range(TRY_NUM_PORTS):
|
||||
if not((INITIAL_PORT_VALUE + i) in ports_in_use):
|
||||
break
|
||||
else:
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS))
|
||||
if directory_to_serve is None:
|
||||
subprocess.Popen(['python', '-m', 'http.server', str(INITIAL_PORT_VALUE + i)])
|
||||
else:
|
||||
cmd = ' '.join(['python', '-m', 'http.server', '-d', directory_to_serve, str(INITIAL_PORT_VALUE + i)])
|
||||
subprocess.Popen(cmd, shell=True) # Doesn't seem to work if list is passed for some reason.
|
||||
serve_files_in_background(INITIAL_PORT_VALUE + i, directory_to_serve)
|
||||
# if directory_to_serve is None:
|
||||
# subprocess.Popen(['python', '-m', 'http.server', str(INITIAL_PORT_VALUE + i)])
|
||||
# else:
|
||||
# cmd = ' '.join(['python', '-m', 'http.server', '-d', directory_to_serve, str(INITIAL_PORT_VALUE + i)])
|
||||
# subprocess.Popen(cmd, shell=True) # Doesn't seem to work if list is passed for some reason.
|
||||
return INITIAL_PORT_VALUE + i
|
||||
|
||||
|
||||
@ -49,19 +122,30 @@ def download_ngrok():
|
||||
try:
|
||||
zip_file_url = NGROK_ZIP_URLS[sys.platform]
|
||||
except KeyError:
|
||||
print("Sorry, we don't currently support your operating system, please leave us a note on GitHub, and we'll look into it!")
|
||||
print("Sorry, we don't currently support your operating system, please leave us a note on GitHub, and "
|
||||
"we'll look into it!")
|
||||
return
|
||||
|
||||
r = requests.get(zip_file_url)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall()
|
||||
if sys.platform == 'darwin' or sys.platform == 'linux':
|
||||
st = os.stat('ngrok')
|
||||
os.chmod('ngrok', st.st_mode | stat.S_IEXEC)
|
||||
|
||||
|
||||
def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL):
|
||||
if not(os.path.isfile('ngrok.exe')):
|
||||
download_ngrok()
|
||||
subprocess.Popen(['ngrok', 'http', str(local_port)])
|
||||
r = requests.get(api_url)
|
||||
if sys.platform == 'win32':
|
||||
subprocess.Popen(['ngrok', 'http', str(local_port)])
|
||||
else:
|
||||
subprocess.Popen(['./ngrok', 'http', str(local_port)])
|
||||
session = requests.Session()
|
||||
retry = Retry(connect=3, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
r = session.get(api_url)
|
||||
for tunnel in r.json()['tunnels']:
|
||||
if LOCALHOST_PREFIX + str(local_port) in tunnel['config']['addr']:
|
||||
return tunnel['public_url']
|
||||
@ -75,6 +159,6 @@ def kill_processes(process_ids):
|
||||
if conns.laddr.port in process_ids:
|
||||
proc.send_signal(SIGTERM) # or SIGKILL
|
||||
except AccessDenied:
|
||||
print("Unable to kill processes, please kill manually.")
|
||||
pass
|
||||
|
||||
|
||||
|
Binary file not shown.
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 1.0
|
||||
Name: gradio
|
||||
Version: 0.1.8
|
||||
Version: 0.2.0
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
Home-page: https://github.com/abidlabs/gradio
|
||||
Author: Abubakar Abid
|
||||
|
@ -3,6 +3,7 @@ README.md
|
||||
setup.py
|
||||
gradio/__init__.py
|
||||
gradio/inputs.py
|
||||
gradio/interface.py
|
||||
gradio/networking.py
|
||||
gradio/outputs.py
|
||||
gradio/preprocessing_utils.py
|
||||
|
@ -10,6 +10,8 @@ from signal import SIGTERM # or SIGKILL
|
||||
import threading
|
||||
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
|
||||
import stat
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
INITIAL_PORT_VALUE = 7860
|
||||
TRY_NUM_PORTS = 100
|
||||
@ -119,25 +121,31 @@ def start_simple_server(directory_to_serve=None):
|
||||
def download_ngrok():
|
||||
try:
|
||||
zip_file_url = NGROK_ZIP_URLS[sys.platform]
|
||||
print(zip_file_url)
|
||||
except KeyError:
|
||||
print("Sorry, we don't currently support your operating system, please leave us a note on GitHub, and we'll look into it!")
|
||||
print("Sorry, we don't currently support your operating system, please leave us a note on GitHub, and "
|
||||
"we'll look into it!")
|
||||
return
|
||||
r = requests.get(zip_file_url)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall()
|
||||
if (sys.platform=='darwin' or sys.platform=='linux'):
|
||||
if sys.platform == 'darwin' or sys.platform == 'linux':
|
||||
st = os.stat('ngrok')
|
||||
os.chmod('ngrok', st.st_mode | stat.S_IEXEC)
|
||||
|
||||
|
||||
def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL):
|
||||
if not(os.path.isfile('ngrok.exe')):
|
||||
download_ngrok()
|
||||
if sys.platform=='win32':
|
||||
if sys.platform == 'win32':
|
||||
subprocess.Popen(['ngrok', 'http', str(local_port)])
|
||||
else:
|
||||
subprocess.Popen(['./ngrok', 'http', str(local_port)])
|
||||
r = requests.get(api_url)
|
||||
session = requests.Session()
|
||||
retry = Retry(connect=3, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
r = session.get(api_url)
|
||||
for tunnel in r.json()['tunnels']:
|
||||
if LOCALHOST_PREFIX + str(local_port) in tunnel['config']['addr']:
|
||||
return tunnel['public_url']
|
||||
|
Loading…
x
Reference in New Issue
Block a user