added tests for networking

This commit is contained in:
Abubakar Abid 2019-03-01 20:05:39 -08:00
parent 395f3df686
commit 8857383131
4 changed files with 70 additions and 8 deletions

View File

@ -19,6 +19,9 @@ class AbstractInput(ABC):
"""
def __init__(self, preprocessing_fn=None):
"""
:param preprocessing_fn: an optional preprocessing function that overrides the default
"""
if preprocessing_fn is not None:
if not callable(preprocessing_fn):
raise ValueError('`preprocessing_fn` must be a callable function')

View File

@ -20,10 +20,9 @@ import pkg_resources
from bs4 import BeautifulSoup
import shutil
INITIAL_PORT_VALUE = 7860
TRY_NUM_PORTS = 100
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
LOCALHOST_NAME = 'localhost'
LOCALHOST_PREFIX = 'localhost:'
NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output)
NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output)
@ -45,6 +44,12 @@ NGROK_ZIP_URLS = {
def build_template(temp_dir, input_interface, output_interface):
"""
Builds a complete HTML template with supporting JS and CSS files in a given directory.
:param temp_dir: string with path to temp directory in which the template should be built
:param input_interface: an AbstractInput object which includes is used to get the input template
:param output_interface: an AbstractInput object which includes is used to get the input template
"""
input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path())
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
input_page = open(input_template_path)
@ -65,10 +70,14 @@ def build_template(temp_dir, input_interface, output_interface):
copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP))
copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP))
return
def copy_files(src_dir, dest_dir):
"""
Copies all the files from one directory to another
:param src_dir: string path to source directory
:param dest_dir: string path to destination directory
"""
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
src_files = os.listdir(src_dir)
@ -100,6 +109,12 @@ def set_socket_port_in_js(temp_dir, socket_port):
def get_first_available_port(initial, final):
"""
Gets the first open port in a specified range of port numbers
:param initial: the initial value in the range of port numbers
:param final: final (exclusive) value in the range of port numbers, should be greater than `initial`
:return:
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
@ -144,8 +159,7 @@ def serve_files_in_background(port, directory_to_serve=None):
def start_simple_server(directory_to_serve=None):
# TODO(abidlabs): increment port number until free port is found
port = get_first_available_port (INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
serve_files_in_background(port, directory_to_serve)
return port
@ -179,7 +193,7 @@ def create_ngrok_tunnel(local_port, api_url):
session.mount('https://', adapter)
r = session.get(api_url)
for tunnel in r.json()['tunnels']:
if LOCALHOST_PREFIX + str(local_port) in tunnel['config']['addr']:
if '{}:'.format(LOCALHOST_NAME) + str(local_port) in tunnel['config']['addr']:
return tunnel['public_url']
raise RuntimeError("Not able to retrieve ngrok public URL")
@ -192,7 +206,7 @@ def setup_ngrok(server_port, websocket_port, output_directory):
return site_ngrok_url
def kill_processes(process_ids):
def kill_processes(process_ids): #TODO(abidlabs): remove this, we shouldn't need to kill
for proc in process_iter():
try:
for conns in proc.connections(kind='inet'):

View File

@ -16,6 +16,7 @@ class AbstractOutput(ABC):
def __init__(self, postprocessing_fn=None):
"""
:param postprocessing_fn: an optional postprocessing function that overrides the default
"""
if postprocessing_fn is not None:
self.postprocess = postprocessing_fn

44
test/test_networking.py Normal file
View File

@ -0,0 +1,44 @@
import unittest
from gradio import networking
import socket
import tempfile
import os
LOCALHOST_NAME = 'localhost'
class TestGetAvailablePort(unittest.TestCase):
def test_get_first_available_port_by_blocking_port(self):
initial = 7000
final = 8000
port_found = False
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
port_found = True
break
except OSError:
pass
if port_found:
s = socket.socket() # create a socket object
s.bind((LOCALHOST_NAME, port)) # Bind to the port
new_port = networking.get_first_available_port(initial, final)
s.close()
self.assertFalse(port==new_port)
class TestCopyFiles(unittest.TestCase):
def test_copy_files(self):
filename = "a.txt"
with tempfile.TemporaryDirectory() as temp_src:
with open(os.path.join(temp_src, "a.txt"), "w+") as f:
f.write('Hi')
with tempfile.TemporaryDirectory() as temp_dest:
self.assertFalse(os.path.exists(os.path.join(temp_dest, filename)))
networking.copy_files(temp_src, temp_dest)
self.assertTrue(os.path.exists(os.path.join(temp_dest, filename)))
if __name__ == '__main__':
unittest.main()