gradio/Test Sklearn.ipynb
2019-04-09 21:06:02 -07:00

179 lines
7.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from sklearn import datasets, svm\n",
"import gradio\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# The digits dataset\n",
"digits = datasets.load_digits()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
" decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',\n",
" max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
" tol=0.001, verbose=False)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# To apply a classifier on this data, we need to flatten the image, to\n",
"# turn the data in a (samples, feature) matrix:\n",
"n_samples = len(digits.images)\n",
"data = digits.images.reshape((n_samples, -1))\n",
"\n",
"# Create a classifier: a support vector classifier\n",
"classifier = svm.SVC(gamma=0.001)\n",
"\n",
"# We learn the digits on the first half of the digits\n",
"classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16.0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.max()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAB4CAYAAADbsbjHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACUZJREFUeJzt3V2MXVUZxvHnkYrEFDptlAsQMq1cYIy2aQkJ0UgbaYJB7RClJkJiMdIm3thoSHuBBJTENkEtmmgGvxqDGlovaCAx2BpahQjS6jQRjZq2E6x8JFCmfDVo7evFPpUJlNnrTPc55z27/1/SZE7nPXuteTvznD377NXliBAAIK+3DXoCAICZEdQAkBxBDQDJEdQAkBxBDQDJEdQAkNxQBrXts2y/bPviJmtBb3uJ3vZO23vbl6DuNOXknxO2j017fH23x4uI/0bE3Ih4ssnaJti+2fYzto/a/qHts3s83hnRW9uLbf/a9vO2j/d6vM6YZ0pvP2/7j7ZftH3Y9jdsn9XjMc+U3l5v+2+dPHjW9k9sz+36OP1e8GJ7UtIXImLXDDVzIqIvP4xNsn2NpB9JWiHpWUk7JO2JiFv6NP6k2tvb90m6QtKUpG0RMafP40+qvb39oqT9kh6XdL6kByTdExF39mn8SbW3txdLejUinrN9rqQfSHoqIr7czXFSXPqwfYfte23/wvZLkm6wfYXtR21P2X7a9ndsv71TP8d22B7tPL6n8/lf2X7J9u9tL+y2tvP5j9n+e+cV8Lu2H7G9pvBL+ZykuyPirxFxRNIdkkqf2xNt6W2npz+W9JcG23NaWtTb70XEIxHx74g4LOnnkj7UXKe616LePhkRz037qxOSLum2HymCuuNaVd8g8yTdK+m4pC9Jepeqb5qrJa2b4fmflfRVSQskPSnp693W2j5f0jZJN3fGPSTp8pNPsr2w801ywVsc9/2qzkxO2i/pQtvzZphLP7Sht1m1sbcfkfREYW0vtaK3tq+0fVTSi5I+KWnLDPM4pUxB/XBE3B8RJyLiWEQ8HhGPRcTxiDgo6W5JV87w/F9GxN6I+I+kn0laMovaj0uaiIgdnc99W9L/Xw0j4lBEjETEU29x3LmSjk57fPLjc2eYSz+0obdZtaq3tm+S9EFJ36qr7YNW9DYi9kTEPEkXSbpT1QtBV/p6na/GP6c/sH2ppG9KWibpnarm+tgMz39m2sevqgrNbmsvmD6PiAjbh2tn/rqXJZ037fF50/5+kNrQ26xa01vbn1J1JvnRzqW7QWtNbzvPPWx7l6rfEi6vq58u0xn1G9/VHJf0Z0mXRMR5km6V5B7P4WlJ7zn5wLYlXdjF85+QtHja48WS/hURU81Mb9ba0NusWtFbV2+Ef1/SNRGR4bKH1JLevsEcSe/t9kmZgvqNzlV16eAVV+/4z3QtqikPSFpq+xO256i6HvbuLp7/U0k32b7U9gJJt0ja2vw0T9vQ9daVcySd3Xl8jnt86+MsDWNvV6r63r02Ivb1aI5NGMbe3mD7os7Ho6p+Y/lNt5PIHNRfUXUXxUuqXknv7fWAEfGspM+ouj73vKpXvj9Jek2SbC9ydZ/nKd84iIgHVF3D+q2kSUn/kPS1Xs97Foaut536Y6reoD2r83GaO0CmGcbe3qrqDbsH/fq9zPf3et6zMIy9/YCkR22/IulhVb91d/0C0/f7qIeJq5v+n5L06Yj43aDn0yb0tnfobe8MqreZz6gHwvbVtufZfoeq23WOS/rDgKfVCvS2d+ht72ToLUH9Zh+WdFDVLThXSxqLiNcGO6XWoLe9Q297Z+C95dIHACTHGTUAJEdQA0ByvVqZ2Mj1lO3bt9fWbNiwobZm5cqVReNt2rSptmb+/PlFxyow2xv1+3atavny5bU1U1Nla3luv/322ppVq1YVHatA+t7u3r27tmZsbKzoWEuWzLQyuny8QqezwKSR/m7evLm2ZuPGjbU1CxcurK2RpH376m8t73UucEYNAMkR1ACQHEENAMkR1ACQHEENAMkR1ACQHEENAMkR1ACQXKatuN6kZDHLoUOHamteeOGFovEWLFhQW7Nt27bamuuuu65ovOxGRkZqa/bs2VN0rIceeqi2psEFLwM1MTFRW7NixYramnnzyvZEnpycLKobBiULVUp+BsfHx2tr1q0r+2+hSxa8XHXVVUXHmi3OqAEgOYIaAJIjqAEgOYIaAJIjqAEgOYIaAJIjqAEgOYIaAJIb2IKXkpvISxazHDhwoLZm0aJFRXMq2QmmZN7DsOClZFFGg7uCFO1C0hb33Xdfbc3ixYtra0p3eCnZPWdYrF27tramZCHcsmXLamtKd3jp9WKWEpxRA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJDewBS8lu64sXbq0tqZ0MUuJkpvkh8GWLVtqa2677bbamqNHjzYwm8ry5csbO1Z269evr60ZHR1t5DhSe3bGkcp+ng8ePFhbU7JYrnQhS0lWzZ8/v+hYs8UZNQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHKpF7yU7LjSpAw3tjehZKHEmjVramua/FqnpqYaO9YglXwdJQuOSnaBKbV169bGjjUMShbFHDlypLamdMFLSd2uXbtqa07n54kzagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIbmArE0tW6ezbt6+RsUpWHErS3r17a2tWr159utM5I01MTNTWLFmypA8zOT0lW5jdddddjYxVunpxZGSkkfHapCRfSlYTStK6detqazZv3lxbs2nTpqLxToUzagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQGtuClZDudkgUo27dvb6Sm1IYNGxo7FoZPyRZmu3fvrq3Zv39/bc3Y2FjBjKRVq1bV1tx4442NHCeDjRs31taUbJ9VuhBu586dtTW9XgjHGTUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByqRe8lOyaULIA5bLLLiuaU1M7ygyDkl1BShZA7Nixo2i8kkUgJYtJBq1kF5qS3WxKakp2k5HK/g1GR0dra4ZlwUvJ7i1r165tbLySxSzj4+ONjXcqnFEDQHIENQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHIENQAk54gY9BwAADPgjBoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkvsf2PN/nyaodHgAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"images_and_labels = list(zip(digits.images, digits.target))\n",
"for index, (image, label) in enumerate(images_and_labels[:4]):\n",
" plt.subplot(2, 4, index + 1)\n",
" plt.axis('off')\n",
" plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')\n",
" plt.title('Training: %i' % label)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "predict() missing 1 required positional argument: 'X'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-5-043f82e0ca32>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mclassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m: predict() missing 1 required positional argument: 'X'"
]
}
],
"source": [
"classifier.predict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"expected = digits.target[n_samples // 2:]\n",
"predicted = classifier.predict(data[n_samples // 2:])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inp = gradio.inputs.Sketchpad(shape=(8, 8), flatten=True, scale=16/255, invert_colors=False)\n",
"io = gradio.Interface(inputs=inp, outputs=\"label\", model_type=\"sklearn\", model=classifier)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"io.launch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}