refactored outputs.py to support js and template rendering

This commit is contained in:
Abubakar Abid 2019-03-18 05:32:57 -07:00
parent e18646af15
commit f89d84c6d8
2 changed files with 42 additions and 19 deletions

View File

@ -52,13 +52,15 @@ def build_template(temp_dir, input_interface, output_interface):
""" """
input_template_path = pkg_resources.resource_filename( input_template_path = pkg_resources.resource_filename(
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name())) 'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path()) output_template_path = pkg_resources.resource_filename(
'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()))
input_page = open(input_template_path) input_page = open(input_template_path)
output_page = open(output_template_path) output_page = open(output_template_path)
input_soup = BeautifulSoup(render_string_or_list_with_tags(input_page.read(), input_soup = BeautifulSoup(render_string_or_list_with_tags(
input_interface.get_template_context()), input_page.read(), input_interface.get_template_context()), features="html.parser")
features="html.parser") output_soup = BeautifulSoup(
output_soup = BeautifulSoup(output_page.read(), features="html.parser") render_string_or_list_with_tags(
output_page.read(), output_interface.get_template_context()), features="html.parser")
all_io_page = open(BASE_TEMPLATE) all_io_page = open(BASE_TEMPLATE)
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser") all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
@ -72,9 +74,12 @@ def build_template(temp_dir, input_interface, output_interface):
f.write(str(all_io_soup)) f.write(str(all_io_soup))
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP)) copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
render_template_with_tags(os.path.join(temp_dir, render_template_with_tags(os.path.join(
inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())), temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
input_interface.get_js_context()) input_interface.get_js_context())
render_template_with_tags(os.path.join(
temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())),
output_interface.get_js_context())
def copy_files(src_dir, dest_dir): def copy_files(src_dir, dest_dir):

View File

@ -9,6 +9,11 @@ import numpy as np
import json import json
from gradio import imagenet_class_labels from gradio import imagenet_class_labels
# Where to find the static resources associated with each template.
BASE_OUTPUT_INTERFACE_TEMPLATE_PATH = 'templates/output/{}.html'
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'
class AbstractOutput(ABC): class AbstractOutput(ABC):
""" """
An abstract class for defining the methods that all gradio inputs should have. An abstract class for defining the methods that all gradio inputs should have.
@ -23,10 +28,22 @@ class AbstractOutput(ABC):
self.postprocess = postprocessing_fn self.postprocess = postprocessing_fn
super().__init__() super().__init__()
@abstractmethod def get_js_context(self):
def get_template_path(self):
""" """
All interfaces should define a method that returns the path to its template. :return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
@abstractmethod
def get_name(self):
"""
All outputs should define a method that returns a name used for identifying the related static resources.
""" """
pass pass
@ -51,10 +68,13 @@ class Label(AbstractOutput):
self.max_label_length = max_label_length self.max_label_length = max_label_length
super().__init__(postprocessing_fn=postprocessing_fn) super().__init__(postprocessing_fn=postprocessing_fn)
def get_name(self):
return 'label'
def get_label_name(self, label): def get_label_name(self, label):
if self.label_names is None: if self.label_names is None:
name = label name = label
elif self.label_names == 'imagenet1000': elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this
name = imagenet_class_labels.NAMES1000[label] name = imagenet_class_labels.NAMES1000[label]
else: # if list or dictionary else: # if list or dictionary
name = self.label_names[label] name = self.label_names[label]
@ -62,9 +82,6 @@ class Label(AbstractOutput):
name = name[:self.max_label_length] name = name[:self.max_label_length]
return name return name
def get_template_path(self):
return 'templates/output/label.html'
def postprocess(self, prediction): def postprocess(self, prediction):
""" """
""" """
@ -93,18 +110,19 @@ class Label(AbstractOutput):
class Textbox(AbstractOutput): class Textbox(AbstractOutput):
def get_template_path(self): def get_name(self):
return 'templates/output/textbox.html' return 'textbox'
def postprocess(self, prediction): def postprocess(self, prediction):
""" """
""" """
return prediction return prediction
class Image(AbstractOutput): class Image(AbstractOutput):
def get_template_path(self): def get_name(self):
return 'templates/output/image.html' return 'image'
def postprocess(self, prediction): def postprocess(self, prediction):
""" """