Beginner's Tutorial: Creating a Sketchpad for a Keras MNIST Model
Abubakar Abid, April 9, 2019
Gradio is a python library that makes it easy to turn your machine learning models into visual interfaces! This tutorial shows you how to do that with the "Hello World" of machine learning models: a model that we train from scratch to classify hand-written digits on the MNIST dataset. By the end, you will create an interface that allows you to draw handwritten digits and see the results of the classifier. This post comes with a companion collaboratory notebook that allows you to run the code (and embed the interface) directly in a browser window. Check out the colab notebook here.
Installing Gradio
If you haven't already installed gradio, go ahead and do so. It's super easy as long as you have Python3 already on your machine:
pip install gradio
The MNIST Dataset
The MNIST dataset consists of images of handwritten digits. We'll be training a model to classify the image into the digit written, from 0 through 9, so let's load the data.
import tensorflow as tf
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Here is a sample of handwritten digits.
Training a Keras Model
By using the keras API from the tensorflow package, we can train a model in just a few lines of code. Here, we're not going to train a very complicated model -- it'll just be a fully connected neural network. Since we're not really going for record accuracies, let's just train it only for 5 epochs.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1)
Launching a Gradio Interface
Now that we have our keras model trained, we'll want to actually define the interface. What's the appropriate interface to use? For the input, we can use a Sketchpad, so that users can use their cursor to create new digits and test the model (we call this process interactive inference). The output of the model is simply a label, so we will use the Label interface.
io = gradio.Interface(
inputs="sketchpad",
outputs="label",
model=model,
model_type='keras')
io.launch(inline=True, share=False)
And that's it. Try it out in the colab notebook here.
You'll notice that the interface is embedded directly in the colab notebook!