mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-09 02:00:44 +08:00
refactored outputs.py to support js and template rendering
This commit is contained in:
parent
e18646af15
commit
f89d84c6d8
@ -52,13 +52,15 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
"""
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
|
||||
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
|
||||
output_template_path = pkg_resources.resource_filename(
|
||||
'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()))
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(render_string_or_list_with_tags(input_page.read(),
|
||||
input_interface.get_template_context()),
|
||||
features="html.parser")
|
||||
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
|
||||
input_soup = BeautifulSoup(render_string_or_list_with_tags(
|
||||
input_page.read(), input_interface.get_template_context()), features="html.parser")
|
||||
output_soup = BeautifulSoup(
|
||||
render_string_or_list_with_tags(
|
||||
output_page.read(), output_interface.get_template_context()), features="html.parser")
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||
@ -72,9 +74,12 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
f.write(str(all_io_soup))
|
||||
|
||||
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||
render_template_with_tags(os.path.join(temp_dir,
|
||||
inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
|
||||
input_interface.get_js_context())
|
||||
render_template_with_tags(os.path.join(
|
||||
temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
|
||||
input_interface.get_js_context())
|
||||
render_template_with_tags(os.path.join(
|
||||
temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())),
|
||||
output_interface.get_js_context())
|
||||
|
||||
|
||||
def copy_files(src_dir, dest_dir):
|
||||
|
@ -9,6 +9,11 @@ import numpy as np
|
||||
import json
|
||||
from gradio import imagenet_class_labels
|
||||
|
||||
# Where to find the static resources associated with each template.
|
||||
BASE_OUTPUT_INTERFACE_TEMPLATE_PATH = 'templates/output/{}.html'
|
||||
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'
|
||||
|
||||
|
||||
class AbstractOutput(ABC):
|
||||
"""
|
||||
An abstract class for defining the methods that all gradio inputs should have.
|
||||
@ -23,10 +28,22 @@ class AbstractOutput(ABC):
|
||||
self.postprocess = postprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def get_template_path(self):
|
||||
def get_js_context(self):
|
||||
"""
|
||||
All interfaces should define a method that returns the path to its template.
|
||||
:return: a dictionary with context variables for the javascript file associated with the context
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_template_context(self):
|
||||
"""
|
||||
:return: a dictionary with context variables for the javascript file associated with the context
|
||||
"""
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
"""
|
||||
All outputs should define a method that returns a name used for identifying the related static resources.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -51,10 +68,13 @@ class Label(AbstractOutput):
|
||||
self.max_label_length = max_label_length
|
||||
super().__init__(postprocessing_fn=postprocessing_fn)
|
||||
|
||||
def get_name(self):
|
||||
return 'label'
|
||||
|
||||
def get_label_name(self, label):
|
||||
if self.label_names is None:
|
||||
name = label
|
||||
elif self.label_names == 'imagenet1000':
|
||||
elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this
|
||||
name = imagenet_class_labels.NAMES1000[label]
|
||||
else: # if list or dictionary
|
||||
name = self.label_names[label]
|
||||
@ -62,9 +82,6 @@ class Label(AbstractOutput):
|
||||
name = name[:self.max_label_length]
|
||||
return name
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/output/label.html'
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
@ -93,18 +110,19 @@ class Label(AbstractOutput):
|
||||
|
||||
class Textbox(AbstractOutput):
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/output/textbox.html'
|
||||
def get_name(self):
|
||||
return 'textbox'
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
return prediction
|
||||
|
||||
|
||||
class Image(AbstractOutput):
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/output/image.html'
|
||||
def get_name(self):
|
||||
return 'image'
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user