added validation to interfaces

This commit is contained in:
Abubakar Abid 2019-03-10 17:24:03 -07:00
parent b33ad40e60
commit b87ac1c071
4 changed files with 57 additions and 41 deletions

View File

@ -2,9 +2,18 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@ -16,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -25,12 +34,12 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299, num_channels=3)\n",
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=None, num_top_classes=8)\n",
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=12, num_top_classes=5)\n",
"\n",
"iface = gradio.Interface(inputs=inp, \n",
" outputs=out,\n",
@ -45,14 +54,16 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validating samples: 6 out of 6\r"
"Validating samples: 6/6 [======]\n",
"\n",
"Validation passed successfully!\n"
]
}
],
@ -62,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 46,
"metadata": {
"scrolled": false
},
@ -72,29 +83,8 @@
"output_type": "stream",
"text": [
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7863/interface.html\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Unexpected exception in keepalive ping task\n",
"Traceback (most recent call last):\n",
" File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\websockets\\protocol.py\", line 984, in keepalive_ping\n",
" ping_waiter = yield from self.ping()\n",
" File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\websockets\\protocol.py\", line 583, in ping\n",
" yield from self.ensure_open()\n",
" File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\websockets\\protocol.py\", line 658, in ensure_open\n",
" ) from self.transfer_data_exc\n",
"websockets.exceptions.ConnectionClosed: WebSocket connection is closed: code = 1001 (going away), no reason\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model available publicly for 8 hours at: http://03244fc8.ngrok.io/interface.html\n"
"Model is running locally at: http://localhost:7866/interface.html\n",
"To create a public link, set `share=True` in the argument to `launch()`\n"
]
},
{
@ -104,14 +94,14 @@
" <iframe\n",
" width=\"1000\"\n",
" height=\"500\"\n",
" src=\"http://localhost:7863/interface.html\"\n",
" src=\"http://localhost:7866/interface.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x28b48b00c50>"
"<IPython.lib.display.IFrame at 0x1b08a51f748>"
]
},
"metadata": {},
@ -119,7 +109,7 @@
}
],
"source": [
"iface.launch(inline=True, browser=False, share=True);"
"iface.launch(inline=True, browser=False, share=False);"
]
}
],

View File

@ -28,6 +28,9 @@ class AbstractInput(ABC):
self.preprocess = preprocessing_fn
super().__init__()
def get_validation_inputs(self):
return []
@abstractmethod
def get_template_path(self):
"""
@ -75,6 +78,9 @@ class Webcam(AbstractInput):
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_template_path(self):
return 'templates/input/webcam.html'
@ -91,6 +97,8 @@ class Webcam(AbstractInput):
class Textbox(AbstractInput):
def get_validation_inputs(self):
return validation_data.ENGLISH_TEXTS
def get_template_path(self):
return 'templates/input/textbox.html'
@ -113,8 +121,7 @@ class ImageUpload(AbstractInput):
self.shift = shift
super().__init__(preprocessing_fn=preprocessing_fn)
@staticmethod
def get_validation_inputs():
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
def get_template_path(self):

View File

@ -138,18 +138,26 @@ class Interface:
def validate(self):
if self.validate_flag:
if self.verbose:
print("Interface already validated")
return
validation_inputs = self.input_interface.get_validation_inputs()
n = len(validation_inputs)
if n == 0:
self.validate_flag = True
if self.verbose:
print("No validation samples for this interface... skipping validation.")
return
for m, msg in enumerate(validation_inputs):
if self.verbose:
print(f"Validating samples: {m+1} out of {len(validation_inputs)}", end='\r')
print(f"Validating samples: {m+1}/{n} [" + "="*(m+1) + "."*(n-m-1) + "]", end='\r')
try:
processed_input = self.input_interface.preprocess(msg)
prediction = self.predict(processed_input)
except Exception as e:
if self.verbose:
print("\n----------")
print("Validation failed, likely due to invalid pre-processing. See below:\n")
print("Validation failed, likely due to incompatible pre-processing and model input. See below:\n")
print(traceback.format_exc())
break
try:
@ -157,11 +165,14 @@ class Interface:
except Exception as e:
if self.verbose:
print("\n----------")
print("Validation failed, likely due to invalid post-processing. See below:\n")
print("Validation failed, likely due to incompatible model output and post-processing."
"See below:\n")
print(traceback.format_exc())
break
else: # This means if a break was not explicitly called
self.validate_flag = True
if self.verbose:
print("\n\nValidation passed successfully!")
return
raise RuntimeError("Validation did not pass")
@ -170,8 +181,8 @@ class Interface:
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
:param share: boolean. If True, then a share link is generated using ngrok is displayed to the user.
"""
# if validate and not(self.validate_flag):
# self.validate()
if validate and not self.validate_flag:
self.validate()
self.launch_flag = True
output_directory = tempfile.mkdtemp()

File diff suppressed because one or more lines are too long