This commit is contained in:
aliabd 2020-06-18 09:39:32 -07:00
commit 7dd16a7d19
8 changed files with 41 additions and 25 deletions

View File

@ -17,7 +17,7 @@ import random
import time import time
from IPython import get_ipython from IPython import get_ipython
LOCALHOST_IP = "127.0.0.1" LOCALHOST_IP = "0.0.0.0"
TRY_NUM_PORTS = 100 TRY_NUM_PORTS = 100
PKG_VERSION_URL = "https://gradio.app/api/pkg-version" PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
@ -29,7 +29,8 @@ class Interface:
""" """
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
live=False, show_input=True, show_output=True): live=False, show_input=True, show_output=True,
load_fn=None, server_name=LOCALHOST_IP):
""" """
:param fn: a function that will process the input panel data from the interface and return the output panel data. :param fn: a function that will process the input panel data from the interface and return the output panel data.
:param inputs: a string or `AbstractInput` representing the input interface. :param inputs: a string or `AbstractInput` representing the input interface.
@ -63,6 +64,8 @@ class Interface:
fn = [fn] fn = [fn]
self.output_interfaces *= len(fn) self.output_interfaces *= len(fn)
self.predict = fn self.predict = fn
self.load_fn = load_fn
self.context = None
self.verbose = verbose self.verbose = verbose
self.status = "OFF" self.status = "OFF"
self.saliency = saliency self.saliency = saliency
@ -70,6 +73,7 @@ class Interface:
self.show_input = show_input self.show_input = show_input
self.show_output = show_output self.show_output = show_output
self.flag_hash = random.getrandbits(32) self.flag_hash = random.getrandbits(32)
self.server_name = server_name
def update_config_file(self, output_directory): def update_config_file(self, output_directory):
config = { config = {
@ -148,6 +152,8 @@ class Interface:
""" """
# if validate and not self.validate_flag: # if validate and not self.validate_flag:
# self.validate() # self.validate()
context = self.load_fn() if self.load_fn else None
self.context = context
# If an existing interface is running with this instance, close it. # If an existing interface is running with this instance, close it.
if self.status == "RUNNING": if self.status == "RUNNING":
@ -161,8 +167,8 @@ class Interface:
output_directory = tempfile.mkdtemp() output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface. # Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory) server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://localhost:{}/".format(server_port) path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory) networking.build_template(output_directory)
self.update_config_file(output_directory) self.update_config_file(output_directory)

View File

@ -20,7 +20,7 @@ INITIAL_PORT_VALUE = (
TRY_NUM_PORTS = ( TRY_NUM_PORTS = (
100 100
) # Number of ports to try before giving up and throwing an exception. ) # Number of ports to try before giving up and throwing an exception.
LOCALHOST_NAME = "localhost" LOCALHOST_NAME = "0.0.0.0"
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request" GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/") STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
@ -109,7 +109,7 @@ def get_first_available_port(initial, final):
) )
def serve_files_in_background(interface, port, directory_to_serve=None): def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
class HTTPHandler(SimpleHTTPRequestHandler): class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()""" """This handler uses server.base_path instead of always using os.getcwd()"""
@ -139,7 +139,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)] processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
predictions = [] predictions = []
for predict_fn in interface.predict: for predict_fn in interface.predict:
prediction = predict_fn(*processed_input) if interface.context:
prediction = predict_fn(*processed_input,
interface.context)
else:
prediction = predict_fn(*processed_input)
if len(interface.output_interfaces) / len(interface.predict) == 1: if len(interface.output_interfaces) / len(interface.predict) == 1:
prediction = [prediction] prediction = [prediction]
predictions.extend(prediction) predictions.extend(prediction)
@ -260,7 +264,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
self.base_path = base_path self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass) BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
httpd = HTTPServer(directory_to_serve, (LOCALHOST_NAME, port)) httpd = HTTPServer(directory_to_serve, (server_name, port))
# Now loop forever # Now loop forever
def serve_forever(): def serve_forever():
@ -277,11 +281,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
return httpd return httpd
def start_simple_server(interface, directory_to_serve=None): def start_simple_server(interface, directory_to_serve=None, server_name=None):
port = get_first_available_port( port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
) )
httpd = serve_files_in_background(interface, port, directory_to_serve) httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd return port, httpd

Binary file not shown.

Binary file not shown.

View File

@ -1,6 +1,6 @@
Metadata-Version: 1.0 Metadata-Version: 1.0
Name: gradio Name: gradio
Version: 0.9.1 Version: 0.9.2
Summary: Python library for easily interacting with trained machine learning models Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/abidlabs/gradio Home-page: https://github.com/abidlabs/gradio
Author: Abubakar Abid Author: Abubakar Abid

View File

@ -18,7 +18,7 @@ import time
from IPython import get_ipython from IPython import get_ipython
import tensorflow as tf import tensorflow as tf
LOCALHOST_IP = "127.0.0.1" LOCALHOST_IP = "0.0.0.0"
TRY_NUM_PORTS = 100 TRY_NUM_PORTS = 100
PKG_VERSION_URL = "https://gradio.app/api/pkg-version" PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
@ -30,8 +30,9 @@ class Interface:
""" """
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
live=False, show_input=True, show_output=True, live=False, show_input=True, show_output=True,
load_fn=None, capture_session=False): load_fn=None, capture_session=False,
server_name=LOCALHOST_IP):
""" """
:param fn: a function that will process the input panel data from the interface and return the output panel data. :param fn: a function that will process the input panel data from the interface and return the output panel data.
:param inputs: a string or `AbstractInput` representing the input interface. :param inputs: a string or `AbstractInput` representing the input interface.
@ -43,7 +44,9 @@ class Interface:
elif isinstance(iface, gradio.inputs.AbstractInput): elif isinstance(iface, gradio.inputs.AbstractInput):
return iface return iface
else: else:
raise ValueError("Input interface must be of type `str` or `AbstractInput`") raise ValueError("Input interface must be of type `str` or "
"`AbstractInput`")
def get_output_instance(iface): def get_output_instance(iface):
if isinstance(iface, str): if isinstance(iface, str):
return gradio.outputs.shortcuts[iface] return gradio.outputs.shortcuts[iface]
@ -51,7 +54,8 @@ class Interface:
return iface return iface
else: else:
raise ValueError( raise ValueError(
"Output interface must be of type `str` or `AbstractOutput`" "Output interface must be of type `str` or "
"`AbstractOutput`"
) )
if isinstance(inputs, list): if isinstance(inputs, list):
self.input_interfaces = [get_input_instance(i) for i in inputs] self.input_interfaces = [get_input_instance(i) for i in inputs]
@ -76,6 +80,7 @@ class Interface:
self.flag_hash = random.getrandbits(32) self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session self.capture_session = capture_session
self.session = None self.session = None
self.server_name = server_name
def update_config_file(self, output_directory): def update_config_file(self, output_directory):
config = { config = {
@ -173,8 +178,8 @@ class Interface:
output_directory = tempfile.mkdtemp() output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface. # Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory) server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://localhost:{}/".format(server_port) path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory) networking.build_template(output_directory)
self.update_config_file(output_directory) self.update_config_file(output_directory)
@ -247,7 +252,8 @@ class Interface:
if ( if (
is_colab is_colab
): # Embed the remote interface page if on google colab; otherwise, embed the local page. ): # Embed the remote interface page if on google colab;
# otherwise, embed the local page.
display(IFrame(share_url, width=1000, height=500)) display(IFrame(share_url, width=1000, height=500))
else: else:
display(IFrame(path_to_local_server, width=1000, height=500)) display(IFrame(path_to_local_server, width=1000, height=500))

View File

@ -20,7 +20,7 @@ INITIAL_PORT_VALUE = (
TRY_NUM_PORTS = ( TRY_NUM_PORTS = (
100 100
) # Number of ports to try before giving up and throwing an exception. ) # Number of ports to try before giving up and throwing an exception.
LOCALHOST_NAME = "localhost" LOCALHOST_NAME = "0.0.0.0"
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request" GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/") STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
@ -109,7 +109,7 @@ def get_first_available_port(initial, final):
) )
def serve_files_in_background(interface, port, directory_to_serve=None): def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
class HTTPHandler(SimpleHTTPRequestHandler): class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()""" """This handler uses server.base_path instead of always using os.getcwd()"""
@ -278,7 +278,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
self.base_path = base_path self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass) BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
httpd = HTTPServer(directory_to_serve, (LOCALHOST_NAME, port)) httpd = HTTPServer(directory_to_serve, (server_name, port))
# Now loop forever # Now loop forever
def serve_forever(): def serve_forever():
@ -295,11 +295,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None):
return httpd return httpd
def start_simple_server(interface, directory_to_serve=None): def start_simple_server(interface, directory_to_serve=None, server_name=None):
port = get_first_available_port( port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
) )
httpd = serve_files_in_background(interface, port, directory_to_serve) httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd return port, httpd

View File

@ -5,7 +5,7 @@ except ImportError:
setup( setup(
name='gradio', name='gradio',
version='0.9.1', version='0.9.2',
include_package_data=True, include_package_data=True,
description='Python library for easily interacting with trained machine learning models', description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid', author='Abubakar Abid',