mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
added tests for networking
This commit is contained in:
parent
395f3df686
commit
8857383131
@ -19,6 +19,9 @@ class AbstractInput(ABC):
|
||||
"""
|
||||
|
||||
def __init__(self, preprocessing_fn=None):
|
||||
"""
|
||||
:param preprocessing_fn: an optional preprocessing function that overrides the default
|
||||
"""
|
||||
if preprocessing_fn is not None:
|
||||
if not callable(preprocessing_fn):
|
||||
raise ValueError('`preprocessing_fn` must be a callable function')
|
||||
|
@ -20,10 +20,9 @@ import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
import shutil
|
||||
|
||||
INITIAL_PORT_VALUE = 7860
|
||||
TRY_NUM_PORTS = 100
|
||||
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
|
||||
LOCALHOST_NAME = 'localhost'
|
||||
LOCALHOST_PREFIX = 'localhost:'
|
||||
NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output)
|
||||
NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output)
|
||||
|
||||
@ -45,6 +44,12 @@ NGROK_ZIP_URLS = {
|
||||
|
||||
|
||||
def build_template(temp_dir, input_interface, output_interface):
|
||||
"""
|
||||
Builds a complete HTML template with supporting JS and CSS files in a given directory.
|
||||
:param temp_dir: string with path to temp directory in which the template should be built
|
||||
:param input_interface: an AbstractInput object which includes is used to get the input template
|
||||
:param output_interface: an AbstractInput object which includes is used to get the input template
|
||||
"""
|
||||
input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path())
|
||||
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
|
||||
input_page = open(input_template_path)
|
||||
@ -65,10 +70,14 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
|
||||
copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP))
|
||||
copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP))
|
||||
return
|
||||
|
||||
|
||||
def copy_files(src_dir, dest_dir):
|
||||
"""
|
||||
Copies all the files from one directory to another
|
||||
:param src_dir: string path to source directory
|
||||
:param dest_dir: string path to destination directory
|
||||
"""
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
src_files = os.listdir(src_dir)
|
||||
@ -100,6 +109,12 @@ def set_socket_port_in_js(temp_dir, socket_port):
|
||||
|
||||
|
||||
def get_first_available_port(initial, final):
|
||||
"""
|
||||
Gets the first open port in a specified range of port numbers
|
||||
:param initial: the initial value in the range of port numbers
|
||||
:param final: final (exclusive) value in the range of port numbers, should be greater than `initial`
|
||||
:return:
|
||||
"""
|
||||
for port in range(initial, final):
|
||||
try:
|
||||
s = socket.socket() # create a socket object
|
||||
@ -144,8 +159,7 @@ def serve_files_in_background(port, directory_to_serve=None):
|
||||
|
||||
|
||||
def start_simple_server(directory_to_serve=None):
|
||||
# TODO(abidlabs): increment port number until free port is found
|
||||
port = get_first_available_port (INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
|
||||
port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
|
||||
serve_files_in_background(port, directory_to_serve)
|
||||
return port
|
||||
|
||||
@ -179,7 +193,7 @@ def create_ngrok_tunnel(local_port, api_url):
|
||||
session.mount('https://', adapter)
|
||||
r = session.get(api_url)
|
||||
for tunnel in r.json()['tunnels']:
|
||||
if LOCALHOST_PREFIX + str(local_port) in tunnel['config']['addr']:
|
||||
if '{}:'.format(LOCALHOST_NAME) + str(local_port) in tunnel['config']['addr']:
|
||||
return tunnel['public_url']
|
||||
raise RuntimeError("Not able to retrieve ngrok public URL")
|
||||
|
||||
@ -192,7 +206,7 @@ def setup_ngrok(server_port, websocket_port, output_directory):
|
||||
return site_ngrok_url
|
||||
|
||||
|
||||
def kill_processes(process_ids):
|
||||
def kill_processes(process_ids): #TODO(abidlabs): remove this, we shouldn't need to kill
|
||||
for proc in process_iter():
|
||||
try:
|
||||
for conns in proc.connections(kind='inet'):
|
||||
|
@ -16,6 +16,7 @@ class AbstractOutput(ABC):
|
||||
|
||||
def __init__(self, postprocessing_fn=None):
|
||||
"""
|
||||
:param postprocessing_fn: an optional postprocessing function that overrides the default
|
||||
"""
|
||||
if postprocessing_fn is not None:
|
||||
self.postprocess = postprocessing_fn
|
||||
|
44
test/test_networking.py
Normal file
44
test/test_networking.py
Normal file
@ -0,0 +1,44 @@
|
||||
import unittest
|
||||
from gradio import networking
|
||||
import socket
|
||||
import tempfile
|
||||
import os
|
||||
LOCALHOST_NAME = 'localhost'
|
||||
|
||||
|
||||
class TestGetAvailablePort(unittest.TestCase):
|
||||
def test_get_first_available_port_by_blocking_port(self):
|
||||
initial = 7000
|
||||
final = 8000
|
||||
port_found = False
|
||||
for port in range(initial, final):
|
||||
try:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.close()
|
||||
port_found = True
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
if port_found:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((LOCALHOST_NAME, port)) # Bind to the port
|
||||
new_port = networking.get_first_available_port(initial, final)
|
||||
s.close()
|
||||
self.assertFalse(port==new_port)
|
||||
|
||||
|
||||
class TestCopyFiles(unittest.TestCase):
|
||||
def test_copy_files(self):
|
||||
filename = "a.txt"
|
||||
with tempfile.TemporaryDirectory() as temp_src:
|
||||
with open(os.path.join(temp_src, "a.txt"), "w+") as f:
|
||||
f.write('Hi')
|
||||
with tempfile.TemporaryDirectory() as temp_dest:
|
||||
self.assertFalse(os.path.exists(os.path.join(temp_dest, filename)))
|
||||
networking.copy_files(temp_src, temp_dest)
|
||||
self.assertTrue(os.path.exists(os.path.join(temp_dest, filename)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user