diff --git a/gradio/networking.py b/gradio/networking.py index b753468b64..6e092f00b1 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -52,13 +52,15 @@ def build_template(temp_dir, input_interface, output_interface): """ input_template_path = pkg_resources.resource_filename( 'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name())) - output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path()) + output_template_path = pkg_resources.resource_filename( + 'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name())) input_page = open(input_template_path) output_page = open(output_template_path) - input_soup = BeautifulSoup(render_string_or_list_with_tags(input_page.read(), - input_interface.get_template_context()), - features="html.parser") - output_soup = BeautifulSoup(output_page.read(), features="html.parser") + input_soup = BeautifulSoup(render_string_or_list_with_tags( + input_page.read(), input_interface.get_template_context()), features="html.parser") + output_soup = BeautifulSoup( + render_string_or_list_with_tags( + output_page.read(), output_interface.get_template_context()), features="html.parser") all_io_page = open(BASE_TEMPLATE) all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser") @@ -72,9 +74,12 @@ def build_template(temp_dir, input_interface, output_interface): f.write(str(all_io_soup)) copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP)) - render_template_with_tags(os.path.join(temp_dir, - inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())), - input_interface.get_js_context()) + render_template_with_tags(os.path.join( + temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())), + input_interface.get_js_context()) + render_template_with_tags(os.path.join( + temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())), + output_interface.get_js_context()) def copy_files(src_dir, dest_dir): diff --git a/gradio/outputs.py b/gradio/outputs.py index c90bb80975..9e403f8233 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -9,6 +9,11 @@ import numpy as np import json from gradio import imagenet_class_labels +# Where to find the static resources associated with each template. +BASE_OUTPUT_INTERFACE_TEMPLATE_PATH = 'templates/output/{}.html' +BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js' + + class AbstractOutput(ABC): """ An abstract class for defining the methods that all gradio inputs should have. @@ -23,10 +28,22 @@ class AbstractOutput(ABC): self.postprocess = postprocessing_fn super().__init__() - @abstractmethod - def get_template_path(self): + def get_js_context(self): """ - All interfaces should define a method that returns the path to its template. + :return: a dictionary with context variables for the javascript file associated with the context + """ + return {} + + def get_template_context(self): + """ + :return: a dictionary with context variables for the javascript file associated with the context + """ + return {} + + @abstractmethod + def get_name(self): + """ + All outputs should define a method that returns a name used for identifying the related static resources. """ pass @@ -51,10 +68,13 @@ class Label(AbstractOutput): self.max_label_length = max_label_length super().__init__(postprocessing_fn=postprocessing_fn) + def get_name(self): + return 'label' + def get_label_name(self, label): if self.label_names is None: name = label - elif self.label_names == 'imagenet1000': + elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this name = imagenet_class_labels.NAMES1000[label] else: # if list or dictionary name = self.label_names[label] @@ -62,9 +82,6 @@ class Label(AbstractOutput): name = name[:self.max_label_length] return name - def get_template_path(self): - return 'templates/output/label.html' - def postprocess(self, prediction): """ """ @@ -93,18 +110,19 @@ class Label(AbstractOutput): class Textbox(AbstractOutput): - def get_template_path(self): - return 'templates/output/textbox.html' + def get_name(self): + return 'textbox' def postprocess(self, prediction): """ """ return prediction + class Image(AbstractOutput): - def get_template_path(self): - return 'templates/output/image.html' + def get_name(self): + return 'image' def postprocess(self, prediction): """