modularized inputs/outputs

This commit is contained in:
Abubakar Abid 2019-02-17 19:36:10 -08:00
parent 65f1bbbd34
commit 18b773f2ed
5 changed files with 138 additions and 75 deletions

100
.idea/workspace.xml generated
View File

@ -2,11 +2,11 @@
<project version="4">
<component name="ChangeListManager">
<list default="true" id="fd73cd66-e80f-470e-a2ec-e220d3b6b864" name="Default Changelist" comment="">
<change afterPath="$PROJECT_DIR$/outputs.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/Test Notebook.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/Test Notebook.ipynb" afterDir="false" />
<change beforePath="$PROJECT_DIR$/gradio.py" beforeDir="false" afterPath="$PROJECT_DIR$/gradio.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/inputs.py" beforeDir="false" afterPath="$PROJECT_DIR$/inputs.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/outputs.py" beforeDir="false" afterPath="$PROJECT_DIR$/outputs.py" afterDir="false" />
</list>
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
<option name="SHOW_DIALOG" value="false" />
@ -90,13 +90,13 @@
<counts>
<entry key="dummy" value="10" />
<entry key="gitignore" value="2" />
<entry key="py" value="2014" />
<entry key="py" value="2099" />
</counts>
</usages-collector>
<usages-collector id="statistics.file.types.edit">
<counts>
<entry key="PLAIN_TEXT" value="12" />
<entry key="Python" value="2014" />
<entry key="Python" value="2099" />
</counts>
</usages-collector>
</session>
@ -145,8 +145,8 @@
<file pinned="false" current-in-tab="true">
<entry file="file://$PROJECT_DIR$/gradio.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="491">
<caret line="20" column="60" selection-start-line="20" selection-start-column="60" selection-end-line="20" selection-end-column="60" />
<state relative-caret-position="266">
<caret line="50" column="22" lean-forward="true" selection-start-line="50" selection-start-column="22" selection-end-line="50" selection-end-column="22" />
<folding>
<element signature="e#0#14#0" expanded="true" />
</folding>
@ -157,11 +157,8 @@
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/inputs.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="275">
<caret line="11" selection-start-line="11" selection-end-line="11" />
<folding>
<element signature="e#0#35#0" expanded="true" />
</folding>
<state relative-caret-position="50">
<caret line="2" column="21" lean-forward="true" selection-start-line="2" selection-start-column="21" selection-end-line="2" selection-end-column="21" />
</state>
</provider>
</entry>
@ -169,8 +166,8 @@
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/outputs.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="175">
<caret line="7" selection-start-line="6" selection-end-line="7" />
<state relative-caret-position="250">
<caret line="10" column="10" selection-start-line="10" selection-start-column="10" selection-end-line="10" selection-end-column="10" />
</state>
</provider>
</entry>
@ -203,8 +200,8 @@
<file pinned="false" current-in-tab="false">
<entry file="file://$PROJECT_DIR$/preprocessing_utils.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="75">
<caret line="3" column="19" selection-start-line="3" selection-start-column="4" selection-end-line="3" selection-end-column="19" />
<state relative-caret-position="966">
<caret line="53" selection-start-line="3" selection-end-line="53" />
</state>
</provider>
</entry>
@ -287,9 +284,9 @@
<option value="$PROJECT_DIR$/preprocessing_utils.py" />
<option value="$PROJECT_DIR$/networking.py" />
<option value="$PROJECT_DIR$/.gitignore" />
<option value="$PROJECT_DIR$/gradio.py" />
<option value="$PROJECT_DIR$/outputs.py" />
<option value="$PROJECT_DIR$/inputs.py" />
<option value="$PROJECT_DIR$/outputs.py" />
<option value="$PROJECT_DIR$/gradio.py" />
</list>
</option>
</component>
@ -478,34 +475,6 @@
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../khateebi/templates/index.html">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="925">
<caret line="37" column="32" selection-start-line="37" selection-start-column="31" selection-end-line="37" selection-end-column="41" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../khateebi/settings.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="150">
<caret line="6" column="38" selection-start-line="6" selection-start-column="38" selection-end-line="6" selection-end-column="38" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../khateebi/run.sh">
<provider selected="true" editor-type-id="text-editor">
<state>
<caret column="11" selection-start-column="11" selection-end-column="11" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/preprocessing_utils.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="75">
<caret line="3" column="19" selection-start-line="3" selection-start-column="4" selection-end-line="3" selection-end-column="19" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/networking.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="450">
@ -527,27 +496,52 @@
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../khateebi/run.sh">
<provider selected="true" editor-type-id="text-editor">
<state>
<caret column="11" selection-start-column="11" selection-end-column="11" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../khateebi/settings.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="150">
<caret line="6" column="38" selection-start-line="6" selection-start-column="38" selection-end-line="6" selection-end-column="38" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/../khateebi/templates/index.html">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="925">
<caret line="37" column="32" selection-start-line="37" selection-start-column="31" selection-end-line="37" selection-end-column="41" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/preprocessing_utils.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="966">
<caret line="53" selection-start-line="3" selection-end-line="53" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/outputs.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="175">
<caret line="7" selection-start-line="6" selection-end-line="7" />
<state relative-caret-position="250">
<caret line="10" column="10" selection-start-line="10" selection-start-column="10" selection-end-line="10" selection-end-column="10" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/inputs.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="275">
<caret line="11" selection-start-line="11" selection-end-line="11" />
<folding>
<element signature="e#0#35#0" expanded="true" />
</folding>
<state relative-caret-position="50">
<caret line="2" column="21" lean-forward="true" selection-start-line="2" selection-start-column="21" selection-end-line="2" selection-end-column="21" />
</state>
</provider>
</entry>
<entry file="file://$PROJECT_DIR$/gradio.py">
<provider selected="true" editor-type-id="text-editor">
<state relative-caret-position="491">
<caret line="20" column="60" selection-start-line="20" selection-start-column="60" selection-end-line="20" selection-end-column="60" />
<state relative-caret-position="266">
<caret line="50" column="22" lean-forward="true" selection-start-line="50" selection-start-column="22" selection-end-line="50" selection-end-column="22" />
<folding>
<element signature="e#0#14#0" expanded="true" />
</folding>

