This commit is contained in:
Ali Abid 2020-08-10 16:51:40 -07:00
parent f684b3ba44
commit 3f4b5d0b26
5 changed files with 77 additions and 51 deletions

View File

@ -9,6 +9,8 @@ def get_params(func):
params_doc = []
documented_params = {"self"}
for param_line in params_str.split("\n")[1:]:
if param_line.strip() == "Returns":
break
space_index = param_line.index(" ")
colon_index = param_line.index(":")
name = param_line[:space_index]

View File

@ -21,7 +21,6 @@ import weakref
import analytics
import os
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
@ -48,7 +47,7 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged"):
"""
@ -69,6 +68,7 @@ class Interface:
allow_flagging (bool): if False, users will not see a button to flag an input and output.
flagging_dir (str): what to name the dir where flagged data is stored.
"""
def get_input_instance(iface):
if isinstance(iface, str):
shortcut = InputComponent.get_all_shortcut_implementations()[iface]
@ -90,6 +90,7 @@ class Interface:
"Output interface must be of type `str` or "
"`OutputComponent`"
)
if isinstance(inputs, list):
self.input_interfaces = [get_input_instance(i) for i in inputs]
else:
@ -135,7 +136,7 @@ class Interface:
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
tf.keras.backend.get_session()
except (ImportError, AttributeError):
# If they are using TF >= 2.0 or don't have TF,
# just ignore this.
@ -151,7 +152,7 @@ class Interface:
"_{}".format(index)):
index += 1
self.flagging_dir = self.flagging_dir + "/" + dir_name + \
"_{}".format(index)
"_{}".format(index)
try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
@ -188,8 +189,8 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
return config
return config
def process(self, raw_input):
"""
@ -208,7 +209,7 @@ class Interface:
durations = []
for predict_fn in self.predict:
start = time.time()
if self.capture_session and not(self.session is None):
if self.capture_session and not (self.session is None):
graph, sess = self.session
with graph.as_default():
with sess.as_default():
@ -238,31 +239,35 @@ class Interface:
return processed_output, durations
def close(self):
if self.simple_server and not(self.simple_server.fileno() == -1): # checks to see if server is running
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
print("Closing Gradio server on port {}...".format(self.server_port))
networking.close_server(self.simple_server)
def run_until_interrupted(self, thread, path_to_local_server):
try:
while 1:
pass
except (KeyboardInterrupt, OSError):
print("Keyboard interruption in main thread... closing server.")
thread.keep_running = False
networking.url_ok(path_to_local_server)
def launch(self, inline=None, inbrowser=None, share=False, debug=False):
"""
Parameters
inline (bool): whether to display in the interface inline on python
notebooks.
inbrowser (bool): whether to automatically launch the interface in a
new tab on the default browser.
share (bool): whether to create a publicly shareable link from
your computer for the interface.
debug (bool): if True, and the interface was launched from Google
Colab, prints the errors in the cell output.
:returns
inline (bool): whether to display in the interface inline on python notebooks.
inbrowser (bool): whether to automatically launch the interface in a new tab on the default browser.
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
Returns
httpd (str): HTTPServer object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name,
server_port=self.server_port)
server_port, httpd, thread = networking.start_simple_server(
self, output_directory, self.server_name, server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
@ -277,7 +282,7 @@ class Interface:
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
@ -370,6 +375,11 @@ class Interface:
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
is_in_interactive_mode = bool(getattr(sys, 'ps1', sys.flags.interactive))
if not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
return httpd, path_to_local_server, share_url

View File

@ -9,6 +9,7 @@ from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import pkg_resources
from distutils import dir_util
from gradio import inputs, outputs
import time
import json
from gradio.tunneling import create_tunnel
import urllib.request
@ -16,6 +17,8 @@ from shutil import copyfile
import requests
import sys
import analytics
import csv
INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -183,18 +186,32 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
os.makedirs(interface.flagging_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild(
interface.flagging_dir, msg['data']['input_data']) for i
interface.flagging_dir, msg['data']['input_data'][i]) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild(
interface.flagging_dir, msg['data']['output_data']) for i
interface.flagging_dir, msg['data']['output_data'][i])
for i
in range(len(interface.output_interfaces))]}
with open("{}/log.txt".format(interface.flagging_dir),
'a+') as f:
f.write(json.dumps(output))
f.write("\n")
log_fp = "{}/log.csv".format(interface.flagging_dir)
is_new = not os.path.exists(log_fp)
with open(log_fp, "a") as csvfile:
headers = ["input_{}".format(i) for i in range(len(
output["inputs"]))] + ["output_{}".format(i) for i in
range(len(output["outputs"]))]
writer = csv.DictWriter(csvfile, delimiter=',',
lineterminator='\n',
fieldnames=headers)
if is_new:
writer.writeheader()
writer.writerow(
dict(zip(headers, output["inputs"] +
output["outputs"]))
)
else:
self.send_error(404, 'Path not found: {}'.format(self.path))
@ -205,22 +222,21 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
class QuittableHTTPThread(threading.Thread):
def __init__(self, httpd):
super().__init__(daemon=False)
self.httpd = httpd
self.keep_running =True
def run(self):
while self.keep_running:
self.httpd.handle_request()
httpd = HTTPServer(directory_to_serve, (server_name, port))
# Now loop forever
def serve_forever():
try:
while True:
sys.stdout.flush()
httpd.serve_forever()
except (KeyboardInterrupt, OSError):
httpd.shutdown()
httpd.server_close()
thread = threading.Thread(target=serve_forever, daemon=False)
thread = QuittableHTTPThread(httpd=httpd)
thread.start()
return httpd
return httpd, thread
def start_simple_server(interface, directory_to_serve=None, server_name=None, server_port=None):
@ -229,8 +245,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None, se
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd
httpd, thread = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd, thread
def close_server(server):

View File

@ -9,6 +9,8 @@ def get_params(func):
params_doc = []
documented_params = {"self"}
for param_line in params_str.split("\n")[1:]:
if param_line.strip() == "Returns":
break
space_index = param_line.index(" ")
colon_index = param_line.index(":")
name = param_line[:space_index]

View File

@ -255,15 +255,11 @@ class Interface:
def launch(self, inline=None, inbrowser=None, share=False, debug=False):
"""
Parameters
inline (bool): whether to display in the interface inline on python
notebooks.
inbrowser (bool): whether to automatically launch the interface in a
new tab on the default browser.
share (bool): whether to create a publicly shareable link from
your computer for the interface.
debug (bool): if True, and the interface was launched from Google
Colab, prints the errors in the cell output.
:returns
inline (bool): whether to display in the interface inline on python notebooks.
inbrowser (bool): whether to automatically launch the interface in a new tab on the default browser.
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
Returns
httpd (str): HTTPServer object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)