mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Native reverse forward tunneling + Gradio server API integration
This commit is contained in:
parent
bdcbd5e8b4
commit
ef404beac0
6
.gitignore
vendored
6
.gitignore
vendored
@ -10,4 +10,8 @@ models/*
|
||||
gradio_files/*
|
||||
ngrok*
|
||||
examples/ngrok*
|
||||
gradio-flagged/*
|
||||
gradio-flagged/*
|
||||
.DS_Store
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
@ -1,19 +1,17 @@
|
||||
'''
|
||||
"""
|
||||
This is the core file in the `gradio` package, and defines the Interface class, including methods for constructing the
|
||||
interface using the input and output types.
|
||||
'''
|
||||
"""
|
||||
|
||||
import webbrowser
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
from gradio import networking, strings, inputs
|
||||
import tempfile
|
||||
import traceback
|
||||
import urllib
|
||||
import tensorflow as tf
|
||||
import webbrowser
|
||||
|
||||
LOCALHOST_IP = '127.0.0.1'
|
||||
SHARE_LINK_FORMAT = 'https://{}.gradio.app/'
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
from gradio import networking, strings
|
||||
|
||||
LOCALHOST_IP = "127.0.0.1"
|
||||
INITIAL_WEBSOCKET_PORT = 9200
|
||||
TRY_NUM_PORTS = 100
|
||||
|
||||
@ -25,12 +23,24 @@ class Interface:
|
||||
"""
|
||||
|
||||
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
|
||||
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'pyfunc': 'python function',
|
||||
'pytorch': 'PyTorch model'}
|
||||
STATUS_TYPES = {'OFF': 'off', 'RUNNING': 'running'}
|
||||
VALID_MODEL_TYPES = {
|
||||
"sklearn": "sklearn model",
|
||||
"keras": "Keras model",
|
||||
"pyfunc": "python function",
|
||||
"pytorch": "PyTorch model",
|
||||
}
|
||||
STATUS_TYPES = {"OFF": "off", "RUNNING": "running"}
|
||||
|
||||
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
|
||||
verbose=True):
|
||||
def __init__(
|
||||
self,
|
||||
inputs,
|
||||
outputs,
|
||||
model,
|
||||
model_type=None,
|
||||
preprocessing_fns=None,
|
||||
postprocessing_fns=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
:param outputs: a string or `AbstractOutput` representing the output interface.
|
||||
@ -41,80 +51,94 @@ class Interface:
|
||||
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](
|
||||
preprocessing_fns
|
||||
)
|
||||
elif isinstance(inputs, gradio.inputs.AbstractInput):
|
||||
self.input_interface = inputs
|
||||
else:
|
||||
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
|
||||
raise ValueError("Input interface must be of type `str` or `AbstractInput`")
|
||||
if isinstance(outputs, str):
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](
|
||||
postprocessing_fns
|
||||
)
|
||||
elif isinstance(outputs, gradio.outputs.AbstractOutput):
|
||||
self.output_interface = outputs
|
||||
else:
|
||||
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
|
||||
raise ValueError(
|
||||
"Output interface must be of type `str` or `AbstractOutput`"
|
||||
)
|
||||
self.model_obj = model
|
||||
if model_type is None:
|
||||
model_type = self._infer_model_type(model)
|
||||
if verbose:
|
||||
print("Model type not explicitly identified, inferred to be: {}".format(
|
||||
self.VALID_MODEL_TYPES[model_type]))
|
||||
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
print(
|
||||
"Model type not explicitly identified, inferred to be: {}".format(
|
||||
self.VALID_MODEL_TYPES[model_type]
|
||||
)
|
||||
)
|
||||
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
|
||||
self.model_type = model_type
|
||||
self.verbose = verbose
|
||||
self.status = self.STATUS_TYPES['OFF']
|
||||
self.status = self.STATUS_TYPES["OFF"]
|
||||
self.validate_flag = False
|
||||
self.simple_server = None
|
||||
self.ngrok_api_ports = None
|
||||
|
||||
@staticmethod
|
||||
def _infer_model_type(model):
|
||||
""" Helper method that attempts to identify the type of trained ML model."""
|
||||
try:
|
||||
import sklearn
|
||||
|
||||
if isinstance(model, sklearn.base.BaseEstimator):
|
||||
return 'sklearn'
|
||||
return "sklearn"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
if isinstance(model, tf.keras.Model):
|
||||
return 'keras'
|
||||
return "keras"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import keras
|
||||
|
||||
if isinstance(model, keras.Model):
|
||||
return 'keras'
|
||||
return "keras"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if callable(model):
|
||||
return 'pyfunc'
|
||||
return "pyfunc"
|
||||
|
||||
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
|
||||
raise ValueError(
|
||||
"model_type could not be inferred, please specify parameter `model_type`"
|
||||
)
|
||||
|
||||
def predict(self, preprocessed_input):
|
||||
"""
|
||||
Method that calls the relevant method of the model object to make a prediction.
|
||||
:param preprocessed_input: the preprocessed input returned by the input interface
|
||||
"""
|
||||
if self.model_type=='sklearn':
|
||||
if self.model_type == "sklearn":
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type=='keras':
|
||||
elif self.model_type == "keras":
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type=='pyfunc':
|
||||
elif self.model_type == "pyfunc":
|
||||
return self.model_obj(preprocessed_input)
|
||||
elif self.model_type=='pytorch':
|
||||
elif self.model_type == "pytorch":
|
||||
import torch
|
||||
|
||||
value = torch.from_numpy(preprocessed_input)
|
||||
value = torch.autograd.Variable(value)
|
||||
prediction = self.model_obj(value)
|
||||
return prediction.data.numpy()
|
||||
else:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
|
||||
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
@ -126,18 +150,28 @@ class Interface:
|
||||
if n == 0:
|
||||
self.validate_flag = True
|
||||
if self.verbose:
|
||||
print("No validation samples for this interface... skipping validation.")
|
||||
print(
|
||||
"No validation samples for this interface... skipping validation."
|
||||
)
|
||||
return
|
||||
for m, msg in enumerate(validation_inputs):
|
||||
if self.verbose:
|
||||
print(f"Validating samples: {m+1}/{n} [" + "="*(m+1) + "."*(n-m-1) + "]", end='\r')
|
||||
print(
|
||||
f"Validating samples: {m+1}/{n} ["
|
||||
+ "=" * (m + 1)
|
||||
+ "." * (n - m - 1)
|
||||
+ "]",
|
||||
end="\r",
|
||||
)
|
||||
try:
|
||||
processed_input = self.input_interface.preprocess(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print("\n----------")
|
||||
print("Validation failed, likely due to incompatible pre-processing and model input. See below:\n")
|
||||
print(
|
||||
"Validation failed, likely due to incompatible pre-processing and model input. See below:\n"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
break
|
||||
try:
|
||||
@ -145,8 +179,10 @@ class Interface:
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print("\n----------")
|
||||
print("Validation failed, likely due to incompatible model output and post-processing."
|
||||
"See below:\n")
|
||||
print(
|
||||
"Validation failed, likely due to incompatible model output and post-processing."
|
||||
"See below:\n"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
break
|
||||
else: # This means if a break was not explicitly called
|
||||
@ -168,7 +204,7 @@ class Interface:
|
||||
self.validate()
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == self.STATUS_TYPES['RUNNING']:
|
||||
if self.status == self.STATUS_TYPES["RUNNING"]:
|
||||
if self.verbose:
|
||||
print("Closing existing server...")
|
||||
if self.simple_server is not None:
|
||||
@ -180,19 +216,23 @@ class Interface:
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd = networking.start_simple_server(self, output_directory)
|
||||
path_to_local_server = 'http://localhost:{}/'.format(server_port)
|
||||
networking.build_template(output_directory, self.input_interface, self.output_interface)
|
||||
path_to_local_server = "http://localhost:{}/".format(server_port)
|
||||
networking.build_template(
|
||||
output_directory, self.input_interface, self.output_interface
|
||||
)
|
||||
|
||||
networking.set_interface_types_in_config_file(output_directory,
|
||||
self.input_interface.__class__.__name__.lower(),
|
||||
self.output_interface.__class__.__name__.lower())
|
||||
self.status = self.STATUS_TYPES['RUNNING']
|
||||
networking.set_interface_types_in_config_file(
|
||||
output_directory,
|
||||
self.input_interface.__class__.__name__.lower(),
|
||||
self.output_interface.__class__.__name__.lower(),
|
||||
)
|
||||
self.status = self.STATUS_TYPES["RUNNING"]
|
||||
self.simple_server = httpd
|
||||
|
||||
is_colab = False
|
||||
try: # Check if running interactively using ipython.
|
||||
from_ipynb = get_ipython()
|
||||
if 'google.colab' in str(from_ipynb):
|
||||
if "google.colab" in str(from_ipynb):
|
||||
is_colab = True
|
||||
except NameError:
|
||||
pass
|
||||
@ -203,31 +243,26 @@ class Interface:
|
||||
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
|
||||
if share:
|
||||
try:
|
||||
path_to_ngrok_server, ngrok_api_ports = networking.setup_ngrok(
|
||||
server_port, output_directory, self.ngrok_api_ports)
|
||||
self.ngrok_api_ports = ngrok_api_ports
|
||||
share_url = networking.setup_tunnel(server_port)
|
||||
except RuntimeError:
|
||||
path_to_ngrok_server = None
|
||||
share_url = None
|
||||
if self.verbose:
|
||||
print(strings.en["NGROK_NO_INTERNET"])
|
||||
else:
|
||||
if is_colab: # For a colab notebook, create a public link even if share is False.
|
||||
path_to_ngrok_server, ngrok_api_ports = networking.setup_ngrok(
|
||||
server_port, output_directory, self.ngrok_api_ports)
|
||||
self.ngrok_api_ports = ngrok_api_ports
|
||||
if (
|
||||
is_colab
|
||||
): # For a colab notebook, create a public link even if share is False.
|
||||
share_url = networking.setup_tunnel(server_port)
|
||||
if self.verbose:
|
||||
print(strings.en["COLAB_NO_LOCAL"])
|
||||
else: # If it's not a colab notebook and share=False, print a message telling them about the share option.
|
||||
if self.verbose:
|
||||
print(strings.en["PUBLIC_SHARE_TRUE"])
|
||||
path_to_ngrok_server = None
|
||||
share_url = None
|
||||
|
||||
if path_to_ngrok_server is not None:
|
||||
url = urllib.parse.urlparse(path_to_ngrok_server)
|
||||
subdomain = url.hostname.split('.')[0]
|
||||
path_to_ngrok_interface_page = SHARE_LINK_FORMAT.format(subdomain)
|
||||
if share_url is not None:
|
||||
if self.verbose:
|
||||
print(strings.en["MODEL_PUBLICLY_AVAILABLE_URL"].format(path_to_ngrok_interface_page))
|
||||
print(strings.en["MODEL_PUBLICLY_AVAILABLE_URL"].format(share_url))
|
||||
|
||||
if inline is None:
|
||||
try: # Check if running interactively using ipython.
|
||||
@ -244,12 +279,17 @@ class Interface:
|
||||
inbrowser = False
|
||||
|
||||
if inbrowser and not is_colab:
|
||||
webbrowser.open(path_to_local_server) # Open a browser tab with the interface.
|
||||
webbrowser.open(
|
||||
path_to_local_server
|
||||
) # Open a browser tab with the interface.
|
||||
if inline:
|
||||
from IPython.display import IFrame
|
||||
if is_colab: # Embed the remote interface page if on google colab; otherwise, embed the local page.
|
||||
display(IFrame(path_to_ngrok_interface_page, width=1000, height=500))
|
||||
|
||||
if (
|
||||
is_colab
|
||||
): # Embed the remote interface page if on google colab; otherwise, embed the local page.
|
||||
display(IFrame(share_url, width=1000, height=500))
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
||||
return httpd, path_to_local_server, path_to_ngrok_server
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
@ -1,46 +1,37 @@
|
||||
'''
|
||||
"""
|
||||
Defines helper methods useful for setting up ports, launching servers, and handling `ngrok`
|
||||
'''
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import requests
|
||||
import zipfile
|
||||
import io
|
||||
import sys
|
||||
import os
|
||||
import socket
|
||||
from psutil import process_iter, AccessDenied, NoSuchProcess
|
||||
from signal import SIGTERM # or SIGKILL
|
||||
import threading
|
||||
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
|
||||
import stat
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
from distutils import dir_util
|
||||
from gradio import inputs, outputs
|
||||
import time
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
from gradio.tunneling import create_tunnel
|
||||
import urllib.request
|
||||
|
||||
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
|
||||
LOCALHOST_NAME = 'localhost'
|
||||
NGROK_TUNNEL_API_URL = "http://{}/api/tunnels"
|
||||
INITIAL_PORT_VALUE = (
|
||||
7860
|
||||
) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = (
|
||||
100
|
||||
) # Number of ports to try before giving up and throwing an exception.
|
||||
LOCALHOST_NAME = "localhost"
|
||||
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
|
||||
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html')
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename('gradio', 'static/')
|
||||
STATIC_PATH_TEMP = 'static/'
|
||||
TEMPLATE_TEMP = 'index.html'
|
||||
BASE_JS_FILE = 'static/js/all_io.js'
|
||||
CONFIG_FILE = 'static/config.json'
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename(
|
||||
"gradio", "templates/base_template.html"
|
||||
)
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
|
||||
STATIC_PATH_TEMP = "static/"
|
||||
TEMPLATE_TEMP = "index.html"
|
||||
BASE_JS_FILE = "static/js/all_io.js"
|
||||
CONFIG_FILE = "static/config.json"
|
||||
|
||||
NGROK_ZIP_URLS = {
|
||||
"linux": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip",
|
||||
"darwin": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-darwin-amd64.zip",
|
||||
"win32": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-windows-amd64.zip",
|
||||
}
|
||||
|
||||
def build_template(temp_dir, input_interface, output_interface):
|
||||
"""
|
||||
@ -50,16 +41,27 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
:param output_interface: an AbstractInput object which includes is used to get the input template
|
||||
"""
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
|
||||
"gradio",
|
||||
inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()),
|
||||
)
|
||||
output_template_path = pkg_resources.resource_filename(
|
||||
'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()))
|
||||
"gradio",
|
||||
outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()),
|
||||
)
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(render_string_or_list_with_tags(
|
||||
input_page.read(), input_interface.get_template_context()), features="html.parser")
|
||||
input_soup = BeautifulSoup(
|
||||
render_string_or_list_with_tags(
|
||||
input_page.read(), input_interface.get_template_context()
|
||||
),
|
||||
features="html.parser",
|
||||
)
|
||||
output_soup = BeautifulSoup(
|
||||
render_string_or_list_with_tags(
|
||||
output_page.read(), output_interface.get_template_context()), features="html.parser")
|
||||
output_page.read(), output_interface.get_template_context()
|
||||
),
|
||||
features="html.parser",
|
||||
)
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||
@ -73,12 +75,20 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
f.write(str(all_io_soup))
|
||||
|
||||
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||
render_template_with_tags(os.path.join(
|
||||
temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
|
||||
input_interface.get_js_context())
|
||||
render_template_with_tags(os.path.join(
|
||||
temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())),
|
||||
output_interface.get_js_context())
|
||||
render_template_with_tags(
|
||||
os.path.join(
|
||||
temp_dir,
|
||||
inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name()),
|
||||
),
|
||||
input_interface.get_js_context(),
|
||||
)
|
||||
render_template_with_tags(
|
||||
os.path.join(
|
||||
temp_dir,
|
||||
outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name()),
|
||||
),
|
||||
output_interface.get_js_context(),
|
||||
)
|
||||
|
||||
|
||||
def copy_files(src_dir, dest_dir):
|
||||
@ -100,7 +110,7 @@ def render_template_with_tags(template_path, context):
|
||||
with open(template_path) as fin:
|
||||
old_lines = fin.readlines()
|
||||
new_lines = render_string_or_list_with_tags(old_lines, context)
|
||||
with open(template_path, 'w') as fout:
|
||||
with open(template_path, "w") as fout:
|
||||
for line in new_lines:
|
||||
fout.write(line)
|
||||
|
||||
@ -109,36 +119,27 @@ def render_string_or_list_with_tags(old_lines, context):
|
||||
# Handle string case
|
||||
if isinstance(old_lines, str):
|
||||
for key, value in context.items():
|
||||
old_lines = old_lines.replace(r'{{' + key + r'}}', str(value))
|
||||
old_lines = old_lines.replace(r"{{" + key + r"}}", str(value))
|
||||
return old_lines
|
||||
|
||||
# Handle list case
|
||||
new_lines = []
|
||||
for line in old_lines:
|
||||
for key, value in context.items():
|
||||
line = line.replace(r'{{' + key + r'}}', str(value))
|
||||
line = line.replace(r"{{" + key + r"}}", str(value))
|
||||
new_lines.append(line)
|
||||
return new_lines
|
||||
|
||||
|
||||
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
|
||||
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
|
||||
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')
|
||||
js_file = os.path.join(temp_dir, BASE_JS_FILE)
|
||||
render_template_with_tags(js_file, {'ngrok_socket_url': ngrok_socket_url})
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
render_template_with_tags(config_file, {'ngrok_socket_url': ngrok_socket_url})
|
||||
|
||||
|
||||
def set_socket_port_in_js(temp_dir, socket_port):
|
||||
js_file = os.path.join(temp_dir, BASE_JS_FILE)
|
||||
render_template_with_tags(js_file, {'socket_port': str(socket_port)})
|
||||
|
||||
|
||||
def set_interface_types_in_config_file(temp_dir, input_interface, output_interface):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
render_template_with_tags(config_file, {'input_interface_type': input_interface,
|
||||
'output_interface_type': output_interface})
|
||||
render_template_with_tags(
|
||||
config_file,
|
||||
{
|
||||
"input_interface_type": input_interface,
|
||||
"output_interface_type": output_interface,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_first_available_port(initial, final):
|
||||
@ -156,15 +157,20 @@ def get_first_available_port(initial, final):
|
||||
return port
|
||||
except OSError:
|
||||
pass
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(initial, final))
|
||||
raise OSError(
|
||||
"All ports from {} to {} are in use. Please close a port.".format(
|
||||
initial, final
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
class HTTPHandler(SimpleHTTPRequestHandler):
|
||||
"""This handler uses server.base_path instead of always using os.getcwd()"""
|
||||
|
||||
def _set_headers(self):
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
def translate_path(self, path):
|
||||
@ -181,17 +187,14 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
if self.path == "/api/predict/":
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers['Content-Length']))
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
|
||||
# Make the prediction.
|
||||
msg = json.loads(data_string)
|
||||
processed_input = interface.input_interface.preprocess(msg['data'])
|
||||
processed_input = interface.input_interface.preprocess(msg["data"])
|
||||
prediction = interface.predict(processed_input)
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {
|
||||
'action': 'output',
|
||||
'data': processed_output,
|
||||
}
|
||||
output = {"action": "output", "data": processed_output}
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps(output).encode())
|
||||
@ -223,7 +226,9 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
|
||||
|
||||
def start_simple_server(interface, directory_to_serve=None):
|
||||
port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
|
||||
port = get_first_available_port(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
||||
)
|
||||
httpd = serve_files_in_background(interface, port, directory_to_serve)
|
||||
return port, httpd
|
||||
|
||||
@ -233,66 +238,23 @@ def close_server(server):
|
||||
server.server_close()
|
||||
|
||||
|
||||
def download_ngrok():
|
||||
def url_request(url):
|
||||
try:
|
||||
zip_file_url = NGROK_ZIP_URLS[sys.platform]
|
||||
except KeyError:
|
||||
print("Sorry, we don't currently support your operating system, please leave us a note on GitHub, and "
|
||||
"we'll look into it!")
|
||||
return
|
||||
r = requests.get(zip_file_url)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall()
|
||||
if sys.platform == 'darwin' or sys.platform == 'linux':
|
||||
st = os.stat('ngrok')
|
||||
os.chmod('ngrok', st.st_mode | stat.S_IEXEC)
|
||||
req = urllib.request.Request(
|
||||
url=url, headers={"content-type": "application/json"}
|
||||
)
|
||||
res = urllib.request.urlopen(req, timeout=10)
|
||||
return res
|
||||
except Exception as e:
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
|
||||
def create_ngrok_tunnel(local_port, log_file):
|
||||
if not(os.path.isfile('ngrok.exe') or os.path.isfile('ngrok')):
|
||||
download_ngrok()
|
||||
if sys.platform == 'win32':
|
||||
subprocess.Popen(['ngrok', 'http', str(local_port), '--log', log_file, '--log-format', 'json'])
|
||||
else:
|
||||
subprocess.Popen(['./ngrok', 'http', str(local_port), '--log', log_file, '--log-format', 'json'])
|
||||
time.sleep(1.5) # Let ngrok write to the log file TODO(abidlabs): a better way to do this.
|
||||
session = requests.Session()
|
||||
retry = Retry(connect=3, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
|
||||
api_url = None
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
log = json.loads(line)
|
||||
if log["msg"] == "starting web service":
|
||||
api_url = log["addr"]
|
||||
api_port = urlparse(api_url).port
|
||||
break
|
||||
|
||||
if api_url is None:
|
||||
raise RuntimeError("Tunnel information not available in log file")
|
||||
|
||||
r = session.get(NGROK_TUNNEL_API_URL.format(api_url))
|
||||
for tunnel in r.json()['tunnels']:
|
||||
if '{}:'.format(LOCALHOST_NAME) + str(local_port) in tunnel['config']['addr'] and tunnel['proto'] == 'https':
|
||||
return tunnel['public_url'], api_port
|
||||
raise RuntimeError("Not able to retrieve ngrok public URL")
|
||||
|
||||
|
||||
def setup_ngrok(server_port, output_directory, existing_ports):
|
||||
if not(existing_ports is None):
|
||||
kill_processes(existing_ports)
|
||||
site_ngrok_url, port1 = create_ngrok_tunnel(server_port, os.path.join(output_directory, 'ngrok1.log'))
|
||||
return site_ngrok_url, [port1]
|
||||
|
||||
|
||||
def kill_processes(process_ids): #TODO(abidlabs): remove this, we shouldn't need to kill
|
||||
for proc in process_iter():
|
||||
def setup_tunnel(local_server_port):
|
||||
response = url_request(GRADIO_API_SERVER)
|
||||
if response and response.code == 200:
|
||||
try:
|
||||
for conns in proc.connections(kind='inet'):
|
||||
if conns.laddr.port in process_ids:
|
||||
proc.send_signal(SIGTERM) # or SIGKILL
|
||||
except (AccessDenied, NoSuchProcess):
|
||||
pass
|
||||
payload = json.loads(response.read().decode("utf-8"))[0]
|
||||
return create_tunnel(payload, LOCALHOST_NAME, local_server_port)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(str(e))
|
||||
|
101
gradio/tunneling.py
Normal file
101
gradio/tunneling.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""
|
||||
This file provides remote port forwarding functionality using paramiko package,
|
||||
Inspired by: https://github.com/paramiko/paramiko/blob/master/demos/rforward.py
|
||||
"""
|
||||
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from io import StringIO
|
||||
|
||||
import paramiko
|
||||
|
||||
DEBUG_MODE = False
|
||||
|
||||
|
||||
def handler(chan, host, port):
|
||||
sock = socket.socket()
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
except Exception as e:
|
||||
verbose("Forwarding request to %s:%d failed: %r" % (host, port, e))
|
||||
return
|
||||
|
||||
verbose(
|
||||
"Connected! Tunnel open %r -> %r -> %r"
|
||||
% (chan.origin_addr, chan.getpeername(), (host, port))
|
||||
)
|
||||
while True:
|
||||
r, w, x = select.select([sock, chan], [], [])
|
||||
if sock in r:
|
||||
data = sock.recv(1024)
|
||||
if len(data) == 0:
|
||||
break
|
||||
chan.send(data)
|
||||
if chan in r:
|
||||
data = chan.recv(1024)
|
||||
if len(data) == 0:
|
||||
break
|
||||
sock.send(data)
|
||||
chan.close()
|
||||
sock.close()
|
||||
verbose("Tunnel closed from %r" % (chan.origin_addr,))
|
||||
|
||||
|
||||
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
|
||||
transport.request_port_forward("", server_port)
|
||||
while True:
|
||||
chan = transport.accept(1000)
|
||||
if chan is None:
|
||||
continue
|
||||
thr = threading.Thread(target=handler, args=(chan, remote_host, remote_port))
|
||||
thr.setDaemon(True)
|
||||
thr.start()
|
||||
|
||||
|
||||
def verbose(s):
|
||||
if DEBUG_MODE:
|
||||
print(s)
|
||||
|
||||
|
||||
def create_tunnel(payload, local_server, local_server_port):
|
||||
client = paramiko.SSHClient()
|
||||
# client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
|
||||
verbose(
|
||||
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
|
||||
)
|
||||
try:
|
||||
client.connect(
|
||||
hostname=payload["host"],
|
||||
port=int(payload["port"]),
|
||||
username=payload["user"],
|
||||
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])),
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"*** Failed to connect to %s:%d: %r"
|
||||
% (payload["host"], int(payload["port"]), e)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
verbose(
|
||||
"Now forwarding remote port %d to %s:%d ..."
|
||||
% (int(payload["remote_port"]), local_server, local_server_port)
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=reverse_forward_tunnel,
|
||||
args=(
|
||||
int(payload["remote_port"]),
|
||||
local_server,
|
||||
local_server_port,
|
||||
client.get_transport(),
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return payload["share_url"]
|
Loading…
Reference in New Issue
Block a user