added ability to close all servers

This commit is contained in:
Abubakar Abid 2020-07-06 23:31:21 -05:00
parent 6bc9d0204b
commit 9bc6aa44f0
3 changed files with 29 additions and 29 deletions

View File

@ -1 +1 @@
from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.

View File

@ -18,6 +18,7 @@ import time
import inspect
from IPython import get_ipython
import sys
import weakref
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
@ -28,6 +29,7 @@ class Interface:
The Interface class represents a general input/output interface for a machine learning model. During construction,
the appropriate inputs and outputs
"""
instances = weakref.WeakSet()
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
@ -83,6 +85,8 @@ class Interface:
self.description = description
self.thumbnail = thumbnail
self.examples = examples
self.server_port = None
Interface.instances.add(self)
def get_config_file(self):
config = {
@ -205,6 +209,11 @@ class Interface:
return
raise RuntimeError("Validation did not pass")
def close(self):
if self.server_port:
print("Closing Gradio server on port {}...".format(self.server_port))
networking.close_server(self.simple_server)
def launch(self, inline=None, inbrowser=None, share=False, validate=True, debug=False):
"""
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
@ -224,22 +233,13 @@ class Interface:
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
pass
# If an existing interface is running with this instance, close it.
if self.status == "RUNNING":
if self.verbose:
print("Closing existing server...")
if self.simple_server is not None:
try:
networking.close_server(self.simple_server)
except OSError:
pass
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
self.server_port = server_port
self.status = "RUNNING"
self.simple_server = httpd
@ -344,3 +344,12 @@ class Interface:
time.sleep(0.1)
return httpd, path_to_local_server, share_url
@classmethod
def get_instances(cls):
return list(Interface.instances) #Returns list of all current instances
def reset_all():
for io in Interface.get_instances():
io.close()

View File

@ -114,16 +114,9 @@ def get_first_available_port(initial, final):
)
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME, stdout=None):
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()"""
# def __init__(self, *args):
# if not(stdout is None):
# sys.stdout = stdout
# else:
# print('out is None')
# super().__init__(*args)
def _set_headers(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
@ -203,13 +196,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
# Now loop forever
def serve_forever():
# try:
sys.stdout = stdout
while True:
sys.stdout.flush()
httpd.serve_forever()
# except (KeyboardInterrupt, OSError):
# httpd.server_close()
try:
while True:
sys.stdout.flush()
httpd.serve_forever()
except (KeyboardInterrupt, OSError):
httpd.shutdown()
httpd.server_close()
thread = threading.Thread(target=serve_forever, daemon=False)
thread.start()
@ -221,13 +214,11 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None):
port = get_first_available_port(
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
)
httpd = serve_files_in_background(
interface, port, directory_to_serve, server_name, stdout=sys.stdout)
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd
def close_server(server):
server.shutdown()
server.server_close()