mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
refactored inputs.py to support js and template rendering
This commit is contained in:
parent
61e9c01219
commit
e18646af15
@ -34,11 +34,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 101,
|
||||
"execution_count": 121,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299, num_channels=3)\n",
|
||||
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299, num_channels=3, aspect_ratio=1.0)\n",
|
||||
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=8, num_top_classes=5)\n",
|
||||
"\n",
|
||||
"io = gradio.Interface(inputs=inp, \n",
|
||||
@ -49,7 +49,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 103,
|
||||
"execution_count": 123,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
@ -58,9 +58,8 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Closing existing server...\n",
|
||||
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
|
||||
"Model is running locally at: http://localhost:7882/interface.html\n",
|
||||
"Model is running locally at: http://localhost:7895/interface.html\n",
|
||||
"To create a public link, set `share=True` in the argument to `launch()`\n"
|
||||
]
|
||||
}
|
||||
|
@ -11,6 +11,10 @@ from io import BytesIO
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
# Where to find the static resources associated with each template.
|
||||
BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'templates/input/{}.html'
|
||||
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
|
||||
|
||||
|
||||
class AbstractInput(ABC):
|
||||
"""
|
||||
@ -29,12 +33,28 @@ class AbstractInput(ABC):
|
||||
super().__init__()
|
||||
|
||||
def get_validation_inputs(self):
|
||||
"""
|
||||
An interface can optionally implement a method that returns a list of examples inputs that it should be able to
|
||||
accept and preprocess for validation purposes.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def get_template_path(self):
|
||||
def get_js_context(self):
|
||||
"""
|
||||
All interfaces should define a method that returns the path to its template.
|
||||
:return: a dictionary with context variables for the javascript file associated with the context
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_template_context(self):
|
||||
"""
|
||||
:return: a dictionary with context variables for the javascript file associated with the context
|
||||
"""
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
"""
|
||||
All interfaces should define a method that returns a name used for identifying the related static resources.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -54,8 +74,8 @@ class Sketchpad(AbstractInput):
|
||||
self.invert_colors = invert_colors
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/sketchpad.html'
|
||||
def get_name(self):
|
||||
return 'sketchpad'
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
@ -81,8 +101,8 @@ class Webcam(AbstractInput):
|
||||
def get_validation_inputs(self):
|
||||
return validation_data.BASE64_COLOR_IMAGES
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/webcam.html'
|
||||
def get_name(self):
|
||||
return 'webcam'
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
@ -100,8 +120,8 @@ class Textbox(AbstractInput):
|
||||
def get_validation_inputs(self):
|
||||
return validation_data.ENGLISH_TEXTS
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/textbox.html'
|
||||
def get_name(self):
|
||||
return 'textbox'
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
@ -112,20 +132,24 @@ class Textbox(AbstractInput):
|
||||
|
||||
class ImageUpload(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB',
|
||||
scale = 1/127.5, shift = -1):
|
||||
scale=1/127.5, shift=-1, aspect_ratio="false"):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
self.image_mode = image_mode
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
self.aspect_ratio = aspect_ratio
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
return validation_data.BASE64_COLOR_IMAGES
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/image_upload.html'
|
||||
def get_name(self):
|
||||
return 'image_upload'
|
||||
|
||||
def get_js_context(self):
|
||||
return {'aspect_ratio': self.aspect_ratio}
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
@ -146,12 +170,13 @@ class ImageUpload(AbstractInput):
|
||||
|
||||
class CSV(AbstractInput):
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/csv.html'
|
||||
def get_name(self):
|
||||
# return 'templates/input/csv.html'
|
||||
return 'csv'
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
By default, no pre-processing is applied to text.
|
||||
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
|
||||
"""
|
||||
return inp
|
||||
|
||||
|
@ -19,6 +19,7 @@ from requests.packages.urllib3.util.retry import Retry
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
from distutils import dir_util
|
||||
from gradio import inputs, outputs
|
||||
|
||||
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
|
||||
@ -49,11 +50,14 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
:param input_interface: an AbstractInput object which includes is used to get the input template
|
||||
:param output_interface: an AbstractInput object which includes is used to get the input template
|
||||
"""
|
||||
input_template_path = pkg_resources.resource_filename('gradio', input_interface.get_template_path())
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
|
||||
output_template_path = pkg_resources.resource_filename('gradio', output_interface.get_template_path())
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(input_page.read(), features="html.parser")
|
||||
input_soup = BeautifulSoup(render_string_or_list_with_tags(input_page.read(),
|
||||
input_interface.get_template_context()),
|
||||
features="html.parser")
|
||||
output_soup = BeautifulSoup(output_page.read(), features="html.parser")
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
@ -68,6 +72,9 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
f.write(str(all_io_soup))
|
||||
|
||||
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||
render_template_with_tags(os.path.join(temp_dir,
|
||||
inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
|
||||
input_interface.get_js_context())
|
||||
|
||||
|
||||
def copy_files(src_dir, dest_dir):
|
||||
@ -88,16 +95,28 @@ def render_template_with_tags(template_path, context):
|
||||
"""
|
||||
with open(template_path) as fin:
|
||||
old_lines = fin.readlines()
|
||||
new_lines = []
|
||||
for line in old_lines:
|
||||
for key, value in context.items():
|
||||
line = line.replace(r'{{' + key + r'}}', value)
|
||||
new_lines.append(line)
|
||||
new_lines = render_string_or_list_with_tags(old_lines, context)
|
||||
with open(template_path, 'w') as fout:
|
||||
for line in new_lines:
|
||||
fout.write(line)
|
||||
|
||||
|
||||
def render_string_or_list_with_tags(old_lines, context):
|
||||
# Handle string case
|
||||
if isinstance(old_lines, str):
|
||||
for key, value in context.items():
|
||||
old_lines = old_lines.replace(r'{{' + key + r'}}', str(value))
|
||||
return old_lines
|
||||
|
||||
# Handle list case
|
||||
new_lines = []
|
||||
for line in old_lines:
|
||||
for key, value in context.items():
|
||||
line = line.replace(r'{{' + key + r'}}', str(value))
|
||||
new_lines.append(line)
|
||||
return new_lines
|
||||
|
||||
|
||||
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
|
||||
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
|
||||
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')
|
||||
|
@ -1,4 +1,5 @@
|
||||
var cropper;
|
||||
var aspectRatio = "{{aspect_ratio}}"
|
||||
|
||||
$('body').on('click', ".input_image.drop_mode", function (e) {
|
||||
$(this).parent().find(".hidden_upload").click();
|
||||
@ -18,7 +19,7 @@ function loadPreviewFromFiles(files) {
|
||||
var image = $(".input_image img")
|
||||
image.attr("src", this.result)
|
||||
image.cropper({
|
||||
aspectRatio : 1.0,
|
||||
aspectRatio : aspectRatio,
|
||||
background: false
|
||||
});
|
||||
if (!cropper) {
|
||||
|
@ -10,8 +10,8 @@ PACKAGE_NAME = 'gradio'
|
||||
class TestSketchpad(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Sketchpad()
|
||||
path = inp.get_template_path()
|
||||
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Sketchpad()
|
||||
@ -22,8 +22,8 @@ class TestSketchpad(unittest.TestCase):
|
||||
class TestWebcam(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Webcam()
|
||||
path = inp.get_template_path()
|
||||
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
|
||||
self.assertFalse(os.path.exists(os.path.join(PACKAGE_NAME, path))) # Note implemented yet.
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Webcam()
|
||||
@ -34,8 +34,8 @@ class TestWebcam(unittest.TestCase):
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Textbox()
|
||||
path = inp.get_template_path()
|
||||
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Textbox()
|
||||
@ -46,7 +46,7 @@ class TestTextbox(unittest.TestCase):
|
||||
class TestImageUpload(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.ImageUpload()
|
||||
path = inp.get_template_path()
|
||||
path = inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(inp.get_name())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
|
Loading…
Reference in New Issue
Block a user