mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
added preprocessing so imagenet models work out of the box
This commit is contained in:
parent
2dc3e5a591
commit
d7bd7f9355
@ -19,100 +19,35 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
|
||||
"x_train, x_test = x_train / 255.0, x_test / 255.0"
|
||||
"model = tf.keras.applications.inception_v3.InceptionV3()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(60000, 28, 28)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_train.shape"
|
||||
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299)\n",
|
||||
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = tf.keras.models.Sequential([\n",
|
||||
" tf.keras.layers.Flatten(),\n",
|
||||
" tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
|
||||
" tf.keras.layers.Dropout(0.2),\n",
|
||||
" tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"model.compile(optimizer='adam',\n",
|
||||
" loss='sparse_categorical_crossentropy',\n",
|
||||
" metrics=['accuracy'])"
|
||||
"iface = gradio.Interface(inputs=inp, \n",
|
||||
" outputs=out,\n",
|
||||
" model=model, \n",
|
||||
" model_type='keras')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/1\n",
|
||||
"3/3 [==============================] - 6s 2s/step - loss: 2.1068 - acc: 0.3067\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tensorflow.python.keras.callbacks.History at 0x28f0f75c550>"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.fit(x_train, y_train, epochs=1, steps_per_epoch=3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inp = gradio.inputs.ImageUpload(image_width=28, image_height=28, num_channels=None, image_mode='L')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"iface = gradio.Interface(inputs=inp, outputs=\"label\", model=model, model_type='keras')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -120,7 +55,7 @@
|
||||
"text": [
|
||||
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
|
||||
"Model is running locally at: http://localhost:7860/interface.html\n",
|
||||
"To create a public link, set `share=True` in the argument to `launch()`\n"
|
||||
"Model available publicly for 8 hours at: https://efb97fa5.ngrok.io/interface.html\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -137,7 +72,7 @@
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.lib.display.IFrame at 0x28f091997f0>"
|
||||
"<IPython.lib.display.IFrame at 0x118f6285eb8>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
@ -146,10 +81,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"('http://localhost:7860/interface.html', None)"
|
||||
"('http://localhost:7860/interface.html', 'https://efb97fa5.ngrok.io')"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
@ -157,15 +92,23 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"127.0.0.1 - - [06/Mar/2019 10:19:46] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [06/Mar/2019 10:19:46] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [06/Mar/2019 10:19:46] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [06/Mar/2019 10:19:46] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n"
|
||||
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /interface.html HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
|
||||
"127.0.0.1 - - [06/Mar/2019 11:24:00] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"dddddddddddddddddddd cleaver,\n",
|
||||
"dddddddddddddddddddd cleaver,\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"iface.launch(browser=True, share=False)"
|
||||
"iface.launch(inline=True, browser=True, share=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
1000
gradio/imagenet_class_labels.py
Normal file
1000
gradio/imagenet_class_labels.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -11,7 +11,6 @@ from io import BytesIO
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class AbstractInput(ABC):
|
||||
"""
|
||||
An abstract class for defining the methods that all gradio inputs should have.
|
||||
@ -99,11 +98,14 @@ class Textbox(AbstractInput):
|
||||
|
||||
|
||||
class ImageUpload(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB'):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB',
|
||||
scale = 1/127.5, shift = -1):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
self.image_mode = image_mode
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
@ -117,10 +119,12 @@ class ImageUpload(AbstractInput):
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode)
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
im = np.array(im).flatten()
|
||||
im = im * self.scale + self.shift
|
||||
if self.num_channels is None:
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height)
|
||||
array = im.reshape(1, self.image_width, self.image_height)
|
||||
else:
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
array = im.reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
|
||||
|
@ -108,8 +108,8 @@ class Interface:
|
||||
await websocket.send(str(processed_output))
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
|
||||
def predict(self, preprocessed_input):
|
||||
"""
|
||||
|
@ -7,6 +7,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import json
|
||||
from gradio import imagenet_class_labels
|
||||
|
||||
class AbstractOutput(ABC):
|
||||
"""
|
||||
@ -42,11 +43,25 @@ class Label(AbstractOutput):
|
||||
CONFIDENCES_KEY = 'confidences'
|
||||
CONFIDENCE_KEY = 'confidence'
|
||||
|
||||
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True):
|
||||
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True, label_names=None,
|
||||
max_label_length=None):
|
||||
self.num_top_classes = num_top_classes
|
||||
self.show_confidences = show_confidences
|
||||
self.label_names = label_names
|
||||
self.max_label_length = max_label_length
|
||||
super().__init__(postprocessing_fn=postprocessing_fn)
|
||||
|
||||
def get_label_name(self, label):
|
||||
if self.label_names is None:
|
||||
name = label
|
||||
elif self.label_names == 'imagenet1000':
|
||||
name = imagenet_class_labels.NAMES1000[label]
|
||||
else: # if list or dictionary
|
||||
name = self.label_names[label]
|
||||
if self.max_label_length is not None:
|
||||
name = name[:self.max_label_length]
|
||||
return name
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/label_output.html'
|
||||
|
||||
@ -54,18 +69,19 @@ class Label(AbstractOutput):
|
||||
"""
|
||||
"""
|
||||
response = dict()
|
||||
print('dddddddddddddddddddd', self.get_label_name(499))
|
||||
# TODO(abidlabs): check if list, if so convert to numpy array
|
||||
if isinstance(prediction, np.ndarray):
|
||||
prediction = prediction.squeeze()
|
||||
if prediction.size == 1: # if it's single value
|
||||
response[Label.LABEL_KEY] = np.asscalar(prediction)
|
||||
response[Label.LABEL_KEY] = self.get_label_name(np.asscalar(prediction))
|
||||
elif len(prediction.shape) == 1: # if a 1D
|
||||
response[Label.LABEL_KEY] = int(prediction.argmax())
|
||||
response[Label.LABEL_KEY] = self.get_label_name(int(prediction.argmax()))
|
||||
if self.show_confidences:
|
||||
response[Label.CONFIDENCES_KEY] = []
|
||||
for i in range(self.num_top_classes):
|
||||
response[Label.CONFIDENCES_KEY].append({
|
||||
Label.LABEL_KEY: int(prediction.argmax()),
|
||||
Label.LABEL_KEY: self.get_label_name(int(prediction.argmax())),
|
||||
Label.CONFIDENCE_KEY: float(prediction.max()),
|
||||
})
|
||||
prediction[prediction.argmax()] = 0
|
||||
|
Loading…
Reference in New Issue
Block a user