View File

@ -43,8 +43,85 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"import numpy as np\n",
"\n",
"def resize_and_crop(img, size, crop_type='top'):\n",
" \"\"\"\n",
" Resize and crop an image to fit the specified size.\n",
" args:\n",
" img_path: path for the image to resize.\n",
" modified_path: path to store the modified image.\n",
" size: `(width, height)` tuple.\n",
" crop_type: can be 'top', 'middle' or 'bottom', depending on this\n",
" value, the image will cropped getting the 'top/left', 'midle' or\n",
" 'bottom/rigth' of the image to fit the size.\n",
" raises:\n",
" Exception: if can not open the file in img_path of there is problems\n",
" to save the image.\n",
" ValueError: if an invalid `crop_type` is provided.\n",
" \"\"\"\n",
" # Get current and desired ratio for the images\n",
" img_ratio = img.size[0] / float(img.size[1])\n",
" ratio = size[0] / float(size[1])\n",
" # The image is scaled/cropped vertically or horizontally depending on the ratio\n",
" if ratio > img_ratio:\n",
" img = img.resize((size[0], size[0] * img.size[1] / img.size[0]),\n",
" Image.ANTIALIAS)\n",
" # Crop in the top, middle or bottom\n",
" if crop_type == 'top':\n",
" box = (0, 0, img.size[0], size[1])\n",
" elif crop_type == 'middle':\n",
" box = (0, (img.size[1] - size[1]) / 2, img.size[0], (img.size[1] + size[1]) / 2)\n",
" elif crop_type == 'bottom':\n",
" box = (0, img.size[1] - size[1], img.size[0], img.size[1])\n",
" else:\n",
" raise ValueError('ERROR: invalid value for crop_type')\n",
" img = img.crop(box)\n",
" elif ratio < img_ratio:\n",
" img = img.resize((size[1] * img.size[0] / img.size[1], size[1]),\n",
" Image.ANTIALIAS)\n",
" # Crop in the top, middle or bottom\n",
" if crop_type == 'top':\n",
" box = (0, 0, size[0], img.size[1])\n",
" elif crop_type == 'middle':\n",
" box = ((img.size[0] - size[0]) / 2, 0, (img.size[0] + size[0]) / 2, img.size[1])\n",
" elif crop_type == 'bottom':\n",
" box = (img.size[0] - size[0], 0, img.size[0], img.size[1])\n",
" else:\n",
" raise ValueError('ERROR: invalid value for crop_type')\n",
" img = img.crop(box)\n",
" else:\n",
" img = img.resize((size[0], size[1]),\n",
" Image.ANTIALIAS)\n",
" # If the scale is the same, we do not need to crop\n",
" return img\n",
"\n",
"\n",
"def _pre_process(imgstring):\n",
" \"\"\"\n",
" \"\"\"\n",
" content = imgstring.split(';')[1]\n",
" image_encoded = content.split(',')[1]\n",
" body = base64.decodebytes(image_encoded.encode('utf-8'))\n",
" im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')\n",
" im = resize_and_crop(im, (28, 28))\n",
" array = np.array(im).flatten().reshape(1, 28, 28, 1)\n",
" return array"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
@ -56,7 +133,7 @@
}
],
"source": [
"iface = gradio.Interface(input='sketchpad',output='class',model_obj=model, model_type='keras')\n",
"iface = gradio.Interface(input='sketchpad', output='class', model=model, model_type='keras', preprocessing_fn=_pre_process)\n",
"iface.launch()"
]
}

