mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-07 11:46:51 +08:00
reorganized files to clean up interface.py
This commit is contained in:
parent
a9ac43c706
commit
6d91825899
@ -22,18 +22,18 @@ class AbstractInput(ABC):
|
||||
if preprocessing_fn is not None:
|
||||
if not callable(preprocessing_fn):
|
||||
raise ValueError('`preprocessing_fn` must be a callable function')
|
||||
self._preprocess = preprocessing_fn
|
||||
self.preprocess = preprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
"""
|
||||
All interfaces should define a method that returns the path to its template.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _preprocess(self, inp):
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
All interfaces should define a default preprocessing method
|
||||
"""
|
||||
@ -42,10 +42,10 @@ class AbstractInput(ABC):
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
return 'templates/sketchpad_input.html'
|
||||
|
||||
def _preprocess(self, inp):
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
|
||||
"""
|
||||
@ -59,10 +59,10 @@ class Sketchpad(AbstractInput):
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
return 'templates/webcam_input.html'
|
||||
|
||||
def _preprocess(self, inp):
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
@ -76,10 +76,10 @@ class Webcam(AbstractInput):
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
return 'templates/textbox_input.html'
|
||||
|
||||
def _preprocess(self, inp):
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
By default, no pre-processing is applied to text.
|
||||
"""
|
||||
@ -88,10 +88,10 @@ class Textbox(AbstractInput):
|
||||
|
||||
class ImageUpload(AbstractInput):
|
||||
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
return 'templates/image_upload_input.html'
|
||||
|
||||
def _preprocess(self, inp):
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
|
@ -7,13 +7,9 @@ import asyncio
|
||||
import websockets
|
||||
import nest_asyncio
|
||||
import webbrowser
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
from gradio import inputs
|
||||
from gradio import outputs
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
from gradio import networking
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
nest_asyncio.apply()
|
||||
@ -23,43 +19,41 @@ INITIAL_WEBSOCKET_PORT = 9200
|
||||
TRY_NUM_PORTS = 100
|
||||
|
||||
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html')
|
||||
JS_PATH_LIB = pkg_resources.resource_filename('gradio', 'js/')
|
||||
CSS_PATH_LIB = pkg_resources.resource_filename('gradio', 'css/')
|
||||
JS_PATH_TEMP = 'js/'
|
||||
CSS_PATH_TEMP = 'css/'
|
||||
TEMPLATE_TEMP = 'interface.html'
|
||||
BASE_JS_FILE = 'js/all-io.js'
|
||||
|
||||
|
||||
class Interface:
|
||||
"""
|
||||
The Interface class represents a general input/output interface for a machine learning model. During construction,
|
||||
the appropriate inputs and outputs
|
||||
"""
|
||||
|
||||
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
|
||||
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'}
|
||||
|
||||
def __init__(self, input, output, model, model_type=None, preprocessing_fn=None, postprocessing_fn=None):
|
||||
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
|
||||
verbose=True):
|
||||
"""
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
|
||||
:param inputs: a string representing the input interface.
|
||||
:param outputs: a string representing the output interface.
|
||||
:param model_obj: the model object, such as a sklearn classifier or keras model.
|
||||
:param model_params: additional model parameters.
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
|
||||
provided.
|
||||
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
|
||||
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
|
||||
"""
|
||||
self.input_interface = inputs.registry[input](preprocessing_fn)
|
||||
self.output_interface = outputs.registry[output](postprocessing_fn)
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
|
||||
self.model_obj = model
|
||||
if model_type is None:
|
||||
model_type = self._infer_model_type(model)
|
||||
if model_type is None:
|
||||
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
|
||||
else:
|
||||
if verbose:
|
||||
print("Model type not explicitly identified, inferred to be: {}".format(
|
||||
self.VALID_MODEL_TYPES[model_type]))
|
||||
self.VALID_MODEL_TYPES[model_type]))
|
||||
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
self.model_type = model_type
|
||||
|
||||
def _infer_model_type(self, model):
|
||||
@staticmethod
|
||||
def _infer_model_type(model):
|
||||
""" Helper method that attempts to identify the type of trained ML model."""
|
||||
try:
|
||||
import sklearn
|
||||
if isinstance(model, sklearn.base.BaseEstimator):
|
||||
@ -84,124 +78,75 @@ class Interface:
|
||||
if callable(model):
|
||||
return 'function'
|
||||
|
||||
return None
|
||||
|
||||
def _build_template(self, temp_dir):
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
'gradio', self.input_interface._get_template_path())
|
||||
output_template_path = pkg_resources.resource_filename(
|
||||
'gradio', self.output_interface._get_template_path())
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
|
||||
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||
input_tag = all_io_soup.find("div", {"id": "input"})
|
||||
output_tag = all_io_soup.find("div", {"id": "output"})
|
||||
|
||||
input_tag.replace_with(input_soup)
|
||||
output_tag.replace_with(output_soup)
|
||||
|
||||
f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w")
|
||||
f.write(str(all_io_soup.prettify))
|
||||
|
||||
self._copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP))
|
||||
self._copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP))
|
||||
return
|
||||
|
||||
def _copy_files(self, src_dir, dest_dir):
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
src_files = os.listdir(src_dir)
|
||||
for file_name in src_files:
|
||||
full_file_name = os.path.join(src_dir, file_name)
|
||||
if os.path.isfile(full_file_name):
|
||||
shutil.copy(full_file_name, dest_dir)
|
||||
|
||||
def _set_socket_url_in_js(self, temp_dir, socket_url):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
def _set_socket_port_in_js(self, temp_dir, socket_port):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
def predict(self, array):
|
||||
if self.model_type=='sklearn':
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='keras':
|
||||
return self.model_obj.predict(array)
|
||||
elif self.model_type=='function':
|
||||
return self.model_obj(array)
|
||||
else:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
|
||||
|
||||
async def communicate(self, websocket, path):
|
||||
"""
|
||||
Method that defines how this interface communicates with the websocket.
|
||||
:param websocket: a Websocket object used to communicate with the interface frontend
|
||||
:param path: ignored
|
||||
Method that defines how this interface should communicates with the websocket. (1) When an input is received by
|
||||
the websocket, it is passed into the input interface and preprocssed. (2) Then the model is called to make a
|
||||
prediction. (3) Finally, the prediction is postprocessed to get something to be displayed by the output.
|
||||
:param websocket: a Websocket server used to communicate with the interface frontend
|
||||
:param path: not used, but required for compliance with websocket library
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
processed_input = self.input_interface._pre_process(msg)
|
||||
processed_input = self.input_interface.preprocess(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
processed_output = self.output_interface._post_process(prediction)
|
||||
processed_output = self.output_interface.postprocess(prediction)
|
||||
await websocket.send(str(processed_output))
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
||||
def launch(self, share_link=False, verbose=True):
|
||||
def predict(self, preprocessed_input):
|
||||
"""
|
||||
Standard method shared by interfaces that launches a websocket at a specified IP address.
|
||||
Method that calls the relevant method of the model object to make a prediction.
|
||||
:param preprocessed_input: the preprocessed input returned by the input interface
|
||||
"""
|
||||
if self.model_type=='sklearn':
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type=='keras':
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type=='function':
|
||||
return self.model_obj(preprocessed_input)
|
||||
else:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
|
||||
def launch(self, share=False):
|
||||
"""
|
||||
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
|
||||
:param share: boolean. If True, then a share link is generated using ngrok is displayed to the user.
|
||||
"""
|
||||
output_directory = tempfile.mkdtemp()
|
||||
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port = networking.start_simple_server(output_directory)
|
||||
path_to_server = 'http://localhost:{}/'.format(server_port)
|
||||
self._build_template(output_directory)
|
||||
networking.build_template(output_directory, self.input_interface, self.output_interface)
|
||||
|
||||
ports_in_use = networking.get_ports_in_use(INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
|
||||
for i in range(TRY_NUM_PORTS):
|
||||
if not ((INITIAL_WEBSOCKET_PORT + i) in ports_in_use):
|
||||
break
|
||||
else:
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(
|
||||
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS))
|
||||
|
||||
start_server = websockets.serve(self.communicate, LOCALHOST_IP, INITIAL_WEBSOCKET_PORT + i)
|
||||
self._set_socket_port_in_js(output_directory, INITIAL_WEBSOCKET_PORT + i)
|
||||
if verbose:
|
||||
# Set up a port to serve a websocket that sets up the communication between the front-end and model.
|
||||
websocket_port = networking.get_first_available_port(
|
||||
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
|
||||
start_server = websockets.serve(self.communicate, LOCALHOST_IP, websocket_port)
|
||||
networking.set_socket_port_in_js(output_directory, websocket_port) # sets the websocket port in the JS file.
|
||||
if self.verbose:
|
||||
print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu")
|
||||
print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP))
|
||||
print("Model available locally at: {}".format(path_to_server + networking.TEMPLATE_TEMP))
|
||||
|
||||
if share_link:
|
||||
networking.kill_processes([4040, 4041])
|
||||
site_ngrok_url = networking.setup_ngrok(server_port)
|
||||
socket_ngrok_url = networking.setup_ngrok(INITIAL_WEBSOCKET_PORT, api_url=networking.NGROK_TUNNELS_API_URL2)
|
||||
self._set_socket_url_in_js(output_directory, socket_ngrok_url)
|
||||
if verbose:
|
||||
print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP))
|
||||
if share:
|
||||
site_ngrok_url = networking.setup_ngrok(server_port, websocket_port, output_directory)
|
||||
if self.verbose:
|
||||
print("Model available publicly for 8 hours at: {}".format(
|
||||
site_ngrok_url + '/' + networking.TEMPLATE_TEMP))
|
||||
else:
|
||||
if verbose:
|
||||
print("To create a public link, set `share_link=True` in the argument to `launch()`")
|
||||
if self.verbose:
|
||||
print("To create a public link, set `share=True` in the argument to `launch()`")
|
||||
|
||||
# Keep the server running in the background.
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
try:
|
||||
asyncio.get_event_loop().run_forever()
|
||||
except RuntimeError: # Runtime errors are thrown in jupyter notebooks because of async.
|
||||
pass
|
||||
|
||||
webbrowser.open(path_to_server + TEMPLATE_TEMP)
|
||||
webbrowser.open(path_to_server + networking.TEMPLATE_TEMP) # Open a browser tab with the interface.
|
||||
|
@ -16,6 +16,9 @@ from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
|
||||
import stat
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
import shutil
|
||||
|
||||
INITIAL_PORT_VALUE = 7860
|
||||
TRY_NUM_PORTS = 100
|
||||
@ -24,6 +27,16 @@ LOCALHOST_PREFIX = 'localhost:'
|
||||
NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output)
|
||||
NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output)
|
||||
|
||||
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html')
|
||||
JS_PATH_LIB = pkg_resources.resource_filename('gradio', 'js/')
|
||||
CSS_PATH_LIB = pkg_resources.resource_filename('gradio', 'css/')
|
||||
JS_PATH_TEMP = 'js/'
|
||||
CSS_PATH_TEMP = 'css/'
|
||||
TEMPLATE_TEMP = 'interface.html'
|
||||
BASE_JS_FILE = 'js/all-io.js'
|
||||
|
||||
|
||||
NGROK_ZIP_URLS = {
|
||||
"linux": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip",
|
||||
"darwin": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-darwin-amd64.zip",
|
||||
@ -31,48 +44,71 @@ NGROK_ZIP_URLS = {
|
||||
}
|
||||
|
||||
|
||||
def get_ports_in_use(start, stop):
|
||||
ports_in_use = []
|
||||
for port in range(start, stop):
|
||||
def build_template(temp_dir, input_interface, output_interface):
|
||||
input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path())
|
||||
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
|
||||
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||
input_tag = all_io_soup.find("div", {"id": "input"})
|
||||
output_tag = all_io_soup.find("div", {"id": "output"})
|
||||
|
||||
input_tag.replace_with(input_soup)
|
||||
output_tag.replace_with(output_soup)
|
||||
|
||||
f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w")
|
||||
f.write(str(all_io_soup.prettify))
|
||||
|
||||
copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP))
|
||||
copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP))
|
||||
return
|
||||
|
||||
|
||||
def copy_files(src_dir, dest_dir):
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
src_files = os.listdir(src_dir)
|
||||
for file_name in src_files:
|
||||
full_file_name = os.path.join(src_dir, file_name)
|
||||
if os.path.isfile(full_file_name):
|
||||
shutil.copy(full_file_name, dest_dir)
|
||||
|
||||
def set_socket_url_in_js(temp_dir, socket_url):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws'))
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
def set_socket_port_in_js(temp_dir, socket_port):
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin:
|
||||
lines = fin.readlines()
|
||||
lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port)
|
||||
|
||||
with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
|
||||
def get_first_available_port(initial, final):
|
||||
for port in range(initial, final):
|
||||
try:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.close()
|
||||
return port
|
||||
except OSError:
|
||||
ports_in_use.append(port)
|
||||
return ports_in_use
|
||||
# ports_in_use = []
|
||||
# try:
|
||||
# for proc in process_iter():
|
||||
# for conns in proc.connections(kind='inet'):
|
||||
# ports_in_use.append(conns.laddr.port)
|
||||
# except AccessDenied:
|
||||
# pass # TODO(abidlabs): somehow find a way to handle this issue?
|
||||
# return ports_in_use
|
||||
pass
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(initial, final))
|
||||
|
||||
|
||||
def serve_files_in_background(port, directory_to_serve=None):
|
||||
# class Handler(http.server.SimpleHTTPRequestHandler):
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# super().__init__(*args, directory=directory_to_serve, **kwargs)
|
||||
#
|
||||
# server = socketserver.ThreadingTCPServer(('localhost', port), Handler)
|
||||
# # Ensures that Ctrl-C cleanly kills all spawned threads
|
||||
# server.daemon_threads = True
|
||||
# # Quicker rebinding
|
||||
# server.allow_reuse_address = True
|
||||
#
|
||||
# # A custom signal handle to allow us to Ctrl-C out of the process
|
||||
# def signal_handler(signal, frame):
|
||||
# print('Exiting http server (Ctrl+C pressed)')
|
||||
# try:
|
||||
# if (server):
|
||||
# server.server_close()
|
||||
# finally:
|
||||
# sys.exit(0)
|
||||
#
|
||||
# # Install the keyboard interrupt handler
|
||||
# signal.signal(signal.SIGINT, signal_handler)
|
||||
class HTTPHandler(SimpleHTTPRequestHandler):
|
||||
"""This handler uses server.base_path instead of always using os.getcwd()"""
|
||||
|
||||
@ -106,20 +142,9 @@ def serve_files_in_background(port, directory_to_serve=None):
|
||||
|
||||
def start_simple_server(directory_to_serve=None):
|
||||
# TODO(abidlabs): increment port number until free port is found
|
||||
ports_in_use = get_ports_in_use(start=INITIAL_PORT_VALUE, stop=INITIAL_PORT_VALUE + TRY_NUM_PORTS)
|
||||
for i in range(TRY_NUM_PORTS):
|
||||
if not((INITIAL_PORT_VALUE + i) in ports_in_use):
|
||||
break
|
||||
else:
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS))
|
||||
serve_files_in_background(INITIAL_PORT_VALUE + i, directory_to_serve)
|
||||
# if directory_to_serve is None:
|
||||
# subprocess.Popen(['python', '-m', 'http.server', str(INITIAL_PORT_VALUE + i)])
|
||||
# else:
|
||||
# cmd = ' '.join(['python', '-m', 'http.server', '-d', directory_to_serve, str(INITIAL_PORT_VALUE + i)])
|
||||
# subprocess.Popen(cmd, shell=True) # Doesn't seem to work if list is passed for some reason.
|
||||
return INITIAL_PORT_VALUE + i
|
||||
port = get_first_available_port (INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
|
||||
serve_files_in_background(port, directory_to_serve)
|
||||
return port
|
||||
|
||||
|
||||
def download_ngrok():
|
||||
@ -137,7 +162,7 @@ def download_ngrok():
|
||||
os.chmod('ngrok', st.st_mode | stat.S_IEXEC)
|
||||
|
||||
|
||||
def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL):
|
||||
def create_ngrok_tunnel(local_port, api_url):
|
||||
if not(os.path.isfile('ngrok.exe') or os.path.isfile('ngrok')):
|
||||
download_ngrok()
|
||||
if sys.platform == 'win32':
|
||||
@ -156,6 +181,13 @@ def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL):
|
||||
raise RuntimeError("Not able to retrieve ngrok public URL")
|
||||
|
||||
|
||||
def setup_ngrok(server_port, websocket_port, output_directory):
|
||||
site_ngrok_url = create_ngrok_tunnel(server_port, NGROK_TUNNELS_API_URL)
|
||||
socket_ngrok_url = create_ngrok_tunnel(websocket_port, NGROK_TUNNELS_API_URL2)
|
||||
set_socket_url_in_js(output_directory, socket_ngrok_url)
|
||||
return site_ngrok_url
|
||||
|
||||
|
||||
def kill_processes(process_ids):
|
||||
for proc in process_iter():
|
||||
try:
|
||||
|
@ -18,18 +18,18 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
"""
|
||||
if postprocessing_fn is not None:
|
||||
self._postprocess = postprocessing_fn
|
||||
self.postprocess = postprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
"""
|
||||
All interfaces should define a method that returns the path to its template.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _postprocess(self, prediction):
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
All interfaces should define a default postprocessing method
|
||||
"""
|
||||
@ -38,10 +38,10 @@ class AbstractOutput(ABC):
|
||||
|
||||
class Label(AbstractOutput):
|
||||
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
return 'templates/label_output.html'
|
||||
|
||||
def _postprocess(self, prediction):
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
if isinstance(prediction, np.ndarray):
|
||||
@ -58,10 +58,10 @@ class Label(AbstractOutput):
|
||||
|
||||
class Textbox(AbstractOutput):
|
||||
|
||||
def _get_template_path(self):
|
||||
def get_template_path(self):
|
||||
return 'templates/textbox_output.html'
|
||||
|
||||
def _postprocess(self, prediction):
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
return prediction
|
||||
|
@ -10,48 +10,48 @@ PACKAGE_NAME = 'gradio'
|
||||
class TestSketchpad(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Sketchpad()
|
||||
path = inp._get_template_path()
|
||||
path = inp.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Sketchpad()
|
||||
array = inp._preprocess(BASE64_IMG)
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 28, 28, 1))
|
||||
|
||||
|
||||
class TestWebcam(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Webcam()
|
||||
path = inp._get_template_path()
|
||||
path = inp.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Webcam()
|
||||
array = inp._preprocess(BASE64_IMG)
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Textbox()
|
||||
path = inp._get_template_path()
|
||||
path = inp.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Textbox()
|
||||
string = inp._preprocess(RAND_STRING)
|
||||
string = inp.preprocess(RAND_STRING)
|
||||
self.assertEqual(string, RAND_STRING)
|
||||
|
||||
|
||||
class TestImageUpload(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.ImageUpload()
|
||||
path = inp._get_template_path()
|
||||
path = inp.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.ImageUpload()
|
||||
array = inp._preprocess(BASE64_IMG)
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
|
||||
|
||||
|
0
test/test_interface.py
Normal file
0
test/test_interface.py
Normal file
@ -9,40 +9,40 @@ PACKAGE_NAME = 'gradio'
|
||||
class TestLabel(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Label()
|
||||
path = out._get_template_path()
|
||||
path = out.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing_string(self):
|
||||
string = 'happy'
|
||||
out = outputs.Label()
|
||||
label = out._postprocess(string)
|
||||
label = out.postprocess(string)
|
||||
self.assertEqual(label, string)
|
||||
|
||||
def test_postprocessing_one_hot(self):
|
||||
one_hot = np.array([0, 0, 0, 1, 0])
|
||||
true_label = 3
|
||||
out = outputs.Label()
|
||||
label = out._postprocess(one_hot)
|
||||
label = out.postprocess(one_hot)
|
||||
self.assertEqual(label, true_label)
|
||||
|
||||
def test_postprocessing_int(self):
|
||||
true_label_array = np.array([[[3]]])
|
||||
true_label = 3
|
||||
out = outputs.Label()
|
||||
label = out._postprocess(true_label_array)
|
||||
label = out.postprocess(true_label_array)
|
||||
self.assertEqual(label, true_label)
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Textbox()
|
||||
path = out._get_template_path()
|
||||
path = out.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing(self):
|
||||
string = 'happy'
|
||||
out = outputs.Textbox()
|
||||
string = out._postprocess(string)
|
||||
string = out.postprocess(string)
|
||||
self.assertEqual(string, string)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user