mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-09 02:00:44 +08:00
refactored outputs.py to support js and template rendering
This commit is contained in:
parent
e18646af15
commit
f89d84c6d8
@ -52,13 +52,15 @@ def build_template(temp_dir, input_interface, output_interface):
|
|||||||
"""
|
"""
|
||||||
input_template_path = pkg_resources.resource_filename(
|
input_template_path = pkg_resources.resource_filename(
|
||||||
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
|
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
|
||||||
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
|
output_template_path = pkg_resources.resource_filename(
|
||||||
|
'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()))
|
||||||
input_page = open(input_template_path)
|
input_page = open(input_template_path)
|
||||||
output_page = open(output_template_path)
|
output_page = open(output_template_path)
|
||||||
input_soup = BeautifulSoup(render_string_or_list_with_tags(input_page.read(),
|
input_soup = BeautifulSoup(render_string_or_list_with_tags(
|
||||||
input_interface.get_template_context()),
|
input_page.read(), input_interface.get_template_context()), features="html.parser")
|
||||||
features="html.parser")
|
output_soup = BeautifulSoup(
|
||||||
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
|
render_string_or_list_with_tags(
|
||||||
|
output_page.read(), output_interface.get_template_context()), features="html.parser")
|
||||||
|
|
||||||
all_io_page = open(BASE_TEMPLATE)
|
all_io_page = open(BASE_TEMPLATE)
|
||||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||||
@ -72,9 +74,12 @@ def build_template(temp_dir, input_interface, output_interface):
|
|||||||
f.write(str(all_io_soup))
|
f.write(str(all_io_soup))
|
||||||
|
|
||||||
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||||
render_template_with_tags(os.path.join(temp_dir,
|
render_template_with_tags(os.path.join(
|
||||||
inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
|
temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
|
||||||
input_interface.get_js_context())
|
input_interface.get_js_context())
|
||||||
|
render_template_with_tags(os.path.join(
|
||||||
|
temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())),
|
||||||
|
output_interface.get_js_context())
|
||||||
|
|
||||||
|
|
||||||
def copy_files(src_dir, dest_dir):
|
def copy_files(src_dir, dest_dir):
|
||||||
|
@ -9,6 +9,11 @@ import numpy as np
|
|||||||
import json
|
import json
|
||||||
from gradio import imagenet_class_labels
|
from gradio import imagenet_class_labels
|
||||||
|
|
||||||
|
# Where to find the static resources associated with each template.
|
||||||
|
BASE_OUTPUT_INTERFACE_TEMPLATE_PATH = 'templates/output/{}.html'
|
||||||
|
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'
|
||||||
|
|
||||||
|
|
||||||
class AbstractOutput(ABC):
|
class AbstractOutput(ABC):
|
||||||
"""
|
"""
|
||||||
An abstract class for defining the methods that all gradio inputs should have.
|
An abstract class for defining the methods that all gradio inputs should have.
|
||||||
@ -23,10 +28,22 @@ class AbstractOutput(ABC):
|
|||||||
self.postprocess = postprocessing_fn
|
self.postprocess = postprocessing_fn
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@abstractmethod
|
def get_js_context(self):
|
||||||
def get_template_path(self):
|
|
||||||
"""
|
"""
|
||||||
All interfaces should define a method that returns the path to its template.
|
:return: a dictionary with context variables for the javascript file associated with the context
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_template_context(self):
|
||||||
|
"""
|
||||||
|
:return: a dictionary with context variables for the javascript file associated with the context
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self):
|
||||||
|
"""
|
||||||
|
All outputs should define a method that returns a name used for identifying the related static resources.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -51,10 +68,13 @@ class Label(AbstractOutput):
|
|||||||
self.max_label_length = max_label_length
|
self.max_label_length = max_label_length
|
||||||
super().__init__(postprocessing_fn=postprocessing_fn)
|
super().__init__(postprocessing_fn=postprocessing_fn)
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
return 'label'
|
||||||
|
|
||||||
def get_label_name(self, label):
|
def get_label_name(self, label):
|
||||||
if self.label_names is None:
|
if self.label_names is None:
|
||||||
name = label
|
name = label
|
||||||
elif self.label_names == 'imagenet1000':
|
elif self.label_names == 'imagenet1000': # TODO:(abidlabs) better way to handle this
|
||||||
name = imagenet_class_labels.NAMES1000[label]
|
name = imagenet_class_labels.NAMES1000[label]
|
||||||
else: # if list or dictionary
|
else: # if list or dictionary
|
||||||
name = self.label_names[label]
|
name = self.label_names[label]
|
||||||
@ -62,9 +82,6 @@ class Label(AbstractOutput):
|
|||||||
name = name[:self.max_label_length]
|
name = name[:self.max_label_length]
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def get_template_path(self):
|
|
||||||
return 'templates/output/label.html'
|
|
||||||
|
|
||||||
def postprocess(self, prediction):
|
def postprocess(self, prediction):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
@ -93,18 +110,19 @@ class Label(AbstractOutput):
|
|||||||
|
|
||||||
class Textbox(AbstractOutput):
|
class Textbox(AbstractOutput):
|
||||||
|
|
||||||
def get_template_path(self):
|
def get_name(self):
|
||||||
return 'templates/output/textbox.html'
|
return 'textbox'
|
||||||
|
|
||||||
def postprocess(self, prediction):
|
def postprocess(self, prediction):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
class Image(AbstractOutput):
|
class Image(AbstractOutput):
|
||||||
|
|
||||||
def get_template_path(self):
|
def get_name(self):
|
||||||
return 'templates/output/image.html'
|
return 'image'
|
||||||
|
|
||||||
def postprocess(self, prediction):
|
def postprocess(self, prediction):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user