mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
modularized inputs/outputs
This commit is contained in:
parent
65f1bbbd34
commit
18b773f2ed
100
.idea/workspace.xml
generated
100
.idea/workspace.xml
generated
@ -2,11 +2,11 @@
|
||||
<project version="4">
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="fd73cd66-e80f-470e-a2ec-e220d3b6b864" name="Default Changelist" comment="">
|
||||
<change afterPath="$PROJECT_DIR$/outputs.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/Test Notebook.ipynb" beforeDir="false" afterPath="$PROJECT_DIR$/Test Notebook.ipynb" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/gradio.py" beforeDir="false" afterPath="$PROJECT_DIR$/gradio.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/inputs.py" beforeDir="false" afterPath="$PROJECT_DIR$/inputs.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/outputs.py" beforeDir="false" afterPath="$PROJECT_DIR$/outputs.py" afterDir="false" />
|
||||
</list>
|
||||
<option name="EXCLUDED_CONVERTED_TO_IGNORED" value="true" />
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
@ -90,13 +90,13 @@
|
||||
<counts>
|
||||
<entry key="dummy" value="10" />
|
||||
<entry key="gitignore" value="2" />
|
||||
<entry key="py" value="2014" />
|
||||
<entry key="py" value="2099" />
|
||||
</counts>
|
||||
</usages-collector>
|
||||
<usages-collector id="statistics.file.types.edit">
|
||||
<counts>
|
||||
<entry key="PLAIN_TEXT" value="12" />
|
||||
<entry key="Python" value="2014" />
|
||||
<entry key="Python" value="2099" />
|
||||
</counts>
|
||||
</usages-collector>
|
||||
</session>
|
||||
@ -145,8 +145,8 @@
|
||||
<file pinned="false" current-in-tab="true">
|
||||
<entry file="file://$PROJECT_DIR$/gradio.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="491">
|
||||
<caret line="20" column="60" selection-start-line="20" selection-start-column="60" selection-end-line="20" selection-end-column="60" />
|
||||
<state relative-caret-position="266">
|
||||
<caret line="50" column="22" lean-forward="true" selection-start-line="50" selection-start-column="22" selection-end-line="50" selection-end-column="22" />
|
||||
<folding>
|
||||
<element signature="e#0#14#0" expanded="true" />
|
||||
</folding>
|
||||
@ -157,11 +157,8 @@
|
||||
<file pinned="false" current-in-tab="false">
|
||||
<entry file="file://$PROJECT_DIR$/inputs.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="275">
|
||||
<caret line="11" selection-start-line="11" selection-end-line="11" />
|
||||
<folding>
|
||||
<element signature="e#0#35#0" expanded="true" />
|
||||
</folding>
|
||||
<state relative-caret-position="50">
|
||||
<caret line="2" column="21" lean-forward="true" selection-start-line="2" selection-start-column="21" selection-end-line="2" selection-end-column="21" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
@ -169,8 +166,8 @@
|
||||
<file pinned="false" current-in-tab="false">
|
||||
<entry file="file://$PROJECT_DIR$/outputs.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="175">
|
||||
<caret line="7" selection-start-line="6" selection-end-line="7" />
|
||||
<state relative-caret-position="250">
|
||||
<caret line="10" column="10" selection-start-line="10" selection-start-column="10" selection-end-line="10" selection-end-column="10" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
@ -203,8 +200,8 @@
|
||||
<file pinned="false" current-in-tab="false">
|
||||
<entry file="file://$PROJECT_DIR$/preprocessing_utils.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="75">
|
||||
<caret line="3" column="19" selection-start-line="3" selection-start-column="4" selection-end-line="3" selection-end-column="19" />
|
||||
<state relative-caret-position="966">
|
||||
<caret line="53" selection-start-line="3" selection-end-line="53" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
@ -287,9 +284,9 @@
|
||||
<option value="$PROJECT_DIR$/preprocessing_utils.py" />
|
||||
<option value="$PROJECT_DIR$/networking.py" />
|
||||
<option value="$PROJECT_DIR$/.gitignore" />
|
||||
<option value="$PROJECT_DIR$/gradio.py" />
|
||||
<option value="$PROJECT_DIR$/outputs.py" />
|
||||
<option value="$PROJECT_DIR$/inputs.py" />
|
||||
<option value="$PROJECT_DIR$/outputs.py" />
|
||||
<option value="$PROJECT_DIR$/gradio.py" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
@ -478,34 +475,6 @@
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/../khateebi/templates/index.html">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="925">
|
||||
<caret line="37" column="32" selection-start-line="37" selection-start-column="31" selection-end-line="37" selection-end-column="41" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/../khateebi/settings.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="150">
|
||||
<caret line="6" column="38" selection-start-line="6" selection-start-column="38" selection-end-line="6" selection-end-column="38" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/../khateebi/run.sh">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state>
|
||||
<caret column="11" selection-start-column="11" selection-end-column="11" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/preprocessing_utils.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="75">
|
||||
<caret line="3" column="19" selection-start-line="3" selection-start-column="4" selection-end-line="3" selection-end-column="19" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/networking.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="450">
|
||||
@ -527,27 +496,52 @@
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/../khateebi/run.sh">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state>
|
||||
<caret column="11" selection-start-column="11" selection-end-column="11" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/../khateebi/settings.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="150">
|
||||
<caret line="6" column="38" selection-start-line="6" selection-start-column="38" selection-end-line="6" selection-end-column="38" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/../khateebi/templates/index.html">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="925">
|
||||
<caret line="37" column="32" selection-start-line="37" selection-start-column="31" selection-end-line="37" selection-end-column="41" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/preprocessing_utils.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="966">
|
||||
<caret line="53" selection-start-line="3" selection-end-line="53" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/outputs.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="175">
|
||||
<caret line="7" selection-start-line="6" selection-end-line="7" />
|
||||
<state relative-caret-position="250">
|
||||
<caret line="10" column="10" selection-start-line="10" selection-start-column="10" selection-end-line="10" selection-end-column="10" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/inputs.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="275">
|
||||
<caret line="11" selection-start-line="11" selection-end-line="11" />
|
||||
<folding>
|
||||
<element signature="e#0#35#0" expanded="true" />
|
||||
</folding>
|
||||
<state relative-caret-position="50">
|
||||
<caret line="2" column="21" lean-forward="true" selection-start-line="2" selection-start-column="21" selection-end-line="2" selection-end-column="21" />
|
||||
</state>
|
||||
</provider>
|
||||
</entry>
|
||||
<entry file="file://$PROJECT_DIR$/gradio.py">
|
||||
<provider selected="true" editor-type-id="text-editor">
|
||||
<state relative-caret-position="491">
|
||||
<caret line="20" column="60" selection-start-line="20" selection-start-column="60" selection-end-line="20" selection-end-column="60" />
|
||||
<state relative-caret-position="266">
|
||||
<caret line="50" column="22" lean-forward="true" selection-start-line="50" selection-start-column="22" selection-end-line="50" selection-end-column="22" />
|
||||
<folding>
|
||||
<element signature="e#0#14#0" expanded="true" />
|
||||
</folding>
|
||||
|
@ -43,8 +43,85 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import base64\n",
|
||||
"from PIL import Image\n",
|
||||
"from io import BytesIO\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"def resize_and_crop(img, size, crop_type='top'):\n",
|
||||
" \"\"\"\n",
|
||||
" Resize and crop an image to fit the specified size.\n",
|
||||
" args:\n",
|
||||
" img_path: path for the image to resize.\n",
|
||||
" modified_path: path to store the modified image.\n",
|
||||
" size: `(width, height)` tuple.\n",
|
||||
" crop_type: can be 'top', 'middle' or 'bottom', depending on this\n",
|
||||
" value, the image will cropped getting the 'top/left', 'midle' or\n",
|
||||
" 'bottom/rigth' of the image to fit the size.\n",
|
||||
" raises:\n",
|
||||
" Exception: if can not open the file in img_path of there is problems\n",
|
||||
" to save the image.\n",
|
||||
" ValueError: if an invalid `crop_type` is provided.\n",
|
||||
" \"\"\"\n",
|
||||
" # Get current and desired ratio for the images\n",
|
||||
" img_ratio = img.size[0] / float(img.size[1])\n",
|
||||
" ratio = size[0] / float(size[1])\n",
|
||||
" # The image is scaled/cropped vertically or horizontally depending on the ratio\n",
|
||||
" if ratio > img_ratio:\n",
|
||||
" img = img.resize((size[0], size[0] * img.size[1] / img.size[0]),\n",
|
||||
" Image.ANTIALIAS)\n",
|
||||
" # Crop in the top, middle or bottom\n",
|
||||
" if crop_type == 'top':\n",
|
||||
" box = (0, 0, img.size[0], size[1])\n",
|
||||
" elif crop_type == 'middle':\n",
|
||||
" box = (0, (img.size[1] - size[1]) / 2, img.size[0], (img.size[1] + size[1]) / 2)\n",
|
||||
" elif crop_type == 'bottom':\n",
|
||||
" box = (0, img.size[1] - size[1], img.size[0], img.size[1])\n",
|
||||
" else:\n",
|
||||
" raise ValueError('ERROR: invalid value for crop_type')\n",
|
||||
" img = img.crop(box)\n",
|
||||
" elif ratio < img_ratio:\n",
|
||||
" img = img.resize((size[1] * img.size[0] / img.size[1], size[1]),\n",
|
||||
" Image.ANTIALIAS)\n",
|
||||
" # Crop in the top, middle or bottom\n",
|
||||
" if crop_type == 'top':\n",
|
||||
" box = (0, 0, size[0], img.size[1])\n",
|
||||
" elif crop_type == 'middle':\n",
|
||||
" box = ((img.size[0] - size[0]) / 2, 0, (img.size[0] + size[0]) / 2, img.size[1])\n",
|
||||
" elif crop_type == 'bottom':\n",
|
||||
" box = (img.size[0] - size[0], 0, img.size[0], img.size[1])\n",
|
||||
" else:\n",
|
||||
" raise ValueError('ERROR: invalid value for crop_type')\n",
|
||||
" img = img.crop(box)\n",
|
||||
" else:\n",
|
||||
" img = img.resize((size[0], size[1]),\n",
|
||||
" Image.ANTIALIAS)\n",
|
||||
" # If the scale is the same, we do not need to crop\n",
|
||||
" return img\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _pre_process(imgstring):\n",
|
||||
" \"\"\"\n",
|
||||
" \"\"\"\n",
|
||||
" content = imgstring.split(';')[1]\n",
|
||||
" image_encoded = content.split(',')[1]\n",
|
||||
" body = base64.decodebytes(image_encoded.encode('utf-8'))\n",
|
||||
" im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')\n",
|
||||
" im = resize_and_crop(im, (28, 28))\n",
|
||||
" array = np.array(im).flatten().reshape(1, 28, 28, 1)\n",
|
||||
" return array"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -56,7 +133,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"iface = gradio.Interface(input='sketchpad',output='class',model_obj=model, model_type='keras')\n",
|
||||
"iface = gradio.Interface(input='sketchpad', output='class', model=model, model_type='keras', preprocessing_fn=_pre_process)\n",
|
||||
"iface.launch()"
|
||||
]
|
||||
}
|
||||
|
24
gradio.py
24
gradio.py
@ -18,18 +18,16 @@ class Interface():
|
||||
"""
|
||||
build_template_path = 'templates/tmp_html.html'
|
||||
|
||||
def __init__(self, input, output, model_obj, model_type, preprocessing_fn=None, postprocessing_fn=None):
|
||||
def __init__(self, input, output, model, model_type, preprocessing_fn=None, postprocessing_fn=None):
|
||||
"""
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn'.
|
||||
:param model_obj: the model object, such as a sklearn classifier or keras model.
|
||||
:param model_params: additional model parameters.
|
||||
"""
|
||||
self.input_interface = inputs.registry[input]()
|
||||
self.output_interface = outputs.registry[output]()
|
||||
self.input_interface = inputs.registry[input](preprocessing_fn)
|
||||
self.output_interface = outputs.registry[output](postprocessing_fn)
|
||||
self.model_type = model_type
|
||||
self.model_obj = model_obj
|
||||
self.preprocessing_fn = preprocessing_fn
|
||||
self.postprocessing_fn = postprocessing_fn
|
||||
self.model_obj = model
|
||||
|
||||
def _build_template(self):
|
||||
input_template_path = self.input_interface._get_template_path()
|
||||
@ -71,19 +69,9 @@ class Interface():
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
|
||||
if self.preprocessing_fn is None:
|
||||
processed_input = self.input_interface._pre_process(await websocket.recv())
|
||||
else:
|
||||
processed_input = self.preprocessing_fn(await websocket.recv())
|
||||
|
||||
processed_input = self.input_interface._pre_process(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
|
||||
if self.postprocessing_fn is None:
|
||||
processed_output = self.output_interface._post_process(prediction)
|
||||
else:
|
||||
processed_output = self.postprocessing_fn(prediction)
|
||||
|
||||
processed_output = self.output_interface._post_process(prediction)
|
||||
await websocket.send(str(processed_output))
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
|
@ -11,7 +11,9 @@ class AbstractInput(ABC):
|
||||
When this is subclassed, it is automatically added to the registry
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, preprocessing_fn=None):
|
||||
if preprocessing_fn is not None:
|
||||
self._pre_process = preprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
|
@ -7,9 +7,11 @@ class AbstractOutput(ABC):
|
||||
When this is subclassed, it is automatically added to the registry
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, postprocessing_fn=None):
|
||||
"""
|
||||
"""
|
||||
if postprocessing_fn is not None:
|
||||
self._post_process = postprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user