View File

@ -18,18 +18,16 @@ class Interface():
"""
build_template_path = 'templates/tmp_html.html'
def __init__(self, input, output, model_obj, model_type, preprocessing_fn=None, postprocessing_fn=None):
def __init__(self, input, output, model, model_type, preprocessing_fn=None, postprocessing_fn=None):
"""
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
:param model_obj: the model object, such as a sklearn classifier or keras model.
:param model_params: additional model parameters.
"""
self.input_interface = inputs.registry[input]()
self.output_interface = outputs.registry[output]()
self.input_interface = inputs.registry[input](preprocessing_fn)
self.output_interface = outputs.registry[output](postprocessing_fn)
self.model_type = model_type
self.model_obj = model_obj
self.preprocessing_fn = preprocessing_fn
self.postprocessing_fn = postprocessing_fn
self.model_obj = model
def _build_template(self):
input_template_path = self.input_interface._get_template_path()
@ -71,19 +69,9 @@ class Interface():
while True:
try:
msg = await websocket.recv()
if self.preprocessing_fn is None:
processed_input = self.input_interface._pre_process(await websocket.recv())
else:
processed_input = self.preprocessing_fn(await websocket.recv())
processed_input = self.input_interface._pre_process(msg)
prediction = self.predict(processed_input)
if self.postprocessing_fn is None:
processed_output = self.output_interface._post_process(prediction)
else:
processed_output = self.postprocessing_fn(prediction)
processed_output = self.output_interface._post_process(prediction)
await websocket.send(str(processed_output))
except websockets.exceptions.ConnectionClosed:
pass

View File

@ -11,7 +11,9 @@ class AbstractInput(ABC):
When this is subclassed, it is automatically added to the registry
"""
def __init__(self):
def __init__(self, preprocessing_fn=None):
if preprocessing_fn is not None:
self._pre_process = preprocessing_fn
super().__init__()
@abstractmethod

View File

@ -7,9 +7,11 @@ class AbstractOutput(ABC):
When this is subclassed, it is automatically added to the registry
"""
def __init__(self):
def __init__(self, postprocessing_fn=None):
"""
"""
if postprocessing_fn is not None:
self._post_process = postprocessing_fn
super().__init__()
@abstractmethod