gradio/web/articles/beginners-tutorial-gradio-mnist-keras.html
2019-04-09 21:41:39 -07:00

91 lines
4.5 KiB
HTML

<html>
<head>
<title>Gradio</title>
<link href="https://fonts.googleapis.com/css?family=Open+Sans" rel="stylesheet">
<link href="../style/style.css" rel="stylesheet">
<link href="../style/blog.css" rel="stylesheet">
<link href="../style/gradio.css" rel="stylesheet">
<link rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.15.6/styles/github.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.15.6/highlight.min.js"></script>
<script>hljs.initHighlightingOnLoad();</script>
</head>
<body>
<nav>
<img src="../img/logo_inline.png" />
<a href="../index.html">Gradio</a>
<a href="../getting_started.html">Getting Started</a>
<a href="../sharing.html">Sharing</a>
<a href="../blog.html">Blog</a>
</nav>
<div class="content">
<div class="row">
<div class="leftcolumn">
<div class="card">
<h2>Beginner's Tutorial: Creating a Sketchpad for a Keras MNIST Model</h2>
<h4>Abubakar Abid, April 9, 2019</h4>
<img src="../img/mnist-sketchpad-screenshot.png"></img>
<p>Gradio is a python library that makes it easy to turn your machine learning models into visual interfaces!
This tutorial shows you how to do that with the "Hello World" of machine learning models: a model that we train
from scratch to classify hand-written digits on the MNIST dataset. By the end, you will create an interface
that allows you to draw handwritten digits and see the results of the classifier. This post comes with a companion
collaboratory notebook that allows you to run the code (and embed the interface) directly in a browser window.
<a href="https://colab.research.google.com/drive/1DQSuxGARUZ-v4ZOAuw-Hf-8zqegpmes-">Check out the colab notebook here.</a></p>
<h3>Installing Gradio</h3>
<p>If you haven't already installed gradio, go ahead and do so. It's super easy as long as you have Python3 already on your machine:</p>
<pre><code class="bash">pip install gradio</code></pre>
<h3>The MNIST Dataset</h3>
<p>The MNIST dataset consists of images of handwritten digits. We'll be training a model to classify the image
into the digit written, from 0 through 9, so let's load the data.
<pre><code class="python">import tensorflow as tf
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0</code></pre>
Here is a sample of handwritten digits.</p>
<img src="../img/mnist-examples.jpg"></img>
<h3>Training a Keras Model</h3>
<p>By using the keras API from the tensorflow package, we can train a model in just a few lines of code. Here,
we're not going to train a very complicated model -- it'll just be a fully connected neural network. Since we're
not really going for record accuracies, let's just train it only for 5 epochs.</p>
<pre><code class="python">model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1)</code></pre>
<h3>Launching a Gradio Interface</h3>
<p>Now that we have our keras model trained, we'll want to actually define the interface. What's the appropriate
interface to use? For the input, we can use a Sketchpad, so that users can use their cursor to create new digits and
test the model (we call this process <em>interactive inference</em>). The output of the model is simply a label,
so we will use the Label interface. </p>
<pre><code class="python">io = gradio.Interface(
inputs="sketchpad",
outputs="label",
model=model,
model_type='keras')
io.launch(inline=True, share=False)
</code></pre>
And that's it. Try it out <a href="https://colab.research.google.com/drive/1DQSuxGARUZ-v4ZOAuw-Hf-8zqegpmes-">in the colab notebook here.</a>
You'll notice that the interface is embedded directly in the colab notebook!</p>
</div>
</div>
</div>
<footer>
<img src="../img/logo_inline.png" />
</footer>
<body>
</html>