{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from sklearn import datasets, svm\n", "import gradio\n", "import matplotlib.pyplot as plt\n", "\n", "# The digits dataset\n", "digits = datasets.load_digits()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# To apply a classifier on this data, we need to flatten the image, to\n", "# turn the data in a (samples, feature) matrix:\n", "n_samples = len(digits.images)\n", "data = digits.images.reshape((n_samples, -1))\n", "\n", "# Create a classifier: a support vector classifier\n", "classifier = svm.SVC(gamma=0.001)\n", "\n", "# We learn the digits on the first half of the digits\n", "classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16.0" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.max()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAB4CAYAAADbsbjHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACUZJREFUeJzt3V2MXVUZxvHnkYrEFDptlAsQMq1cYIy2aQkJ0UgbaYJB7RClJkJiMdIm3thoSHuBBJTENkEtmmgGvxqDGlovaCAx2BpahQjS6jQRjZq2E6x8JFCmfDVo7evFPpUJlNnrTPc55z27/1/SZE7nPXuteTvznD377NXliBAAIK+3DXoCAICZEdQAkBxBDQDJEdQAkBxBDQDJEdQAkNxQBrXts2y/bPviJmtBb3uJ3vZO23vbl6DuNOXknxO2j017fH23x4uI/0bE3Ih4ssnaJti+2fYzto/a/qHts3s83hnRW9uLbf/a9vO2j/d6vM6YZ0pvP2/7j7ZftH3Y9jdsn9XjMc+U3l5v+2+dPHjW9k9sz+36OP1e8GJ7UtIXImLXDDVzIqIvP4xNsn2NpB9JWiHpWUk7JO2JiFv6NP6k2tvb90m6QtKUpG0RMafP40+qvb39oqT9kh6XdL6kByTdExF39mn8SbW3txdLejUinrN9rqQfSHoqIr7czXFSXPqwfYfte23/wvZLkm6wfYXtR21P2X7a9ndsv71TP8d22B7tPL6n8/lf2X7J9u9tL+y2tvP5j9n+e+cV8Lu2H7G9pvBL+ZykuyPirxFxRNIdkkqf2xNt6W2npz+W9JcG23NaWtTb70XEIxHx74g4LOnnkj7UXKe616LePhkRz037qxOSLum2HymCuuNaVd8g8yTdK+m4pC9Jepeqb5qrJa2b4fmflfRVSQskPSnp693W2j5f0jZJN3fGPSTp8pNPsr2w801ywVsc9/2qzkxO2i/pQtvzZphLP7Sht1m1sbcfkfREYW0vtaK3tq+0fVTSi5I+KWnLDPM4pUxB/XBE3B8RJyLiWEQ8HhGPRcTxiDgo6W5JV87w/F9GxN6I+I+kn0laMovaj0uaiIgdnc99W9L/Xw0j4lBEjETEU29x3LmSjk57fPLjc2eYSz+0obdZtaq3tm+S9EFJ36qr7YNW9DYi9kTEPEkXSbpT1QtBV/p6na/GP6c/sH2ppG9KWibpnarm+tgMz39m2sevqgrNbmsvmD6PiAjbh2tn/rqXJZ037fF50/5+kNrQ26xa01vbn1J1JvnRzqW7QWtNbzvPPWx7l6rfEi6vq58u0xn1G9/VHJf0Z0mXRMR5km6V5B7P4WlJ7zn5wLYlXdjF85+QtHja48WS/hURU81Mb9ba0NusWtFbV2+Ef1/SNRGR4bKH1JLevsEcSe/t9kmZgvqNzlV16eAVV+/4z3QtqikPSFpq+xO256i6HvbuLp7/U0k32b7U9gJJt0ja2vw0T9vQ9daVcySd3Xl8jnt86+MsDWNvV6r63r02Ivb1aI5NGMbe3mD7os7Ho6p+Y/lNt5PIHNRfUXUXxUuqXknv7fWAEfGspM+ouj73vKpXvj9Jek2SbC9ydZ/nKd84iIgHVF3D+q2kSUn/kPS1Xs97Foaut536Y6reoD2r83GaO0CmGcbe3qrqDbsH/fq9zPf3et6zMIy9/YCkR22/IulhVb91d/0C0/f7qIeJq5v+n5L06Yj43aDn0yb0tnfobe8MqreZz6gHwvbVtufZfoeq23WOS/rDgKfVCvS2d+ht72ToLUH9Zh+WdFDVLThXSxqLiNcGO6XWoLe9Q297Z+C95dIHACTHGTUAJEdQA0ByvVqZ2Mj1lO3bt9fWbNiwobZm5cqVReNt2rSptmb+/PlFxyow2xv1+3atavny5bU1U1Nla3luv/322ppVq1YVHatA+t7u3r27tmZsbKzoWEuWzLQyuny8QqezwKSR/m7evLm2ZuPGjbU1CxcurK2RpH376m8t73UucEYNAMkR1ACQHEENAMkR1ACQHEENAMkR1ACQHEENAMkR1ACQXKatuN6kZDHLoUOHamteeOGFovEWLFhQW7Nt27bamuuuu65ovOxGRkZqa/bs2VN0rIceeqi2psEFLwM1MTFRW7NixYramnnzyvZEnpycLKobBiULVUp+BsfHx2tr1q0r+2+hSxa8XHXVVUXHmi3OqAEgOYIaAJIjqAEgOYIaAJIjqAEgOYIaAJIjqAEgOYIaAJIb2IKXkpvISxazHDhwoLZm0aJFRXMq2QmmZN7DsOClZFFGg7uCFO1C0hb33Xdfbc3ixYtra0p3eCnZPWdYrF27tramZCHcsmXLamtKd3jp9WKWEpxRA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJDewBS8lu64sXbq0tqZ0MUuJkpvkh8GWLVtqa2677bbamqNHjzYwm8ry5csbO1Z269evr60ZHR1t5DhSe3bGkcp+ng8ePFhbU7JYrnQhS0lWzZ8/v+hYs8UZNQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHKpF7yU7LjSpAw3tjehZKHEmjVramua/FqnpqYaO9YglXwdJQuOSnaBKbV169bGjjUMShbFHDlypLamdMFLSd2uXbtqa07n54kzagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIbmArE0tW6ezbt6+RsUpWHErS3r17a2tWr159utM5I01MTNTWLFmypA8zOT0lW5jdddddjYxVunpxZGSkkfHapCRfSlYTStK6detqazZv3lxbs2nTpqLxToUzagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQGtuClZDudkgUo27dvb6Sm1IYNGxo7FoZPyRZmu3fvrq3Zv39/bc3Y2FjBjKRVq1bV1tx4442NHCeDjRs31taUbJ9VuhBu586dtTW9XgjHGTUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByqRe8lOyaULIA5bLLLiuaU1M7ygyDkl1BShZA7Nixo2i8kkUgJYtJBq1kF5qS3WxKakp2k5HK/g1GR0dra4ZlwUvJ7i1r165tbLySxSzj4+ONjXcqnFEDQHIENQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHIENQAk54gY9BwAADPgjBoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkvsf2PN/nyaodHgAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "images_and_labels = list(zip(digits.images, digits.target))\n", "for index, (image, label) in enumerate(images_and_labels[:4]):\n", " plt.subplot(2, 4, index + 1)\n", " plt.axis('off')\n", " plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')\n", " plt.title('Training: %i' % label)\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "predict() missing 1 required positional argument: 'X'", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mclassifier\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;31mTypeError\u001b[0m: predict() missing 1 required positional argument: 'X'" ] } ], "source": [ "classifier.predict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "expected = digits.target[n_samples // 2:]\n", "predicted = classifier.predict(data[n_samples // 2:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inp = gradio.inputs.Sketchpad(shape=(8, 8), flatten=True, scale=16/255, invert_colors=False)\n", "io = gradio.Interface(inputs=inp, outputs=\"label\", model_type=\"sklearn\", model=classifier)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "io.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }