diff --git a/gradio/inputs.py b/gradio/inputs.py index f57d8731e0..4635d965cf 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -22,18 +22,18 @@ class AbstractInput(ABC): if preprocessing_fn is not None: if not callable(preprocessing_fn): raise ValueError('`preprocessing_fn` must be a callable function') - self._preprocess = preprocessing_fn + self.preprocess = preprocessing_fn super().__init__() @abstractmethod - def _get_template_path(self): + def get_template_path(self): """ All interfaces should define a method that returns the path to its template. """ pass @abstractmethod - def _preprocess(self, inp): + def preprocess(self, inp): """ All interfaces should define a default preprocessing method """ @@ -42,10 +42,10 @@ class AbstractInput(ABC): class Sketchpad(AbstractInput): - def _get_template_path(self): + def get_template_path(self): return 'templates/sketchpad_input.html' - def _preprocess(self, inp): + def preprocess(self, inp): """ Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28 """ @@ -59,10 +59,10 @@ class Sketchpad(AbstractInput): class Webcam(AbstractInput): - def _get_template_path(self): + def get_template_path(self): return 'templates/webcam_input.html' - def _preprocess(self, inp): + def preprocess(self, inp): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ @@ -76,10 +76,10 @@ class Webcam(AbstractInput): class Textbox(AbstractInput): - def _get_template_path(self): + def get_template_path(self): return 'templates/textbox_input.html' - def _preprocess(self, inp): + def preprocess(self, inp): """ By default, no pre-processing is applied to text. """ @@ -88,10 +88,10 @@ class Textbox(AbstractInput): class ImageUpload(AbstractInput): - def _get_template_path(self): + def get_template_path(self): return 'templates/image_upload_input.html' - def _preprocess(self, inp): + def preprocess(self, inp): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ diff --git a/gradio/interface.py b/gradio/interface.py index 29fc8719f5..bbf03a62e5 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -7,13 +7,9 @@ import asyncio import websockets import nest_asyncio import webbrowser -import pkg_resources -from bs4 import BeautifulSoup -from gradio import inputs -from gradio import outputs +import gradio.inputs +import gradio.outputs from gradio import networking -import os -import shutil import tempfile nest_asyncio.apply() @@ -23,43 +19,41 @@ INITIAL_WEBSOCKET_PORT = 9200 TRY_NUM_PORTS = 100 -BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html') -JS_PATH_LIB = pkg_resources.resource_filename('gradio', 'js/') -CSS_PATH_LIB = pkg_resources.resource_filename('gradio', 'css/') -JS_PATH_TEMP = 'js/' -CSS_PATH_TEMP = 'css/' -TEMPLATE_TEMP = 'interface.html' -BASE_JS_FILE = 'js/all-io.js' - - class Interface: """ + The Interface class represents a general input/output interface for a machine learning model. During construction, + the appropriate inputs and outputs """ # Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description. VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'keras model', 'function': 'python function'} - def __init__(self, input, output, model, model_type=None, preprocessing_fn=None, postprocessing_fn=None): + def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None, + verbose=True): """ - :param model_type: what kind of trained model, can be 'keras' or 'sklearn'. + :param inputs: a string representing the input interface. + :param outputs: a string representing the output interface. :param model_obj: the model object, such as a sklearn classifier or keras model. - :param model_params: additional model parameters. + :param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not + provided. + :param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface. + :param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface. """ - self.input_interface = inputs.registry[input](preprocessing_fn) - self.output_interface = outputs.registry[output](postprocessing_fn) + self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns) + self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns) self.model_obj = model if model_type is None: model_type = self._infer_model_type(model) - if model_type is None: - raise ValueError("model_type could not be inferred, please specify parameter `model_type`") - else: + if verbose: print("Model type not explicitly identified, inferred to be: {}".format( - self.VALID_MODEL_TYPES[model_type])) + self.VALID_MODEL_TYPES[model_type])) elif not(model_type.lower() in self.VALID_MODEL_TYPES): ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES)) self.model_type = model_type - def _infer_model_type(self, model): + @staticmethod + def _infer_model_type(model): + """ Helper method that attempts to identify the type of trained ML model.""" try: import sklearn if isinstance(model, sklearn.base.BaseEstimator): @@ -84,124 +78,75 @@ class Interface: if callable(model): return 'function' - return None - - def _build_template(self, temp_dir): - input_template_path = pkg_resources.resource_filename( - 'gradio', self.input_interface._get_template_path()) - output_template_path = pkg_resources.resource_filename( - 'gradio', self.output_interface._get_template_path()) - input_page = open(input_template_path) - output_page = open(output_template_path) - input_soup = BeautifulSoup(input_page.read(), features="html.parser") - output_soup = BeautifulSoup(output_page.read(), features="html.parser") - - all_io_page = open(BASE_TEMPLATE) - all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser") - input_tag = all_io_soup.find("div", {"id": "input"}) - output_tag = all_io_soup.find("div", {"id": "output"}) - - input_tag.replace_with(input_soup) - output_tag.replace_with(output_soup) - - f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w") - f.write(str(all_io_soup.prettify)) - - self._copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP)) - self._copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP)) - return - - def _copy_files(self, src_dir, dest_dir): - if not os.path.exists(dest_dir): - os.makedirs(dest_dir) - src_files = os.listdir(src_dir) - for file_name in src_files: - full_file_name = os.path.join(src_dir, file_name) - if os.path.isfile(full_file_name): - shutil.copy(full_file_name, dest_dir) - - def _set_socket_url_in_js(self, temp_dir, socket_url): - with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin: - lines = fin.readlines() - lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws')) - - with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout: - for line in lines: - fout.write(line) - - def _set_socket_port_in_js(self, temp_dir, socket_port): - with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin: - lines = fin.readlines() - lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port) - - with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout: - for line in lines: - fout.write(line) - - def predict(self, array): - if self.model_type=='sklearn': - return self.model_obj.predict(array) - elif self.model_type=='keras': - return self.model_obj.predict(array) - elif self.model_type=='function': - return self.model_obj(array) - else: - ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES)) + raise ValueError("model_type could not be inferred, please specify parameter `model_type`") async def communicate(self, websocket, path): """ - Method that defines how this interface communicates with the websocket. - :param websocket: a Websocket object used to communicate with the interface frontend - :param path: ignored + Method that defines how this interface should communicates with the websocket. (1) When an input is received by + the websocket, it is passed into the input interface and preprocssed. (2) Then the model is called to make a + prediction. (3) Finally, the prediction is postprocessed to get something to be displayed by the output. + :param websocket: a Websocket server used to communicate with the interface frontend + :param path: not used, but required for compliance with websocket library """ while True: try: msg = await websocket.recv() - processed_input = self.input_interface._pre_process(msg) + processed_input = self.input_interface.preprocess(msg) prediction = self.predict(processed_input) - processed_output = self.output_interface._post_process(prediction) + processed_output = self.output_interface.postprocess(prediction) await websocket.send(str(processed_output)) except websockets.exceptions.ConnectionClosed: pass - def launch(self, share_link=False, verbose=True): + def predict(self, preprocessed_input): """ - Standard method shared by interfaces that launches a websocket at a specified IP address. + Method that calls the relevant method of the model object to make a prediction. + :param preprocessed_input: the preprocessed input returned by the input interface + """ + if self.model_type=='sklearn': + return self.model_obj.predict(preprocessed_input) + elif self.model_type=='keras': + return self.model_obj.predict(preprocessed_input) + elif self.model_type=='function': + return self.model_obj(preprocessed_input) + else: + ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES)) + + def launch(self, share=False): + """ + Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it. + :param share: boolean. If True, then a share link is generated using ngrok is displayed to the user. """ output_directory = tempfile.mkdtemp() + + # Set up a port to serve the directory containing the static files with interface. server_port = networking.start_simple_server(output_directory) path_to_server = 'http://localhost:{}/'.format(server_port) - self._build_template(output_directory) + networking.build_template(output_directory, self.input_interface, self.output_interface) - ports_in_use = networking.get_ports_in_use(INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS) - for i in range(TRY_NUM_PORTS): - if not ((INITIAL_WEBSOCKET_PORT + i) in ports_in_use): - break - else: - raise OSError("All ports from {} to {} are in use. Please close a port.".format( - INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)) - - start_server = websockets.serve(self.communicate, LOCALHOST_IP, INITIAL_WEBSOCKET_PORT + i) - self._set_socket_port_in_js(output_directory, INITIAL_WEBSOCKET_PORT + i) - if verbose: + # Set up a port to serve a websocket that sets up the communication between the front-end and model. + websocket_port = networking.get_first_available_port( + INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS) + start_server = websockets.serve(self.communicate, LOCALHOST_IP, websocket_port) + networking.set_socket_port_in_js(output_directory, websocket_port) # sets the websocket port in the JS file. + if self.verbose: print("NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu") - print("Model available locally at: {}".format(path_to_server + TEMPLATE_TEMP)) + print("Model available locally at: {}".format(path_to_server + networking.TEMPLATE_TEMP)) - if share_link: - networking.kill_processes([4040, 4041]) - site_ngrok_url = networking.setup_ngrok(server_port) - socket_ngrok_url = networking.setup_ngrok(INITIAL_WEBSOCKET_PORT, api_url=networking.NGROK_TUNNELS_API_URL2) - self._set_socket_url_in_js(output_directory, socket_ngrok_url) - if verbose: - print("Model available publicly for 8 hours at: {}".format(site_ngrok_url + '/' + TEMPLATE_TEMP)) + if share: + site_ngrok_url = networking.setup_ngrok(server_port, websocket_port, output_directory) + if self.verbose: + print("Model available publicly for 8 hours at: {}".format( + site_ngrok_url + '/' + networking.TEMPLATE_TEMP)) else: - if verbose: - print("To create a public link, set `share_link=True` in the argument to `launch()`") + if self.verbose: + print("To create a public link, set `share=True` in the argument to `launch()`") + # Keep the server running in the background. asyncio.get_event_loop().run_until_complete(start_server) try: asyncio.get_event_loop().run_forever() except RuntimeError: # Runtime errors are thrown in jupyter notebooks because of async. pass - webbrowser.open(path_to_server + TEMPLATE_TEMP) + webbrowser.open(path_to_server + networking.TEMPLATE_TEMP) # Open a browser tab with the interface. diff --git a/gradio/networking.py b/gradio/networking.py index 905b4d44b6..67074587bd 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -16,6 +16,9 @@ from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler import stat from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry +import pkg_resources +from bs4 import BeautifulSoup +import shutil INITIAL_PORT_VALUE = 7860 TRY_NUM_PORTS = 100 @@ -24,6 +27,16 @@ LOCALHOST_PREFIX = 'localhost:' NGROK_TUNNELS_API_URL = "http://localhost:4040/api/tunnels" # TODO(this should be captured from output) NGROK_TUNNELS_API_URL2 = "http://localhost:4041/api/tunnels" # TODO(this should be captured from output) + +BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html') +JS_PATH_LIB = pkg_resources.resource_filename('gradio', 'js/') +CSS_PATH_LIB = pkg_resources.resource_filename('gradio', 'css/') +JS_PATH_TEMP = 'js/' +CSS_PATH_TEMP = 'css/' +TEMPLATE_TEMP = 'interface.html' +BASE_JS_FILE = 'js/all-io.js' + + NGROK_ZIP_URLS = { "linux": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip", "darwin": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-darwin-amd64.zip", @@ -31,48 +44,71 @@ NGROK_ZIP_URLS = { } -def get_ports_in_use(start, stop): - ports_in_use = [] - for port in range(start, stop): +def build_template(temp_dir, input_interface, output_interface): + input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path()) + output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path()) + input_page = open(input_template_path) + output_page = open(output_template_path) + input_soup = BeautifulSoup(input_page.read(), features="html.parser") + output_soup = BeautifulSoup(output_page.read(), features="html.parser") + + all_io_page = open(BASE_TEMPLATE) + all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser") + input_tag = all_io_soup.find("div", {"id": "input"}) + output_tag = all_io_soup.find("div", {"id": "output"}) + + input_tag.replace_with(input_soup) + output_tag.replace_with(output_soup) + + f = open(os.path.join(temp_dir, TEMPLATE_TEMP), "w") + f.write(str(all_io_soup.prettify)) + + copy_files(JS_PATH_LIB, os.path.join(temp_dir, JS_PATH_TEMP)) + copy_files(CSS_PATH_LIB, os.path.join(temp_dir, CSS_PATH_TEMP)) + return + + +def copy_files(src_dir, dest_dir): + if not os.path.exists(dest_dir): + os.makedirs(dest_dir) + src_files = os.listdir(src_dir) + for file_name in src_files: + full_file_name = os.path.join(src_dir, file_name) + if os.path.isfile(full_file_name): + shutil.copy(full_file_name, dest_dir) + +def set_socket_url_in_js(temp_dir, socket_url): + with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin: + lines = fin.readlines() + lines[0] = 'var NGROK_URL = "{}"\n'.format(socket_url.replace('http', 'ws')) + + with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout: + for line in lines: + fout.write(line) + +def set_socket_port_in_js(temp_dir, socket_port): + with open(os.path.join(temp_dir, BASE_JS_FILE)) as fin: + lines = fin.readlines() + lines[1] = 'var SOCKET_PORT = {}\n'.format(socket_port) + + with open(os.path.join(temp_dir, BASE_JS_FILE), 'w') as fout: + for line in lines: + fout.write(line) + + +def get_first_available_port(initial, final): + for port in range(initial, final): try: s = socket.socket() # create a socket object s.bind((LOCALHOST_NAME, port)) # Bind to the port s.close() + return port except OSError: - ports_in_use.append(port) - return ports_in_use - # ports_in_use = [] - # try: - # for proc in process_iter(): - # for conns in proc.connections(kind='inet'): - # ports_in_use.append(conns.laddr.port) - # except AccessDenied: - # pass # TODO(abidlabs): somehow find a way to handle this issue? - # return ports_in_use + pass + raise OSError("All ports from {} to {} are in use. Please close a port.".format(initial, final)) def serve_files_in_background(port, directory_to_serve=None): - # class Handler(http.server.SimpleHTTPRequestHandler): - # def __init__(self, *args, **kwargs): - # super().__init__(*args, directory=directory_to_serve, **kwargs) - # - # server = socketserver.ThreadingTCPServer(('localhost', port), Handler) - # # Ensures that Ctrl-C cleanly kills all spawned threads - # server.daemon_threads = True - # # Quicker rebinding - # server.allow_reuse_address = True - # - # # A custom signal handle to allow us to Ctrl-C out of the process - # def signal_handler(signal, frame): - # print('Exiting http server (Ctrl+C pressed)') - # try: - # if (server): - # server.server_close() - # finally: - # sys.exit(0) - # - # # Install the keyboard interrupt handler - # signal.signal(signal.SIGINT, signal_handler) class HTTPHandler(SimpleHTTPRequestHandler): """This handler uses server.base_path instead of always using os.getcwd()""" @@ -106,20 +142,9 @@ def serve_files_in_background(port, directory_to_serve=None): def start_simple_server(directory_to_serve=None): # TODO(abidlabs): increment port number until free port is found - ports_in_use = get_ports_in_use(start=INITIAL_PORT_VALUE, stop=INITIAL_PORT_VALUE + TRY_NUM_PORTS) - for i in range(TRY_NUM_PORTS): - if not((INITIAL_PORT_VALUE + i) in ports_in_use): - break - else: - raise OSError("All ports from {} to {} are in use. Please close a port.".format( - INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)) - serve_files_in_background(INITIAL_PORT_VALUE + i, directory_to_serve) - # if directory_to_serve is None: - # subprocess.Popen(['python', '-m', 'http.server', str(INITIAL_PORT_VALUE + i)]) - # else: - # cmd = ' '.join(['python', '-m', 'http.server', '-d', directory_to_serve, str(INITIAL_PORT_VALUE + i)]) - # subprocess.Popen(cmd, shell=True) # Doesn't seem to work if list is passed for some reason. - return INITIAL_PORT_VALUE + i + port = get_first_available_port (INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS) + serve_files_in_background(port, directory_to_serve) + return port def download_ngrok(): @@ -137,7 +162,7 @@ def download_ngrok(): os.chmod('ngrok', st.st_mode | stat.S_IEXEC) -def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL): +def create_ngrok_tunnel(local_port, api_url): if not(os.path.isfile('ngrok.exe') or os.path.isfile('ngrok')): download_ngrok() if sys.platform == 'win32': @@ -156,6 +181,13 @@ def setup_ngrok(local_port, api_url=NGROK_TUNNELS_API_URL): raise RuntimeError("Not able to retrieve ngrok public URL") +def setup_ngrok(server_port, websocket_port, output_directory): + site_ngrok_url = create_ngrok_tunnel(server_port, NGROK_TUNNELS_API_URL) + socket_ngrok_url = create_ngrok_tunnel(websocket_port, NGROK_TUNNELS_API_URL2) + set_socket_url_in_js(output_directory, socket_ngrok_url) + return site_ngrok_url + + def kill_processes(process_ids): for proc in process_iter(): try: diff --git a/gradio/outputs.py b/gradio/outputs.py index 926fbd3cf7..97449168d7 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -18,18 +18,18 @@ class AbstractOutput(ABC): """ """ if postprocessing_fn is not None: - self._postprocess = postprocessing_fn + self.postprocess = postprocessing_fn super().__init__() @abstractmethod - def _get_template_path(self): + def get_template_path(self): """ All interfaces should define a method that returns the path to its template. """ pass @abstractmethod - def _postprocess(self, prediction): + def postprocess(self, prediction): """ All interfaces should define a default postprocessing method """ @@ -38,10 +38,10 @@ class AbstractOutput(ABC): class Label(AbstractOutput): - def _get_template_path(self): + def get_template_path(self): return 'templates/label_output.html' - def _postprocess(self, prediction): + def postprocess(self, prediction): """ """ if isinstance(prediction, np.ndarray): @@ -58,10 +58,10 @@ class Label(AbstractOutput): class Textbox(AbstractOutput): - def _get_template_path(self): + def get_template_path(self): return 'templates/textbox_output.html' - def _postprocess(self, prediction): + def postprocess(self, prediction): """ """ return prediction diff --git a/test/test_inputs.py b/test/test_inputs.py index e8ef2a6fe4..f3e6d20f88 100644 --- a/test/test_inputs.py +++ b/test/test_inputs.py @@ -10,48 +10,48 @@ PACKAGE_NAME = 'gradio' class TestSketchpad(unittest.TestCase): def test_path_exists(self): inp = inputs.Sketchpad() - path = inp._get_template_path() + path = inp.get_template_path() self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_preprocessing(self): inp = inputs.Sketchpad() - array = inp._preprocess(BASE64_IMG) + array = inp.preprocess(BASE64_IMG) self.assertEqual(array.shape, (1, 28, 28, 1)) class TestWebcam(unittest.TestCase): def test_path_exists(self): inp = inputs.Webcam() - path = inp._get_template_path() + path = inp.get_template_path() self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_preprocessing(self): inp = inputs.Webcam() - array = inp._preprocess(BASE64_IMG) + array = inp.preprocess(BASE64_IMG) self.assertEqual(array.shape, (1, 48, 48, 1)) class TestTextbox(unittest.TestCase): def test_path_exists(self): inp = inputs.Textbox() - path = inp._get_template_path() + path = inp.get_template_path() self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_preprocessing(self): inp = inputs.Textbox() - string = inp._preprocess(RAND_STRING) + string = inp.preprocess(RAND_STRING) self.assertEqual(string, RAND_STRING) class TestImageUpload(unittest.TestCase): def test_path_exists(self): inp = inputs.ImageUpload() - path = inp._get_template_path() + path = inp.get_template_path() self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_preprocessing(self): inp = inputs.ImageUpload() - array = inp._preprocess(BASE64_IMG) + array = inp.preprocess(BASE64_IMG) self.assertEqual(array.shape, (1, 48, 48, 1)) diff --git a/test/test_interface.py b/test/test_interface.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_outputs.py b/test/test_outputs.py index 1a97cd237b..830a37d9aa 100644 --- a/test/test_outputs.py +++ b/test/test_outputs.py @@ -9,40 +9,40 @@ PACKAGE_NAME = 'gradio' class TestLabel(unittest.TestCase): def test_path_exists(self): out = outputs.Label() - path = out._get_template_path() + path = out.get_template_path() self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_postprocessing_string(self): string = 'happy' out = outputs.Label() - label = out._postprocess(string) + label = out.postprocess(string) self.assertEqual(label, string) def test_postprocessing_one_hot(self): one_hot = np.array([0, 0, 0, 1, 0]) true_label = 3 out = outputs.Label() - label = out._postprocess(one_hot) + label = out.postprocess(one_hot) self.assertEqual(label, true_label) def test_postprocessing_int(self): true_label_array = np.array([[[3]]]) true_label = 3 out = outputs.Label() - label = out._postprocess(true_label_array) + label = out.postprocess(true_label_array) self.assertEqual(label, true_label) class TestTextbox(unittest.TestCase): def test_path_exists(self): out = outputs.Textbox() - path = out._get_template_path() + path = out.get_template_path() self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_postprocessing(self): string = 'happy' out = outputs.Textbox() - string = out._postprocess(string) + string = out.postprocess(string) self.assertEqual(string, string)