refactored inputs.py to support js and template rendering

This commit is contained in:
Abubakar Abid 2019-03-18 05:24:25 -07:00
parent 61e9c01219
commit e18646af15
5 changed files with 79 additions and 35 deletions

View File

@ -34,11 +34,11 @@
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299, num_channels=3)\n",
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299, num_channels=3, aspect_ratio=1.0)\n",
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=8, num_top_classes=5)\n",
"\n",
"io = gradio.Interface(inputs=inp, \n",
@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 123,
"metadata": {
"scrolled": false
},
@ -58,9 +58,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Closing existing server...\n",
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7882/interface.html\n",
"Model is running locally at: http://localhost:7895/interface.html\n",
"To create a public link, set `share=True` in the argument to `launch()`\n"
]
}

View File

@ -11,6 +11,10 @@ from io import BytesIO
import numpy as np
from PIL import Image, ImageOps
# Where to find the static resources associated with each template.
BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'templates/input/{}.html'
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
class AbstractInput(ABC):
"""
@ -29,12 +33,28 @@ class AbstractInput(ABC):
super().__init__()
def get_validation_inputs(self):
"""
An interface can optionally implement a method that returns a list of examples inputs that it should be able to
accept and preprocess for validation purposes.
"""
return []
@abstractmethod
def get_template_path(self):
def get_js_context(self):
"""
All interfaces should define a method that returns the path to its template.
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {}
@abstractmethod
def get_name(self):
"""
All interfaces should define a method that returns a name used for identifying the related static resources.
"""
pass
@ -54,8 +74,8 @@ class Sketchpad(AbstractInput):
self.invert_colors = invert_colors
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/input/sketchpad.html'
def get_name(self):
return 'sketchpad'
def preprocess(self, inp):
"""
@ -81,8 +101,8 @@ class Webcam(AbstractInput):
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_template_path(self):
return 'templates/input/webcam.html'
def get_name(self):
return 'webcam'
def preprocess(self, inp):
"""
@ -100,8 +120,8 @@ class Textbox(AbstractInput):
def get_validation_inputs(self):
return validation_data.ENGLISH_TEXTS
def get_template_path(self):
return 'templates/input/textbox.html'
def get_name(self):
return 'textbox'
def preprocess(self, inp):
"""
@ -112,20 +132,24 @@ class Textbox(AbstractInput):
class ImageUpload(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB',
scale = 1/127.5, shift = -1):
scale=1/127.5, shift=-1, aspect_ratio="false"):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
self.image_mode = image_mode
self.scale = scale
self.shift = shift
self.aspect_ratio = aspect_ratio
super().__init__(preprocessing_fn=preprocessing_fn)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_template_path(self):
return 'templates/input/image_upload.html'
def get_name(self):
return 'image_upload'
def get_js_context(self):
return {'aspect_ratio': self.aspect_ratio}
def preprocess(self, inp):
"""
@ -146,12 +170,13 @@ class ImageUpload(AbstractInput):
class CSV(AbstractInput):
def get_template_path(self):
return 'templates/input/csv.html'
def get_name(self):
# return 'templates/input/csv.html'
return 'csv'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to text.
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
"""
return inp

View File

@ -19,6 +19,7 @@ from requests.packages.urllib3.util.retry import Retry
import pkg_resources
from bs4 import BeautifulSoup
from distutils import dir_util
from gradio import inputs, outputs
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
@ -49,11 +50,14 @@ def build_template(temp_dir, input_interface, output_interface):
:param input_interface: an AbstractInput object which includes is used to get the input template
:param output_interface: an AbstractInput object which includes is used to get the input template
"""
input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path())
input_template_path = pkg_resources.resource_filename(
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
input_page = open(input_template_path)
output_page = open(output_template_path)
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
input_soup = BeautifulSoup(render_string_or_list_with_tags(input_page.read(),
input_interface.get_template_context()),
features="html.parser")
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
all_io_page = open(BASE_TEMPLATE)
@ -68,6 +72,9 @@ def build_template(temp_dir, input_interface, output_interface):
f.write(str(all_io_soup))
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
render_template_with_tags(os.path.join(temp_dir,
inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
input_interface.get_js_context())
def copy_files(src_dir, dest_dir):
@ -88,16 +95,28 @@ def render_template_with_tags(template_path, context):
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r'{{' + key + r'}}', value)
new_lines.append(line)
new_lines = render_string_or_list_with_tags(old_lines, context)
with open(template_path, 'w') as fout:
for line in new_lines:
fout.write(line)
def render_string_or_list_with_tags(old_lines, context):
# Handle string case
if isinstance(old_lines, str):
for key, value in context.items():
old_lines = old_lines.replace(r'{{' + key + r'}}', str(value))
return old_lines
# Handle list case
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r'{{' + key + r'}}', str(value))
new_lines.append(line)
return new_lines
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')

View File

@ -1,4 +1,5 @@
var cropper;
var aspectRatio = "{{aspect_ratio}}"
$('body').on('click', ".input_image.drop_mode", function (e) {
$(this).parent().find(".hidden_upload").click();
@ -18,7 +19,7 @@ function loadPreviewFromFiles(files) {
var image = $(".input_image img")
image.attr("src", this.result)
image.cropper({
aspectRatio : 1.0,
aspectRatio : aspectRatio,
background: false
});
if (!cropper) {

View File

@ -10,8 +10,8 @@ PACKAGE_NAME = 'gradio'
class TestSketchpad(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Sketchpad()
path = inp.get_template_path()
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Sketchpad()
@ -22,8 +22,8 @@ class TestSketchpad(unittest.TestCase):
class TestWebcam(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Webcam()
path = inp.get_template_path()
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
self.assertFalse(os.path.exists(os.path.join(PACKAGE_NAME, path))) # Note implemented yet.
def test_preprocessing(self):
inp = inputs.Webcam()
@ -34,8 +34,8 @@ class TestWebcam(unittest.TestCase):
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Textbox()
path = inp.get_template_path()
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Textbox()
@ -46,7 +46,7 @@ class TestTextbox(unittest.TestCase):
class TestImageUpload(unittest.TestCase):
def test_path_exists(self):
inp = inputs.ImageUpload()
path = inp.get_template_path()
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):