added jquery and a few more test notebooks

This commit is contained in:
Abubakar Abid 2019-03-30 12:43:06 -07:00
parent 9ecba1b89f
commit ca3d080e07
4 changed files with 1157 additions and 0 deletions

746
Test Pytorch.ipynb Normal file
View File

@ -0,0 +1,746 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import gradio"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Device configuration\n",
"device = torch.device('cpu')\n",
"\n",
"# Hyper-parameters \n",
"input_size = 784\n",
"hidden_size = 500\n",
"num_classes = 10\n",
"num_epochs = 2\n",
"batch_size = 100\n",
"learning_rate = 0.001\n",
"\n",
"# MNIST dataset \n",
"train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(), download=True)\n",
"test_dataset = torchvision.datasets.MNIST(root='../../data',train=False, transform=transforms.ToTensor())\n",
"train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size,shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/2], Step [100/600], Loss: 0.4317\n",
"Epoch [1/2], Step [200/600], Loss: 0.2267\n",
"Epoch [1/2], Step [300/600], Loss: 0.2052\n",
"Epoch [1/2], Step [400/600], Loss: 0.1179\n",
"Epoch [1/2], Step [500/600], Loss: 0.1108\n",
"Epoch [1/2], Step [600/600], Loss: 0.1830\n",
"Epoch [2/2], Step [100/600], Loss: 0.0972\n",
"Epoch [2/2], Step [200/600], Loss: 0.0662\n",
"Epoch [2/2], Step [300/600], Loss: 0.1487\n",
"Epoch [2/2], Step [400/600], Loss: 0.0640\n",
"Epoch [2/2], Step [500/600], Loss: 0.0425\n",
"Epoch [2/2], Step [600/600], Loss: 0.0979\n"
]
}
],
"source": [
"# Fully connected neural network with one hidden layer\n",
"class NeuralNet(nn.Module):\n",
" def __init__(self, input_size, hidden_size, num_classes):\n",
" super(NeuralNet, self).__init__()\n",
" self.fc1 = nn.Linear(input_size, hidden_size) \n",
" self.relu = nn.ReLU()\n",
" self.fc2 = nn.Linear(hidden_size, num_classes) \n",
" \n",
" def forward(self, x):\n",
" out = self.fc1(x)\n",
" out = self.relu(out)\n",
" out = self.fc2(out)\n",
" return out\n",
"\n",
"model = NeuralNet(input_size, hidden_size, num_classes).to(device)\n",
"\n",
"# Loss and optimizer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) \n",
"\n",
"# Train the model\n",
"total_step = len(train_loader)\n",
"for epoch in range(num_epochs):\n",
" for i, (images, labels) in enumerate(train_loader): \n",
" # Move tensors to the configured device\n",
" images = images.reshape(-1, 28*28).to(device)\n",
" labels = labels.to(device)\n",
" \n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
" \n",
" # Backward and optimize\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of the network on the 10000 test images: 97.04 %\n"
]
}
],
"source": [
"# Test the model\n",
"# In test phase, we don't need to compute gradients (for memory efficiency)\n",
"with torch.no_grad():\n",
" correct = 0\n",
" total = 0\n",
" for images, labels in test_loader:\n",
" images = images.reshape(-1, 28*28).to(device)\n",
" labels = labels.to(device)\n",
" outputs = model(images)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
" print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.float64\n",
"torch.float64\n"
]
},
{
"ename": "RuntimeError",
"evalue": "Expected object of type torch.FloatTensor but found type torch.DoubleTensor for argument #4 'mat1'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-39-d6583191b5ef>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mvalue\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mVariable\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0mprediction\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 476\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 477\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 478\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m<ipython-input-9-abba6ac73cbf>\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[0mout\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 476\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 477\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 478\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 53\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 54\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 55\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 56\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 57\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36mlinear\u001b[1;34m(input, weight, bias)\u001b[0m\n\u001b[0;32m 1022\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m2\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mbias\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1023\u001b[0m \u001b[1;31m# fused op is marginally faster\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1024\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1025\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1026\u001b[0m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mRuntimeError\u001b[0m: Expected object of type torch.FloatTensor but found type torch.DoubleTensor for argument #4 'mat1'"
]
}
],
"source": [
"value = torch.from_numpy(images.numpy())\n",
"print(value.dtype)\n",
"value = torch.autograd.Variable(value)\n",
"print(value.dtype)\n",
"prediction = model(value)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dtype('float64')"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"images.numpy().astype('float64').dtype"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(100, 10)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prediction.data.numpy().shape"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-2.94313431e+00, -1.81460023e+00, -2.08448991e-01,\n",
" -2.29123878e+00, -2.91417217e+00, -7.30904102e-01,\n",
" -1.85286796e+00, -8.89607048e+00, 3.85826755e+00,\n",
" -5.70444298e+00],\n",
" [-5.27852488e+00, -9.87475681e+00, -3.23101878e+00,\n",
" -3.27192068e+00, 2.99915814e+00, -4.19678402e+00,\n",
" -6.34950256e+00, -4.51865005e+00, 7.18662143e-01,\n",
" 4.91613436e+00],\n",
" [ 5.31619835e+00, -4.94643354e+00, -7.60741353e-01,\n",
" -3.37821364e+00, -2.58448744e+00, -1.16258490e+00,\n",
" -2.44758511e+00, -2.42502451e+00, -2.97429585e+00,\n",
" -8.71329665e-01],\n",
" [-6.26879740e+00, 5.35215139e+00, -1.39423239e+00,\n",
" -3.57356954e+00, -1.04397392e+00, -6.51621342e+00,\n",
" -5.03530502e+00, -3.36044490e-01, -1.06999171e+00,\n",
" -5.35540390e+00],\n",
" [ 5.72689712e-01, -4.73341894e+00, 9.67390776e-01,\n",
" -7.11784005e-01, -2.87459540e+00, -3.85147333e-03,\n",
" -1.63910186e+00, -3.20800948e+00, -1.86211896e+00,\n",
" -5.54116011e+00],\n",
" [-2.28098822e+00, -5.37271118e+00, 1.50332046e+00,\n",
" 1.23391628e+00, -8.18955231e+00, -7.10122824e+00,\n",
" -9.54822731e+00, 2.04598665e+00, 2.21477568e-01,\n",
" -5.82763791e-01],\n",
" [-9.81631875e-02, -4.68611860e+00, 4.79472011e-01,\n",
" -5.89810753e+00, 4.02780437e+00, -2.99009085e+00,\n",
" 9.27805245e-01, -3.35206652e+00, -2.87583947e+00,\n",
" -3.54016685e+00],\n",
" [ 7.91082382e-02, -3.29304123e+00, -3.03544235e+00,\n",
" -4.35647297e+00, -2.58279252e+00, 5.38625240e+00,\n",
" -6.60099745e-01, -4.54817867e+00, 3.72485667e-01,\n",
" -5.45329714e+00],\n",
" [-3.40048730e-01, -2.23622179e+00, -1.75288630e+00,\n",
" -4.22681570e+00, -5.96652508e-01, 9.88374472e-01,\n",
" 9.12128639e+00, -6.91706181e+00, -5.71193886e+00,\n",
" -7.31577396e+00],\n",
" [-3.58501768e+00, -9.25465584e+00, -5.46614408e-01,\n",
" -2.43667293e+00, -6.48066759e+00, -3.89760876e+00,\n",
" -1.38017788e+01, 9.51254082e+00, -2.95755482e+00,\n",
" 2.09405303e+00],\n",
" [-1.75172710e+00, -2.98126078e+00, -6.65290546e+00,\n",
" -2.85864210e+00, -5.55760241e+00, 2.44312382e+00,\n",
" -3.61811829e+00, -4.92458248e+00, 4.85441971e+00,\n",
" -3.99161577e+00],\n",
" [ 7.10333920e+00, -5.87376070e+00, 7.35742152e-01,\n",
" -4.57388163e+00, -5.62757587e+00, -8.69627833e-01,\n",
" -3.81240129e+00, -1.22680414e+00, -5.86168003e+00,\n",
" -3.06198215e+00],\n",
" [-4.35662460e+00, 7.14129639e+00, 4.32708621e-01,\n",
" -1.88450491e+00, -3.92650890e+00, -4.76905346e+00,\n",
" -4.78737926e+00, -2.06425619e+00, -1.43031192e+00,\n",
" -8.51775265e+00],\n",
" [-5.01732492e+00, -4.75002670e+00, 7.48586702e+00,\n",
" 8.71298671e-01, -6.77810001e+00, -4.56456995e+00,\n",
" -5.64565229e+00, -2.32957077e+00, 1.09462869e+00,\n",
" -5.92619801e+00],\n",
" [-3.25442362e+00, -3.75993347e+00, -3.17320156e+00,\n",
" 4.08569860e+00, -8.89118862e+00, -1.56907606e+00,\n",
" -1.29745827e+01, -4.13903046e+00, 1.30396795e+00,\n",
" 4.25274998e-01],\n",
" [-3.55768895e+00, -2.09418583e+00, -1.02781892e+00,\n",
" -6.95499659e+00, 6.68295813e+00, -2.23202968e+00,\n",
" -2.39104450e-01, -4.58472347e+00, -4.44918251e+00,\n",
" -5.70568848e+00],\n",
" [-3.22387075e+00, -1.00904818e+01, -1.66163158e+00,\n",
" -4.08348942e+00, -3.42650390e+00, -3.48878241e+00,\n",
" -1.04407053e+01, 6.01407433e+00, -1.70194793e+00,\n",
" 3.92319489e+00],\n",
" [-2.36084223e+00, -7.12867594e+00, 2.79588461e-01,\n",
" -3.45690346e+00, -3.48034048e+00, -2.39581585e+00,\n",
" -2.31899548e+00, -7.42060089e+00, 8.48381615e+00,\n",
" -6.04322863e+00],\n",
" [-4.19490576e+00, -1.02733526e+01, -1.44479012e+00,\n",
" -4.82172012e+00, 3.25319171e-02, -3.11602783e+00,\n",
" -6.72438049e+00, 3.06269407e-01, -1.48180246e+00,\n",
" 4.33638811e+00],\n",
" [-2.75081682e+00, -7.28020811e+00, -2.98303461e+00,\n",
" -2.76366043e+00, -4.09473085e+00, -3.54056692e+00,\n",
" -1.37984486e+01, 8.48108864e+00, -4.28329992e+00,\n",
" 3.44067287e+00],\n",
" [-1.47935200e+00, -4.31553364e+00, -1.80156577e+00,\n",
" -3.10084033e+00, -7.65861988e+00, -2.25040245e+00,\n",
" -5.25622416e+00, -6.60806179e+00, 6.59777069e+00,\n",
" -3.74126458e+00],\n",
" [-3.09584522e+00, -4.03994560e+00, 1.39546502e+00,\n",
" -1.71985483e+00, -3.30831736e-01, -9.78809655e-01,\n",
" 5.82206869e+00, -6.38060808e+00, -4.40905428e+00,\n",
" -5.61296463e+00],\n",
" [-4.33585930e+00, -2.28015685e+00, -6.50762844e+00,\n",
" -7.50386524e+00, 3.49900460e+00, -1.83042765e-01,\n",
" -4.88598967e+00, -2.54932785e+00, -1.70414543e+00,\n",
" -1.71335429e-01],\n",
" [-6.33788157e+00, 8.20950222e+00, -1.09110951e+00,\n",
" -4.89209270e+00, -2.66071391e+00, -5.96939754e+00,\n",
" -4.53781509e+00, -1.31869841e+00, -2.04262447e+00,\n",
" -7.43765354e+00],\n",
" [-1.98968935e+00, -1.08618965e+01, -1.73519349e+00,\n",
" -3.69034433e+00, 7.34611869e-01, -3.01101327e+00,\n",
" -8.48536491e+00, 1.87765157e+00, -1.50498271e+00,\n",
" 4.90581846e+00],\n",
" [-3.59241199e+00, -2.45339823e+00, 2.14817572e+00,\n",
" 9.27214742e-01, -8.07486057e+00, -2.06884146e+00,\n",
" -7.01898956e+00, -1.67429686e+00, 1.15502942e+00,\n",
" -4.88388157e+00],\n",
" [-5.33111954e+00, -1.68560290e+00, 4.15555894e-01,\n",
" -1.83269072e+00, -2.49122381e+00, -1.24994302e+00,\n",
" -1.90831006e+00, -3.09089851e+00, 3.47727752e+00,\n",
" -5.99055147e+00],\n",
" [-2.52312398e+00, -6.40426540e+00, 2.39003658e+00,\n",
" -5.73811722e+00, 4.70678949e+00, -5.00276566e+00,\n",
" 2.86472030e-04, -3.41072738e-01, -4.31854200e+00,\n",
" -1.90522814e+00],\n",
" [-3.09813952e+00, -7.31600761e+00, 1.46198285e+00,\n",
" -6.91232681e+00, 6.62286377e+00, -3.52320147e+00,\n",
" -2.95166254e+00, -1.69830823e+00, -3.61260891e+00,\n",
" -3.76430154e-03],\n",
" [-3.90675402e+00, -8.58440208e+00, 3.84091437e-01,\n",
" -9.11678314e-01, -8.35619164e+00, -5.30601501e+00,\n",
" -1.34841938e+01, 7.37753201e+00, -1.02802634e+00,\n",
" 3.48227167e+00],\n",
" [ 8.09841061e+00, -5.84553242e+00, -4.36034203e-02,\n",
" -3.31476593e+00, -7.94556332e+00, -1.81487560e+00,\n",
" -1.08142841e+00, -4.74964380e+00, -3.62896776e+00,\n",
" -3.67098570e+00],\n",
" [-4.82872438e+00, 6.24776268e+00, -2.81209302e+00,\n",
" -3.99583030e+00, -4.35030222e+00, -5.47072029e+00,\n",
" -5.74521732e+00, -6.83430016e-01, -2.22886491e+00,\n",
" -4.28679466e+00],\n",
" [-1.13239086e+00, -1.07505608e+01, -3.85221720e-01,\n",
" -4.16249514e+00, 1.11317813e-01, -5.01096535e+00,\n",
" -7.09929132e+00, 5.47274947e-01, -2.61468601e+00,\n",
" 5.91940689e+00],\n",
" [-4.76688623e+00, -3.39046216e+00, 8.61355019e+00,\n",
" -9.85053182e-02, -2.67433786e+00, -3.72860909e+00,\n",
" -2.70728278e+00, -5.08575344e+00, -2.89577341e+00,\n",
" -6.25328112e+00],\n",
" [-3.26012516e+00, -3.56679535e+00, -2.13104582e+00,\n",
" -4.59061265e-01, -5.79459000e+00, -1.60959554e+00,\n",
" -5.09219933e+00, -7.62273407e+00, 6.20947170e+00,\n",
" -3.95186377e+00],\n",
" [-1.31348062e+00, -5.19767284e+00, -1.56831324e-01,\n",
" -1.34070158e+00, -8.09649467e+00, -4.45510674e+00,\n",
" -1.38942327e+01, 8.03822708e+00, -4.33768272e+00,\n",
" 2.58261514e+00],\n",
" [-4.15833044e+00, -5.74338055e+00, -5.63697433e+00,\n",
" -1.24962544e+00, -8.88556576e+00, 2.62740111e+00,\n",
" -8.16130829e+00, -6.14461994e+00, 5.57290173e+00,\n",
" -2.04997277e+00],\n",
" [-3.57854271e+00, -4.63059044e+00, 7.85657692e+00,\n",
" 1.90798604e+00, -3.72001743e+00, -2.77965403e+00,\n",
" -2.73498774e+00, -5.69463062e+00, -3.98288202e+00,\n",
" -8.08887482e+00],\n",
" [ 5.72618544e-01, -2.57825613e+00, -1.72792041e+00,\n",
" -3.51139021e+00, -1.85640740e+00, 1.42014265e+00,\n",
" 8.85237503e+00, -6.99086475e+00, -6.19104099e+00,\n",
" -8.19126129e+00],\n",
" [ 8.68018246e+00, -7.40369701e+00, -2.29292154e+00,\n",
" -4.26178265e+00, -4.36462879e+00, -4.42296028e-01,\n",
" -1.77303386e+00, -1.92960644e+00, -5.18078184e+00,\n",
" -3.03363776e+00],\n",
" [ 8.04516435e-01, -2.99887037e+00, -7.78589845e-01,\n",
" -6.35569668e+00, 2.63802457e+00, -2.18808126e+00,\n",
" 3.06124830e+00, -7.12826371e-01, -6.37444162e+00,\n",
" -3.09541106e+00],\n",
" [-2.78797674e+00, -6.94107354e-01, -3.76091480e+00,\n",
" -1.08733892e-01, -4.78449726e+00, 2.34188890e+00,\n",
" 1.54788947e+00, -5.22505283e+00, -2.23338032e+00,\n",
" -4.30411434e+00],\n",
" [-6.92954063e+00, -3.42388296e+00, -4.97031927e+00,\n",
" 8.09408665e+00, -1.31485920e+01, 3.95852685e+00,\n",
" -1.46645641e+01, -1.01325397e+01, 4.84749079e-01,\n",
" -3.03775525e+00],\n",
" [-5.77645302e+00, -3.56791949e+00, -1.40874970e+00,\n",
" 2.98872280e+00, -1.05846815e+01, 1.69422913e+00,\n",
" -1.02357826e+01, -5.04028559e+00, 4.78248119e-01,\n",
" -4.10274601e+00],\n",
" [-7.54200649e+00, -6.78473139e+00, -4.01818991e+00,\n",
" 4.47627008e-01, -5.16251040e+00, 1.08258379e+00,\n",
" -1.01345644e+01, -3.04541993e+00, 3.06609201e+00,\n",
" 1.39591312e+00],\n",
" [-6.21183348e+00, -6.91482210e+00, -5.28124046e+00,\n",
" -5.25641012e+00, 2.62771082e+00, -7.80999720e-01,\n",
" -9.78050613e+00, -6.41314983e-01, -1.48106194e+00,\n",
" 5.20257568e+00],\n",
" [-4.19852495e+00, 5.86203766e+00, -3.27771401e+00,\n",
" -3.64987040e+00, -4.44587040e+00, -5.03493881e+00,\n",
" -5.90556049e+00, -1.82890546e+00, -1.87745023e+00,\n",
" -4.08415413e+00],\n",
" [ 4.50412154e-01, -5.91576576e+00, 2.40946472e-01,\n",
" -5.48061371e+00, 7.73100901e+00, -3.05289888e+00,\n",
" 8.29278767e-01, -1.09182882e+00, -9.85262775e+00,\n",
" -5.13459969e+00],\n",
" [ 7.03677177e+00, -5.27548599e+00, -8.36283922e-01,\n",
" -3.35303903e+00, -8.96849060e+00, 1.27885199e+00,\n",
" -3.88307619e+00, -1.76440930e+00, -4.93410468e-01,\n",
" -4.81816626e+00],\n",
" [ 2.10164666e-01, -2.73290563e+00, -9.83471498e-02,\n",
" -3.72992039e+00, 1.46693587e+00, -3.98113275e+00,\n",
" 8.09136391e+00, -2.84822083e+00, -7.46374559e+00,\n",
" -6.29052639e+00],\n",
" [-4.37877464e+00, 6.51550770e+00, -2.54614544e+00,\n",
" -3.77313089e+00, -2.84641075e+00, -4.26150846e+00,\n",
" -4.18192101e+00, -1.52866042e+00, -1.82835793e+00,\n",
" -5.50512218e+00],\n",
" [ 6.21710587e+00, -6.43846989e+00, -9.54680800e-01,\n",
" -3.34206319e+00, -5.43830872e+00, -1.28298807e+00,\n",
" -5.16852379e+00, 2.62163877e-02, -2.48312187e+00,\n",
" -1.57222879e+00],\n",
" [ 1.11582479e+01, -9.11719894e+00, 3.93074095e-01,\n",
" -6.16349459e+00, -7.34191084e+00, -1.21771574e+00,\n",
" -1.06905496e+00, -4.17815685e+00, -3.26719928e+00,\n",
" -5.55661106e+00],\n",
" [ 6.24382555e-01, -2.50315714e+00, -1.32084340e-01,\n",
" -4.99415159e+00, 9.88243341e-01, -1.59092569e+00,\n",
" 3.79477382e+00, -2.61957884e+00, -4.36313868e+00,\n",
" -3.75159168e+00],\n",
" [-2.41888553e-01, -7.18930769e+00, 7.67069864e+00,\n",
" -1.17135257e-01, -1.03810005e+01, -8.67830658e+00,\n",
" -6.22834063e+00, 9.77339983e-01, -1.33681512e+00,\n",
" -4.75062370e+00],\n",
" [-2.68741202e+00, 5.81236267e+00, 1.25405538e+00,\n",
" -3.28662777e+00, -4.22186947e+00, -5.25563431e+00,\n",
" -4.08522320e+00, -2.46859384e+00, -1.61857891e+00,\n",
" -8.00895596e+00],\n",
" [-5.18400049e+00, 7.41901207e+00, 6.58489108e-01,\n",
" -3.35169506e+00, -2.88367128e+00, -6.68690157e+00,\n",
" -5.40690804e+00, -4.45784926e-01, -2.28699875e+00,\n",
" -6.92785597e+00],\n",
" [-1.77721453e+00, -8.00125504e+00, 1.24888480e-01,\n",
" -1.47997165e+00, -6.09206104e+00, -4.90260983e+00,\n",
" -1.17006941e+01, 7.79809666e+00, -3.09374452e+00,\n",
" 2.66956711e+00],\n",
" [-5.29130554e+00, -1.53717768e+00, -7.01504350e-01,\n",
" 7.18017042e-01, -8.73529816e+00, -5.64459801e+00,\n",
" -1.26275997e+01, 6.27234268e+00, -2.77057409e+00,\n",
" -1.00017822e+00],\n",
" [-4.33235550e+00, -4.33209330e-01, -1.59399450e-01,\n",
" -8.48658442e-01, -4.34042692e+00, -4.76393986e+00,\n",
" -6.11099434e+00, -1.05563331e+00, 2.25537944e+00,\n",
" -2.38121748e+00],\n",
" [-6.10877466e+00, -7.61181355e+00, -1.87066281e+00,\n",
" -6.10769939e+00, 7.97212219e+00, -5.86680508e+00,\n",
" -4.92918396e+00, 1.71273947e-03, -3.98121977e+00,\n",
" 1.64842772e+00],\n",
" [-1.94559216e+00, -1.79278302e+00, 1.03016114e+00,\n",
" -4.36209917e+00, 5.04958749e-01, -1.91686034e+00,\n",
" 8.24356842e+00, -6.12818527e+00, -4.80130959e+00,\n",
" -7.60297966e+00],\n",
" [ 5.55045605e+00, -5.02859211e+00, -8.79769921e-01,\n",
" -1.37008476e+00, -7.79281998e+00, -2.99451381e-01,\n",
" -6.21967697e+00, -5.89890778e-01, -3.06062031e+00,\n",
" -1.72406542e+00],\n",
" [-1.33018994e+00, -9.49907684e+00, -5.35935163e-02,\n",
" 4.34355378e-01, -9.34434700e+00, -2.35429907e+00,\n",
" -1.26184874e+01, 7.76982164e+00, -3.26596594e+00,\n",
" 9.92987037e-01],\n",
" [ 6.22934723e+00, -5.23340511e+00, 6.14945412e-01,\n",
" -2.38412142e+00, -4.16401768e+00, -1.99598610e+00,\n",
" -6.43658400e-01, -6.30997133e+00, -2.06985307e+00,\n",
" -4.36144400e+00],\n",
" [-6.59722328e+00, -2.47505140e+00, -1.26668119e+00,\n",
" 7.06556988e+00, -5.15069532e+00, 1.30961788e+00,\n",
" -8.17326546e+00, -9.53056812e+00, -1.30305231e-01,\n",
" -3.92515922e+00],\n",
" [-1.00758219e+00, -2.31400847e+00, -5.64392984e-01,\n",
" -5.06872416e+00, 1.46974552e+00, -6.33509874e-01,\n",
" 7.41877127e+00, -2.42539763e+00, -8.05217457e+00,\n",
" -6.49750757e+00],\n",
" [-2.42883348e+00, -4.32216549e+00, -3.23260427e+00,\n",
" -1.33107811e-01, -7.12213898e+00, 5.69173694e-01,\n",
" -5.42469597e+00, -8.09215260e+00, 5.31586075e+00,\n",
" -2.59745455e+00],\n",
" [-3.88383985e+00, -4.58199167e+00, 1.53154266e+00,\n",
" 3.95189553e-01, -8.60041714e+00, -5.04962873e+00,\n",
" -1.34752960e+01, 9.81561279e+00, -3.49055672e+00,\n",
" -5.30974925e-01],\n",
" [-5.55807590e+00, 7.13408756e+00, -1.97236156e+00,\n",
" -4.12845612e+00, -2.06563425e+00, -5.16105556e+00,\n",
" -4.59559488e+00, -7.00516820e-01, -2.03985214e+00,\n",
" -5.80674982e+00],\n",
" [-7.74202251e+00, 6.83303058e-01, -4.17243576e+00,\n",
" 4.54120445e+00, -7.56257725e+00, 7.56401730e+00,\n",
" -3.17737961e+00, -8.74649048e+00, -4.01797676e+00,\n",
" -5.54178333e+00],\n",
" [-1.94957161e+00, -5.57329273e+00, 6.73393011e+00,\n",
" -4.60381508e-01, -4.35702658e+00, -5.61905670e+00,\n",
" -5.04878044e+00, -1.16630566e+00, -1.51967692e+00,\n",
" -5.04717875e+00],\n",
" [-6.19228888e+00, -4.36215115e+00, -1.55550504e+00,\n",
" -6.67865229e+00, 9.32039070e+00, -4.30677795e+00,\n",
" -3.91612124e+00, -2.58168817e+00, -4.25660133e+00,\n",
" -1.19800948e-01],\n",
" [-4.99262571e+00, -1.04954376e+01, -3.04880023e+00,\n",
" -4.99716187e+00, 1.70066953e+00, -4.47071743e+00,\n",
" -8.73496056e+00, -2.18293667e-02, 1.81334853e-01,\n",
" 5.78750610e+00],\n",
" [-2.80345392e+00, -4.72813892e+00, 4.39481020e-01,\n",
" -6.51210117e+00, 7.47464228e+00, -4.14801550e+00,\n",
" -2.00983143e+00, -2.08157349e+00, -4.23822260e+00,\n",
" -3.23492408e-01],\n",
" [-3.56214428e+00, -3.25979948e+00, 3.80091906e+00,\n",
" 2.54850864e+00, -7.31272936e+00, -3.85736132e+00,\n",
" -5.33244038e+00, -1.57806063e+00, 6.14682198e-01,\n",
" -5.21179199e+00],\n",
" [ 1.68500233e+00, -5.23059702e+00, 4.90493655e-01,\n",
" -4.55232048e+00, 1.92345440e+00, -5.41529465e+00,\n",
" 5.91364193e+00, -4.34861851e+00, -2.59588194e+00,\n",
" -4.20972347e+00],\n",
" [-3.76849008e+00, -4.20628738e+00, -2.41676593e+00,\n",
" -7.79835606e+00, 6.51037598e+00, -3.28591871e+00,\n",
" -3.44309497e+00, -1.11759353e+00, -3.44168925e+00,\n",
" -7.18272209e-01],\n",
" [-6.25707436e+00, 8.01829052e+00, -8.70854974e-01,\n",
" -2.17074561e+00, -2.51921511e+00, -7.34632683e+00,\n",
" -7.43379021e+00, 1.21018171e+00, -2.92493868e+00,\n",
" -6.24052525e+00],\n",
" [-1.54179430e+00, -1.00130758e+01, -1.60373425e+00,\n",
" -3.60287333e+00, -1.85504389e+00, -4.95446873e+00,\n",
" -1.13778887e+01, 8.44787598e+00, -3.98272920e+00,\n",
" 3.42892790e+00],\n",
" [-2.97215080e+00, -8.74941635e+00, 6.84806108e+00,\n",
" 4.38772535e+00, -1.40652447e+01, -5.17511463e+00,\n",
" -1.13551750e+01, 4.42862844e+00, -2.02880740e+00,\n",
" -8.97097778e+00],\n",
" [ 3.19274902e-01, -3.91941166e+00, -7.31602669e-01,\n",
" -2.14523554e+00, -8.81083727e-01, 1.23677731e+00,\n",
" 6.26848316e+00, -6.33044815e+00, -4.76621914e+00,\n",
" -8.46128654e+00],\n",
" [-2.28181696e+00, -2.24040413e+00, -2.13876939e+00,\n",
" 4.07040000e-01, -4.47353649e+00, 8.65564048e-02,\n",
" 2.86701989e+00, -3.77344036e+00, -5.10816383e+00,\n",
" -6.24403095e+00],\n",
" [ 1.22107010e+01, -8.57083988e+00, -2.32739854e+00,\n",
" -7.13466644e+00, -1.25466442e+01, -1.57235324e+00,\n",
" -4.42084503e+00, 2.85823584e-01, -4.89435768e+00,\n",
" -4.91631460e+00],\n",
" [-7.29047394e+00, 5.65978003e+00, -2.83476830e+00,\n",
" -2.03514028e+00, -5.54133081e+00, -7.70633698e+00,\n",
" -7.37060356e+00, -4.82241184e-01, -1.20001662e+00,\n",
" -3.77805758e+00],\n",
" [-5.97074461e+00, -6.08490705e-02, 9.34150600e+00,\n",
" 1.63726914e+00, -1.67970066e+01, -7.40469837e+00,\n",
" -1.28762379e+01, -4.04681563e-02, -1.38486892e-01,\n",
" -5.77418995e+00],\n",
" [-7.99248219e+00, -2.33761042e-01, 7.26457715e-01,\n",
" 4.49915123e+00, -1.60805607e+01, -1.95129883e+00,\n",
" -1.50286369e+01, -5.68364429e+00, 1.22412658e+00,\n",
" -1.95644116e+00],\n",
" [-8.16811371e+00, -8.41226864e+00, -3.68259668e+00,\n",
" -7.08047009e+00, 1.02391491e+01, -3.60192299e+00,\n",
" -4.16142845e+00, -3.06098413e+00, -2.04577184e+00,\n",
" 4.95172143e-01],\n",
" [-3.68254328e+00, -6.43464565e+00, -6.30004025e+00,\n",
" -2.46165895e+00, -6.77013969e+00, 1.13204098e+01,\n",
" -9.40734673e+00, -9.09241486e+00, 1.30543506e+00,\n",
" -4.77968550e+00],\n",
" [-1.94616914e-01, -6.30213881e+00, 3.80589724e-01,\n",
" -4.76102734e+00, -7.39952683e-01, -2.18785381e+00,\n",
" 1.20491581e+01, -1.19087849e+01, -3.75392103e+00,\n",
" -1.00205212e+01],\n",
" [-6.16003990e+00, -4.73755550e+00, 3.34087610e+00,\n",
" 2.48680592e+00, -1.33565502e+01, -9.27250481e+00,\n",
" -1.76171017e+01, 1.02097149e+01, -1.63418651e+00,\n",
" -2.16037130e+00],\n",
" [-6.05271399e-01, -8.01503754e+00, -1.52590179e+00,\n",
" -1.42467916e-02, -1.16728144e+01, -1.34578657e+00,\n",
" -7.63372087e+00, -7.25058126e+00, 7.69802189e+00,\n",
" -5.38569546e+00],\n",
" [-4.35134459e+00, -1.14437904e+01, -2.29104280e+00,\n",
" -5.45764494e+00, 4.25746632e+00, -5.00416851e+00,\n",
" -6.39387798e+00, -1.44824421e+00, -3.18393517e+00,\n",
" 5.76170349e+00],\n",
" [ 8.75680828e+00, -9.78944206e+00, -5.57915449e-01,\n",
" -5.78618336e+00, -6.20521069e+00, -1.44726467e+00,\n",
" -8.49293590e-01, -9.27665424e+00, 3.29888439e+00,\n",
" -4.34358311e+00],\n",
" [-6.03934145e+00, 6.55214643e+00, -1.52804327e+00,\n",
" -5.59449673e+00, -5.46073723e+00, -6.28587151e+00,\n",
" -3.62543583e+00, -3.13480830e+00, -1.56664109e+00,\n",
" -6.25409508e+00],\n",
" [-5.58640432e+00, 1.46627307e+00, 1.29633894e+01,\n",
" 2.12619233e+00, -1.86619873e+01, -6.41922855e+00,\n",
" -1.27642832e+01, -3.96311998e+00, -5.84557891e-01,\n",
" -9.79300308e+00],\n",
" [-6.58041859e+00, -1.57953396e-01, -4.17795897e-01,\n",
" 8.16717339e+00, -1.60149975e+01, 1.75709569e+00,\n",
" -1.52297287e+01, -5.96669102e+00, -4.64139891e+00,\n",
" -9.46015775e-01],\n",
" [-7.29091740e+00, -9.98701859e+00, -4.99693775e+00,\n",
" -6.66315222e+00, 8.48139668e+00, -5.61966324e+00,\n",
" -4.35089779e+00, -1.19920588e+00, 5.65500855e-01,\n",
" 1.40877223e+00],\n",
" [-3.93445063e+00, -3.84826946e+00, -6.23737240e+00,\n",
" -6.25728607e+00, -3.60181952e+00, 1.01498127e+01,\n",
" -2.82543993e+00, -7.06774330e+00, 1.51875269e+00,\n",
" -8.45338631e+00],\n",
" [-4.64281917e-01, -6.07886744e+00, -5.69949031e-01,\n",
" -4.25031662e+00, 5.37574291e-01, -5.09637237e-01,\n",
" 1.14183540e+01, -9.18782711e+00, -6.76804829e+00,\n",
" -9.76811504e+00]], dtype=float32)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32\n",
"torch.float32\n",
"float32\n",
"torch.float32\n",
"float32\n",
"torch.float32\n"
]
}
],
"source": [
"prediction.data.numpy()"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"inp = gradio.inputs.Sketchpad(flatten=True, scale=1/255, dtype='float32')\n",
"io = gradio.Interface(inputs=inp, outputs=\"label\", model_type=\"pytorch\", model=model)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No validation samples for this interface... skipping validation.\n",
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7874/interface.html\n",
"To create a public link, set `share=True` in the argument to `launch()`\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"1000\"\n",
" height=\"500\"\n",
" src=\"http://localhost:7874/interface.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x14509666898>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<gradio.networking.serve_files_in_background.<locals>.HTTPServer at 0x1450966be48>,\n",
" 'http://localhost:7874/',\n",
" None)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"float32\n",
"torch.float32\n"
]
}
],
"source": [
"io.launch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

211
Test Sklearn.ipynb Normal file
View File

@ -0,0 +1,211 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from sklearn import datasets, svm\n",
"import gradio\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# The digits dataset\n",
"digits = datasets.load_digits()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
" decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',\n",
" max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
" tol=0.001, verbose=False)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# To apply a classifier on this data, we need to flatten the image, to\n",
"# turn the data in a (samples, feature) matrix:\n",
"n_samples = len(digits.images)\n",
"data = digits.images.reshape((n_samples, -1))\n",
"\n",
"# Create a classifier: a support vector classifier\n",
"classifier = svm.SVC(gamma=0.001)\n",
"\n",
"# We learn the digits on the first half of the digits\n",
"classifier.fit(data[:n_samples // 2], digits.target[:n_samples // 2])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16.0"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.max()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAB4CAYAAADbsbjHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAACUZJREFUeJzt3V2MXVUZxvHnkYrEFDptlAsQMq1cYIy2aQkJ0UgbaYJB7RClJkJiMdIm3thoSHuBBJTENkEtmmgGvxqDGlovaCAx2BpahQjS6jQRjZq2E6x8JFCmfDVo7evFPpUJlNnrTPc55z27/1/SZE7nPXuteTvznD377NXliBAAIK+3DXoCAICZEdQAkBxBDQDJEdQAkBxBDQDJEdQAkNxQBrXts2y/bPviJmtBb3uJ3vZO23vbl6DuNOXknxO2j017fH23x4uI/0bE3Ih4ssnaJti+2fYzto/a/qHts3s83hnRW9uLbf/a9vO2j/d6vM6YZ0pvP2/7j7ZftH3Y9jdsn9XjMc+U3l5v+2+dPHjW9k9sz+36OP1e8GJ7UtIXImLXDDVzIqIvP4xNsn2NpB9JWiHpWUk7JO2JiFv6NP6k2tvb90m6QtKUpG0RMafP40+qvb39oqT9kh6XdL6kByTdExF39mn8SbW3txdLejUinrN9rqQfSHoqIr7czXFSXPqwfYfte23/wvZLkm6wfYXtR21P2X7a9ndsv71TP8d22B7tPL6n8/lf2X7J9u9tL+y2tvP5j9n+e+cV8Lu2H7G9pvBL+ZykuyPirxFxRNIdkkqf2xNt6W2npz+W9JcG23NaWtTb70XEIxHx74g4LOnnkj7UXKe616LePhkRz037qxOSLum2HymCuuNaVd8g8yTdK+m4pC9Jepeqb5qrJa2b4fmflfRVSQskPSnp693W2j5f0jZJN3fGPSTp8pNPsr2w801ywVsc9/2qzkxO2i/pQtvzZphLP7Sht1m1sbcfkfREYW0vtaK3tq+0fVTSi5I+KWnLDPM4pUxB/XBE3B8RJyLiWEQ8HhGPRcTxiDgo6W5JV87w/F9GxN6I+I+kn0laMovaj0uaiIgdnc99W9L/Xw0j4lBEjETEU29x3LmSjk57fPLjc2eYSz+0obdZtaq3tm+S9EFJ36qr7YNW9DYi9kTEPEkXSbpT1QtBV/p6na/GP6c/sH2ppG9KWibpnarm+tgMz39m2sevqgrNbmsvmD6PiAjbh2tn/rqXJZ037fF50/5+kNrQ26xa01vbn1J1JvnRzqW7QWtNbzvPPWx7l6rfEi6vq58u0xn1G9/VHJf0Z0mXRMR5km6V5B7P4WlJ7zn5wLYlXdjF85+QtHja48WS/hURU81Mb9ba0NusWtFbV2+Ef1/SNRGR4bKH1JLevsEcSe/t9kmZgvqNzlV16eAVV+/4z3QtqikPSFpq+xO256i6HvbuLp7/U0k32b7U9gJJt0ja2vw0T9vQ9daVcySd3Xl8jnt86+MsDWNvV6r63r02Ivb1aI5NGMbe3mD7os7Ho6p+Y/lNt5PIHNRfUXUXxUuqXknv7fWAEfGspM+ouj73vKpXvj9Jek2SbC9ydZ/nKd84iIgHVF3D+q2kSUn/kPS1Xs97Foaut536Y6reoD2r83GaO0CmGcbe3qrqDbsH/fq9zPf3et6zMIy9/YCkR22/IulhVb91d/0C0/f7qIeJq5v+n5L06Yj43aDn0yb0tnfobe8MqreZz6gHwvbVtufZfoeq23WOS/rDgKfVCvS2d+ht72ToLUH9Zh+WdFDVLThXSxqLiNcGO6XWoLe9Q297Z+C95dIHACTHGTUAJEdQA0ByvVqZ2Mj1lO3bt9fWbNiwobZm5cqVReNt2rSptmb+/PlFxyow2xv1+3atavny5bU1U1Nla3luv/322ppVq1YVHatA+t7u3r27tmZsbKzoWEuWzLQyuny8QqezwKSR/m7evLm2ZuPGjbU1CxcurK2RpH376m8t73UucEYNAMkR1ACQHEENAMkR1ACQHEENAMkR1ACQHEENAMkR1ACQXKatuN6kZDHLoUOHamteeOGFovEWLFhQW7Nt27bamuuuu65ovOxGRkZqa/bs2VN0rIceeqi2psEFLwM1MTFRW7NixYramnnzyvZEnpycLKobBiULVUp+BsfHx2tr1q0r+2+hSxa8XHXVVUXHmi3OqAEgOYIaAJIjqAEgOYIaAJIjqAEgOYIaAJIjqAEgOYIaAJIb2IKXkpvISxazHDhwoLZm0aJFRXMq2QmmZN7DsOClZFFGg7uCFO1C0hb33Xdfbc3ixYtra0p3eCnZPWdYrF27tramZCHcsmXLamtKd3jp9WKWEpxRA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJDewBS8lu64sXbq0tqZ0MUuJkpvkh8GWLVtqa2677bbamqNHjzYwm8ry5csbO1Z269evr60ZHR1t5DhSe3bGkcp+ng8ePFhbU7JYrnQhS0lWzZ8/v+hYs8UZNQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHKpF7yU7LjSpAw3tjehZKHEmjVramua/FqnpqYaO9YglXwdJQuOSnaBKbV169bGjjUMShbFHDlypLamdMFLSd2uXbtqa07n54kzagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIbmArE0tW6ezbt6+RsUpWHErS3r17a2tWr159utM5I01MTNTWLFmypA8zOT0lW5jdddddjYxVunpxZGSkkfHapCRfSlYTStK6detqazZv3lxbs2nTpqLxToUzagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQGtuClZDudkgUo27dvb6Sm1IYNGxo7FoZPyRZmu3fvrq3Zv39/bc3Y2FjBjKRVq1bV1tx4442NHCeDjRs31taUbJ9VuhBu586dtTW9XgjHGTUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByBDUAJEdQA0ByqRe8lOyaULIA5bLLLiuaU1M7ygyDkl1BShZA7Nixo2i8kkUgJYtJBq1kF5qS3WxKakp2k5HK/g1GR0dra4ZlwUvJ7i1r165tbLySxSzj4+ONjXcqnFEDQHIENQAkR1ADQHIENQAkR1ADQHIENQAkR1ADQHIENQAk54gY9BwAADPgjBoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkvsf2PN/nyaodHgAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"images_and_labels = list(zip(digits.images, digits.target))\n",
"for index, (image, label) in enumerate(images_and_labels[:4]):\n",
" plt.subplot(2, 4, index + 1)\n",
" plt.axis('off')\n",
" plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')\n",
" plt.title('Training: %i' % label)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"classifier.predict()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"expected = digits.target[n_samples // 2:]\n",
"predicted = classifier.predict(data[n_samples // 2:])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"inp = gradio.inputs.Sketchpad(shape=(8, 8), flatten=True, scale=16/255, invert_colors=False)\n",
"io = gradio.Interface(inputs=inp, outputs=\"label\", model_type=\"sklearn\", model=classifier)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No validation samples for this interface... skipping validation.\n",
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7865/interface.html\n",
"To create a public link, set `share=True` in the argument to `launch()`\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"1000\"\n",
" height=\"500\"\n",
" src=\"http://localhost:7865/interface.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x2a051defdd8>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<gradio.networking.serve_files_in_background.<locals>.HTTPServer at 0x2a051e271d0>,\n",
" 'http://localhost:7865/',\n",
" None)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"io.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6 (tensorflow)",
"language": "python",
"name": "tensorflow"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

196
Test Tensorflow.ipynb Normal file
View File

@ -0,0 +1,196 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import tensorflow as tf\n",
"import gradio"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"n_classes = 10\n",
"(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
"x_train, x_test = x_train.reshape(-1, 784) / 255.0, x_test.reshape(-1, 784) / 255.0\n",
"y_train = tf.keras.utils.to_categorical(y_train, n_classes).astype(float)\n",
"y_test = tf.keras.utils.to_categorical(y_test, n_classes).astype(float)\n",
"\n",
"learning_rate = 0.5\n",
"epochs = 5\n",
"batch_size = 100"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"x = tf.placeholder(tf.float32, [None, 784], name=\"x\")\n",
"y = tf.placeholder(tf.float32, [None, 10], name=\"y\")\n",
"\n",
"W1 = tf.Variable(tf.random_normal([784, 300], stddev=0.03), name='W1')\n",
"b1 = tf.Variable(tf.random_normal([300]), name='b1')\n",
"W2 = tf.Variable(tf.random_normal([300, 10], stddev=0.03), name='W2')\n",
"hidden_out = tf.add(tf.matmul(x, W1), b1)\n",
"hidden_out = tf.nn.relu(hidden_out)\n",
"y_ = tf.matmul(hidden_out, W2)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"probs = tf.nn.softmax(y_)\n",
"cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_, labels=y))\n",
"optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"init_op = tf.global_variables_initializer()\n",
"correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
"accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1 cost = 0.317\n",
"Epoch: 2 cost = 0.123\n",
"Epoch: 3 cost = 0.086\n",
"Epoch: 4 cost = 0.066\n",
"Epoch: 5 cost = 0.052\n"
]
}
],
"source": [
"sess = tf.Session()\n",
"sess.run(init_op)\n",
"total_batch = int(len(y_train) / batch_size)\n",
"for epoch in range(epochs):\n",
" avg_cost = 0\n",
" for start, end in zip(range(0, len(y_train), batch_size), range(batch_size, len(y_train)+1, batch_size)): \n",
" batch_x = x_train[start: end]\n",
" batch_y = y_train[start: end]\n",
" _, c = sess.run([optimizer, cross_entropy], feed_dict={x: batch_x, y: batch_y})\n",
" avg_cost += c / total_batch"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def predict(inp):\n",
" return sess.run(probs, feed_dict={x:inp})"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"inp = gradio.inputs.Sketchpad(flatten=True)\n",
"io = gradio.Interface(inputs=inp, outputs=\"label\", model_type=\"pyfunc\", model=predict)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No validation samples for this interface... skipping validation.\n",
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model is running locally at: http://localhost:7868/interface.html\n",
"To create a public link, set `share=True` in the argument to `launch()`\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"1000\"\n",
" height=\"500\"\n",
" src=\"http://localhost:7868/interface.html\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x2a126711048>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"(<gradio.networking.serve_files_in_background.<locals>.HTTPServer at 0x2a1266b6b38>,\n",
" 'http://localhost:7868/',\n",
" None)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"io.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6 (tensorflow)",
"language": "python",
"name": "tensorflow"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long