mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
added ability to close all servers
This commit is contained in:
parent
6bc9d0204b
commit
9bc6aa44f0
@ -1 +1 @@
|
||||
from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
|
@ -18,6 +18,7 @@ import time
|
||||
import inspect
|
||||
from IPython import get_ipython
|
||||
import sys
|
||||
import weakref
|
||||
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
@ -28,6 +29,7 @@ class Interface:
|
||||
The Interface class represents a general input/output interface for a machine learning model. During construction,
|
||||
the appropriate inputs and outputs
|
||||
"""
|
||||
instances = weakref.WeakSet()
|
||||
|
||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
@ -83,6 +85,8 @@ class Interface:
|
||||
self.description = description
|
||||
self.thumbnail = thumbnail
|
||||
self.examples = examples
|
||||
self.server_port = None
|
||||
Interface.instances.add(self)
|
||||
|
||||
def get_config_file(self):
|
||||
config = {
|
||||
@ -205,6 +209,11 @@ class Interface:
|
||||
return
|
||||
raise RuntimeError("Validation did not pass")
|
||||
|
||||
def close(self):
|
||||
if self.server_port:
|
||||
print("Closing Gradio server on port {}...".format(self.server_port))
|
||||
networking.close_server(self.simple_server)
|
||||
|
||||
def launch(self, inline=None, inbrowser=None, share=False, validate=True, debug=False):
|
||||
"""
|
||||
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
|
||||
@ -224,22 +233,13 @@ class Interface:
|
||||
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
|
||||
pass
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == "RUNNING":
|
||||
if self.verbose:
|
||||
print("Closing existing server...")
|
||||
if self.simple_server is not None:
|
||||
try:
|
||||
networking.close_server(self.simple_server)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
||||
@ -344,3 +344,12 @@ class Interface:
|
||||
time.sleep(0.1)
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls):
|
||||
return list(Interface.instances) #Returns list of all current instances
|
||||
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
io.close()
|
||||
|
@ -114,16 +114,9 @@ def get_first_available_port(initial, final):
|
||||
)
|
||||
|
||||
|
||||
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME, stdout=None):
|
||||
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
|
||||
class HTTPHandler(SimpleHTTPRequestHandler):
|
||||
"""This handler uses server.base_path instead of always using os.getcwd()"""
|
||||
# def __init__(self, *args):
|
||||
# if not(stdout is None):
|
||||
# sys.stdout = stdout
|
||||
# else:
|
||||
# print('out is None')
|
||||
# super().__init__(*args)
|
||||
|
||||
def _set_headers(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
@ -203,13 +196,13 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
|
||||
# Now loop forever
|
||||
def serve_forever():
|
||||
# try:
|
||||
sys.stdout = stdout
|
||||
while True:
|
||||
sys.stdout.flush()
|
||||
httpd.serve_forever()
|
||||
# except (KeyboardInterrupt, OSError):
|
||||
# httpd.server_close()
|
||||
try:
|
||||
while True:
|
||||
sys.stdout.flush()
|
||||
httpd.serve_forever()
|
||||
except (KeyboardInterrupt, OSError):
|
||||
httpd.shutdown()
|
||||
httpd.server_close()
|
||||
|
||||
thread = threading.Thread(target=serve_forever, daemon=False)
|
||||
thread.start()
|
||||
@ -221,13 +214,11 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None):
|
||||
port = get_first_available_port(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
||||
)
|
||||
httpd = serve_files_in_background(
|
||||
interface, port, directory_to_serve, server_name, stdout=sys.stdout)
|
||||
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
|
||||
return port, httpd
|
||||
|
||||
|
||||
def close_server(server):
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user