diff --git a/gradio/inputs.py b/gradio/inputs.py index 4635d965cf..d7f000fb5d 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -19,6 +19,9 @@ class AbstractInput(ABC): """ def __init__(self, preprocessing_fn=None): + """ + :param preprocessing_fn: an optional preprocessing function that overrides the default + """ if preprocessing_fn is not None: if not callable(preprocessing_fn): raise ValueError('`preprocessing_fn` must be a callable function') diff --git a/gradio/networking.py b/gradio/networking.py index 65f38d8410..623857ac79 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -20,10 +20,9 @@ import pkg_resources from bs4 import BeautifulSoup import shutil -INITIAL_PORT_VALUE = 7860 -TRY_NUM_PORTS = 100 +INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc. +TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception. LOCALHOST_NAME = 'localhost' -LOCALHOST_PREFIX = 'localhost:' NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output) NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output) @@ -45,6 +44,12 @@ NGROK_ZIP_URLS = { def build_template(temp_dir, input_interface, output_interface): + """ + Builds a complete HTML template with supporting JS and CSS files in a given directory. + :param temp_dir: string with path to temp directory in which the template should be built + :param input_interface: an AbstractInput object which includes is used to get the input template + :param output_interface: an AbstractInput object which includes is used to get the input template + """ input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path()) output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path()) input_page = open(input_template_path) @@ -65,10 +70,14 @@ def build_template(temp_dir, input_interface, output_interface): copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP)) copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP)) - return def copy_files(src_dir, dest_dir): + """ + Copies all the files from one directory to another + :param src_dir: string path to source directory + :param dest_dir: string path to destination directory + """ if not os.path.exists(dest_dir): os.makedirs(dest_dir) src_files = os.listdir(src_dir) @@ -100,6 +109,12 @@ def set_socket_port_in_js(temp_dir, socket_port): def get_first_available_port(initial, final): + """ + Gets the first open port in a specified range of port numbers + :param initial: the initial value in the range of port numbers + :param final: final (exclusive) value in the range of port numbers, should be greater than `initial` + :return: + """ for port in range(initial, final): try: s = socket.socket() # create a socket object @@ -144,8 +159,7 @@ def serve_files_in_background(port, directory_to_serve=None): def start_simple_server(directory_to_serve=None): - # TODO(abidlabs): increment port number until free port is found - port = get_first_available_port (INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS) + port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS) serve_files_in_background(port, directory_to_serve) return port @@ -179,7 +193,7 @@ def create_ngrok_tunnel(local_port, api_url): session.mount('https://', adapter) r = session.get(api_url) for tunnel in r.json()['tunnels']: - if LOCALHOST_PREFIX + str(local_port) in tunnel['config']['addr']: + if '{}:'.format(LOCALHOST_NAME) + str(local_port) in tunnel['config']['addr']: return tunnel['public_url'] raise RuntimeError("Not able to retrieve ngrok public URL") @@ -192,7 +206,7 @@ def setup_ngrok(server_port, websocket_port, output_directory): return site_ngrok_url -def kill_processes(process_ids): +def kill_processes(process_ids): #TODO(abidlabs): remove this, we shouldn't need to kill for proc in process_iter(): try: for conns in proc.connections(kind='inet'): diff --git a/gradio/outputs.py b/gradio/outputs.py index 97449168d7..dc50714206 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -16,6 +16,7 @@ class AbstractOutput(ABC): def __init__(self, postprocessing_fn=None): """ + :param postprocessing_fn: an optional postprocessing function that overrides the default """ if postprocessing_fn is not None: self.postprocess = postprocessing_fn diff --git a/test/test_networking.py b/test/test_networking.py new file mode 100644 index 0000000000..96660898ec --- /dev/null +++ b/test/test_networking.py @@ -0,0 +1,44 @@ +import unittest +from gradio import networking +import socket +import tempfile +import os +LOCALHOST_NAME = 'localhost' + + +class TestGetAvailablePort(unittest.TestCase): + def test_get_first_available_port_by_blocking_port(self): + initial = 7000 + final = 8000 + port_found = False + for port in range(initial, final): + try: + s = socket.socket() # create a socket object + s.bind((LOCALHOST_NAME, port)) # Bind to the port + s.close() + port_found = True + break + except OSError: + pass + if port_found: + s = socket.socket() # create a socket object + s.bind((LOCALHOST_NAME, port)) # Bind to the port + new_port = networking.get_first_available_port(initial, final) + s.close() + self.assertFalse(port==new_port) + + +class TestCopyFiles(unittest.TestCase): + def test_copy_files(self): + filename = "a.txt" + with tempfile.TemporaryDirectory() as temp_src: + with open(os.path.join(temp_src, "a.txt"), "w+") as f: + f.write('Hi') + with tempfile.TemporaryDirectory() as temp_dest: + self.assertFalse(os.path.exists(os.path.join(temp_dest, filename))) + networking.copy_files(temp_src, temp_dest) + self.assertTrue(os.path.exists(os.path.join(temp_dest, filename))) + + +if __name__ == '__main__': + unittest.main()