summaryrefslogtreecommitdiff
path: root/nn_tutorial.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'nn_tutorial.ipynb')
-rw-r--r--nn_tutorial.ipynb548
1 files changed, 274 insertions, 274 deletions
diff --git a/nn_tutorial.ipynb b/nn_tutorial.ipynb
index 9cd501a..05f6f72 100644
--- a/nn_tutorial.ipynb
+++ b/nn_tutorial.ipynb
@@ -2,21 +2,24 @@
"cells": [
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"\n",
"What is `torch.nn` *really*?\n",
@@ -27,7 +30,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"We recommend running this tutorial as a notebook, not a script. To download the notebook (.ipynb) file,\n",
"click the link at the top of the page.\n",
@@ -67,14 +72,14 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 1,
"metadata": {
"collapsed": false,
"jupyter": {
- "outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"from pathlib import Path\n",
"import requests\n",
@@ -94,7 +99,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"This dataset is in numpy array format, and has been stored using pickle,\n",
"a python-specific format for serializing data.\n",
@@ -103,14 +110,14 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
- "outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"import pickle\n",
"import gzip\n",
@@ -121,8 +128,10 @@
},
{
"cell_type": "code",
- "execution_count": 11,
- "metadata": {},
+ "execution_count": 3,
+ "metadata": {
+ "collapsed": false
+ },
"outputs": [
{
"name": "stdout",
@@ -144,7 +153,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Each image is 28 x 28, and is being stored as a flattened row of length\n",
"784 (=28x28). Let's take a look at one; we need to reshape it to 2d\n",
@@ -154,11 +165,10 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {
"collapsed": false,
"jupyter": {
- "outputs_hidden": false
}
},
"outputs": [
@@ -171,15 +181,20 @@
},
{
"data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAN8klEQVR4nO3df6jVdZ7H8ddrbfojxzI39iZOrWOEUdE6i9nSyjYRTj8o7FYMIzQ0JDl/JDSwyIb7xxSLIVu6rBSDDtXYMus0UJHFMNVm5S6BdDMrs21qoxjlphtmmv1a9b1/3K9xp+75nOs53/PD+34+4HDO+b7P93zffPHl99f53o8jQgAmvj/rdQMAuoOwA0kQdiAJwg4kQdiBJE7o5sJsc+of6LCI8FjT29qy277C9lu237F9ezvfBaCz3Op1dtuTJP1B0gJJOyW9JGlRROwozMOWHeiwTmzZ50l6JyLejYgvJf1G0sI2vg9AB7UT9hmS/jjq/c5q2p+wvcT2kO2hNpYFoE0dP0EXEeskrZPYjQd6qZ0t+y5JZ4x6/51qGoA+1E7YX5J0tu3v2j5R0o8kbaynLQB1a3k3PiIO2V4q6SlJkyQ9EBFv1NYZgFq1fOmtpYVxzA50XEd+VAPg+EHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEi0P2Yzjw6RJk4r1U045paPLX7p0acPaSSedVJx39uzZxfqtt95arN9zzz0Na4sWLSrO+/nnnxfrK1euLNbvvPPOYr0X2gq77fckHZB0WNKhiJhbR1MA6lfHlv3SiPiwhu8B0EEcswNJtBv2kPS07ZdtLxnrA7aX2B6yPdTmsgC0od3d+PkRscv2X0h6xvZ/R8Tm0R+IiHWS1kmS7WhzeQBa1NaWPSJ2Vc97JD0maV4dTQGoX8thtz3Z9pSjryX9QNL2uhoDUK92duMHJD1m++j3/HtE/L6WriaYM888s1g/8cQTi/WLL764WJ8/f37D2tSpU4vzXn/99cV6L+3cubNYX7NmTbE+ODjYsHbgwIHivK+++mqx/sILLxTr/ajlsEfEu5L+qsZeAHQQl96AJAg7kARhB5Ig7EAShB1IwhHd+1HbRP0F3Zw5c4r1TZs2Feudvs20Xx05cqRYv/nmm4v1Tz75pOVlDw8PF+sfffRRsf7WW2+1vOxOiwiPNZ0tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX2GkybNq1Y37JlS7E+a9asOtupVbPe9+3bV6xfeumlDWtffvllcd6svz9oF9fZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJhmyuwd69e4v1ZcuWFetXX311sf7KK68U683+pHLJtm3bivUFCxYU6wcPHizWzzvvvIa12267rTgv6sWWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4H72PnDyyScX682GF167dm3D2uLFi4vz3njjjcX6hg0binX0n5bvZ7f9gO09trePmjbN9jO2366eT62zWQD1G89u/K8kXfG1abdLejYizpb0bPUeQB9rGvaI2Czp678HXShpffV6vaRr620LQN1a/W38QEQcHSzrA0kDjT5oe4mkJS0uB0BN2r4RJiKidOItItZJWidxgg7opVYvve22PV2Squc99bUEoBNaDftGSTdVr2+S9Hg97QDolKa78bY3SPq+pNNs75T0c0krJf3W9mJJ70v6YSebnOj279/f1vwff/xxy/PecsstxfrDDz9crDcbYx39o2nYI2JRg9JlNfcCoIP4uSyQBGEHkiDsQBKEHUiCsANJcIvrBDB58uSGtSeeeKI47yWXXFKsX3nllcX6008/Xayj+xiyGUiOsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BHfWWWcV61u3bi3W9+3bV6w/99xzxfrQ0FDD2n333Vect5v/NicSrrMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJcZ09ucHCwWH/wwQeL9SlTprS87OXLlxfrDz30ULE+PDxcrGfFdXYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSILr7Cg6//zzi/XVq1cX65dd1vpgv2vXri3WV6xYUazv2rWr5WUfz1q+zm77Adt7bG8fNe0O27tsb6seV9XZLID6jWc3/leSrhhj+r9ExJzq8bt62wJQt6Zhj4jNkvZ2oRcAHdTOCbqltl+rdvNPbfQh20tsD9lu/MfIAHRcq2H/haSzJM2RNCxpVaMPRsS6iJgbEXNbXBaAGrQU9ojYHRGHI+KIpF9KmldvWwDq1lLYbU8f9XZQ0vZGnwXQH5peZ7e9QdL3JZ0mabekn1fv50gKSe9J+mlENL25mOvsE8/UqVOL9WuuuaZhrdm98vaYl4u/smnTpmJ9wYIFxfpE1eg6+wnjmHHRGJPvb7sjAF3Fz2WBJAg7kARhB5Ig7EAShB1Igltc0TNffPFFsX7CCeWLRYcOHSrWL7/88oa1559/vjjv8Yw/JQ0kR9iBJAg7kARhB5Ig7EAShB1IgrADSTS96w25XXDBBcX6DTfcUKxfeOGFDWvNrqM3s2PHjmJ98+bNbX3/RMOWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BDd79uxifenSpcX6ddddV6yffvrpx9zTeB0+fLhYHx4u//XyI0eO1NnOcY8tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX240Cza9mLFo010O6IZtfRZ86c2UpLtRgaGirWV6xYUaxv3LixznYmvKZbdttn2H7O9g7bb9i+rZo+zfYztt+unk/tfLsAWjWe3fhDkv4+Is6V9DeSbrV9rqTbJT0bEWdLerZ6D6BPNQ17RAxHxNbq9QFJb0qaIWmhpPXVx9ZLurZDPQKowTEds9ueKel7krZIGoiIoz9O/kDSQIN5lkha0kaPAGow7rPxtr8t6RFJP4uI/aNrMTI65JiDNkbEuoiYGxFz2+oUQFvGFXbb39JI0H8dEY9Wk3fbnl7Vp0va05kWAdSh6W68bUu6X9KbEbF6VGmjpJskrayeH+9IhxPAwMCYRzhfOffcc4v1e++9t1g/55xzjrmnumzZsqVYv/vuuxvWHn+8/E+GW1TrNZ5j9r+V9GNJr9veVk1brpGQ/9b2YknvS/phRzoEUIumYY+I/5I05uDuki6rtx0AncLPZYEkCDuQBGEHkiDsQBKEHUiCW1zHadq0aQ1ra9euLc47Z86cYn3WrFmttFSLF198sVhftWpVsf7UU08V65999tkx94TOYMsOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkuc5+0UUXFevLli0r1ufNm9ewNmPGjJZ6qsunn37asLZmzZrivHfddVexfvDgwZZ6Qv9hyw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaS5zj44ONhWvR07duwo1p988sli/dChQ8V66Z7zffv2FedFHmzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR0T5A/YZkh6SNCApJK2LiH+1fYekWyT9b/XR5RHxuybfVV4YgLZFxJijLo8n7NMlTY+IrbanSHpZ0rUaGY/9k4i4Z7xNEHag8xqFfTzjsw9LGq5eH7D9pqTe/mkWAMfsmI7Zbc+U9D1JW6pJS22/ZvsB26c2mGeJ7SHbQ+21CqAdTXfjv/qg/W1JL0haERGP2h6Q9KFGjuP/SSO7+jc3+Q5244EOa/mYXZJsf0vSk5KeiojVY9RnSnoyIs5v8j2EHeiwRmFvuhtv25Lul/Tm6KBXJ+6OGpS0vd0mAXTOeM7Gz5f0n5Jel3Skmrxc0iJJczSyG/+epJ9WJ/NK38WWHeiwtnbj60LYgc5reTcewMRA2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLbQzZ/KOn9Ue9Pq6b1o37trV/7kuitVXX29peNCl29n/0bC7eHImJuzxoo6Nfe+rUvid5a1a3e2I0HkiDsQBK9Dvu6Hi+/pF9769e+JHprVVd66+kxO4Du6fWWHUCXEHYgiZ6E3fYVtt+y/Y7t23vRQyO237P9uu1tvR6frhpDb4/t7aOmTbP9jO23q+cxx9jrUW932N5Vrbtttq/qUW9n2H7O9g7bb9i+rZre03VX6Ksr663rx+y2J0n6g6QFknZKeknSoojY0dVGGrD9nqS5EdHzH2DY/jtJn0h66OjQWrb/WdLeiFhZ/Ud5akT8Q5/0doeOcRjvDvXWaJjxn6iH667O4c9b0Yst+zxJ70TEuxHxpaTfSFrYgz76XkRslrT3a5MXSlpfvV6vkX8sXdegt74QEcMRsbV6fUDS0WHGe7ruCn11RS/CPkPSH0e936n+Gu89JD1t+2XbS3rdzBgGRg2z9YGkgV42M4amw3h309eGGe+bddfK8Oft4gTdN82PiL+WdKWkW6vd1b4UI8dg/XTt9BeSztLIGIDDklb1splqmPFHJP0sIvaPrvVy3Y3RV1fWWy/CvkvSGaPef6ea1hciYlf1vEfSYxo57Ognu4+OoFs97+lxP1+JiN0RcTgijkj6pXq47qphxh+R9OuIeLSa3PN1N1Zf3VpvvQj7S5LOtv1d2ydK+pGkjT3o4xtsT65OnMj2ZEk/UP8NRb1R0k3V65skPd7DXv5Evwzj3WiYcfV43fV8+POI6PpD0lUaOSP/P5L+sRc9NOhrlqRXq8cbve5N0gaN7Nb9n0bObSyW9OeSnpX0tqT/kDStj3r7N40M7f2aRoI1vUe9zdfILvprkrZVj6t6ve4KfXVlvfFzWSAJTtABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL/DyJ7caZa7LphAAAAAElFTkSuQmCC\n",
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAz0AAAM2CAYAAADcr+22AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO3de6yldX3v8c/3MFWORAb1VEkvFuFUSLFKQUWhkVtQaasVhRNNbDktmraBKFZNG6sttLWxab0hHiW1hSOeFBtNtbZUTQQExdowFonBGxVKaFELlPvFDvM7f+w17XTce5g9s+ZZm+9+vZKdZ/azLt9fdOVh3vOs9awaYwQAAKCr/7boBQAAAOxJogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoLUNi17AnlBVNyTZN8mNC14KAAAwHwckuWuM8ZTVPrDGGPNfzoJV1W1JHr/odQAAAPM1xqjVPqbr29tuXPQCAACAtaFr9AAAACQRPQAAQHOiBwAAaE30AAAArYkeAACgtYVGT1X9SFX9WVX9S1U9WFU3VtW7qupxi1wXAADQx8K+nLSqDkpyVZInJvl4kq8leXaS1yZ5YVUdPca4bVHrAwAAeljkmZ7/k6Xgec0Y4yVjjN8cYxyf5J1JDk7y1gWuDQAAaKLGGNMPrTowyT9m6UtEDxpjbNnmtscmuSVJJXniGOPeXXj+TUkOn89qAQCAtWKMUat9zKLO9Bw/23562+BJkjHG3Uk+n+QxSZ4z9cIAAIBeFvWZnoNn22+scPs3kzw/yVOTfGalJ5md0VnOIbu+NAAAoJNFnenZONveucLtW/fvN8FaAACAxhZ29baHsfV9ejv8wNEY44hlH+wzPQAAwMyizvRsPZOzcYXb993ufgAAALtkUdHz9dn2qSvc/uOz7Uqf+QEAANgpi4qey2bb51fVf1nD7JLVRye5P8nfTb0wAACgl4VEzxjjH5N8OskBSc7Y7uZzkuyT5IO78h09AAAA21rIl5MmSVUdlOSqJE9M8vEkX01yZJLjsvS2tqPGGLft4nO7kAEAADT0SPpy0q1ne56Z5MIsxc7rkxyU5Nwkz93V4AEAANjWws707EnO9AAAQE+PqDM9AAAAUxA9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWtuw6AUAMI299tpr0nkbN26cdF5HZ5555qTzHvOYx0w26+CDD55s1hlnnDHZrD/+4z+ebNYrXvGKyWYlyQMPPDDZrLe97W2TzTrnnHMmm8XiONMDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArS0seqrqxqoaK/x8e1HrAgAAeln0l5PemeRdy+y/Z+qFAAAAPS06eu4YY5y94DUAAACN+UwPAADQ2qLP9Dy6ql6Z5MlJ7k1ybZIrxhgPLXZZAABAF4uOnv2TXLTdvhuq6pfGGJ99uAdX1aYVbjpkt1cGAAC0sMi3t12Q5IQshc8+SX4yyflJDkjyt1X1jMUtDQAA6GJhZ3rGGOdst+srSX61qu5J8vokZyc5+WGe44jl9s/OAB0+h2UCAACPcGvxQgbvn22ft9BVAAAALazF6PnubLvPQlcBAAC0sBaj57mz7bcWugoAAKCFhURPVR1aVY9fZv+PJTlv9uuHpl0VAADQ0aIuZHBqkt+sqsuS3JDk7iQHJfnZJHsnuSTJHy9obQAAQCOLip7Lkhyc5Key9Ha2fZLckeRzWfrenovGGGNBawMAABpZSPTMvnj0Yb98FAAAYHetxQsZAAAAzI3oAQAAWhM9AABAa6IHAABobVFXbwMaePKTnzzZrEc96lGTzTrqqKMmm/XTP/3Tk83ab7/9JpuVJC972csmnccjy8033zzZrHPPPXeyWSeffPJks+6+++7JZiXJl7/85clmffazrnfFfDnTAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1mqMseg1zF1VbUpy+KLXAVM77LDDJp136aWXTjZr48aNk82C9WrLli2TzfrlX/7lyWbdc889k82a0i233DLpvH/7t3+bbNbXv/71yWbxyDPGqNU+xpkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWNix6AcD83HTTTZPOu+222yabtXHjxslm8cjzxS9+cbJZd9xxx2SzjjvuuMlmJcn3vve9yWZddNFFk80CcKYHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtbVj0AoD5uf322yed98Y3vnGyWT/3cz832ax/+Id/mGzWueeeO9msqV1zzTWTzTrxxBMnm3XvvfdONuvQQw+dbFaSvPa1r510HsBUnOkBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrNcZY9Brmrqo2JTl80esA5mffffedbNbdd9892azzzz9/slmnn376ZLOS5JWvfOVks/78z/98slkALNYYo1b7GGd6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoLW5RE9VnVJV76mqK6vqrqoaVfWhh3nMUVV1SVXdXlX3VdW1VXVWVe01jzUBAAAkyYY5Pc+bkzwjyT1Jbk5yyI7uXFU/n+SjSR5I8uEktyd5UZJ3Jjk6yalzWhcAALDOzevtba9L8tQk+yb5tR3dsar2TfInSR5KcuwY4/QxxhuTHJbkC0lOqaqXz2ldAADAOjeX6BljXDbG+OYYY+zE3U9J8oNJLh5jXL3NczyQpTNGycOEEwAAwM5axIUMjp9tP7nMbVckuS/JUVX16OmWBAAAdDWvz/SsxsGz7Te2v2GMsbmqbkhyaJIDk3x1R09UVZtWuGmHnykCAADWj0Wc6dk42965wu1b9+83wVoAAIDmFnGm5+HUbPuwnw8aYxyx7BMsnQE6fJ6LAgAAHpkWcaZn65mcjSvcvu929wMAANhli4ier8+2T93+hqrakOQpSTYn+daUiwIAAHpaRPRcOtu+cJnbnpfkMUmuGmM8ON2SAACArhYRPR9JcmuSl1fVM7furKq9k/z+7Nf3LWBdAABAQ3O5kEFVvSTJS2a/7j/bPreqLpz9+dYxxhuSZIxxV1W9Okvxc3lVXZzk9iQvztLlrD+S5MPzWBcAAMC8rt52WJLTttt34OwnSf4pyRu23jDG+FhVHZPkt5K8LMneSa5P8utJzh1jPOyV2wAAAHbGXKJnjHF2krNX+ZjPJ/mZecwHAABYySI+0wMAADAZ0QMAALQmegAAgNZEDwAA0Nq8rt4GsEfdddddi17CHnHnnXcuegl7zKtf/erJZn34w9N908GWLVsmmwXAfDjTAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQWo0xFr2GuauqTUkOX/Q6AB7OPvvsM9msT3ziE5PNSpJjjjlmslknnXTSZLM+/elPTzYLgO83xqjVPsaZHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtFZjjEWvYe6qalOSwxe9DoC15KCDDpp03pe+9KXJZt1xxx2Tzbrssssmm3X11VdPNitJ3vve9042q+PfP4BpjDFqtY9xpgcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK3VGGPRa5i7qtqU5PBFrwNgPTv55JMnm3XBBRdMNuuxj33sZLOm9qY3vWmyWR/84Acnm3XLLbdMNgvY88YYtdrHONMDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWaoyx6DXMXVVtSnL4otcBwDSe9rSnTTbrHe94x2SzTjjhhMlmTe3888+fbNZb3/rWyWb98z//82SzYL0aY9RqH+NMDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQ2l+ipqlOq6j1VdWVV3VVVo6o+tMJ9D5jdvtLPxfNYEwAAQJJsmNPzvDnJM5Lck+TmJIfsxGO+nORjy+z/ypzWBAAAMLfoeV2WYuf6JMckuWwnHnPNGOPsOc0HAABY1lyiZ4zxH5FTVfN4SgAAgLmY15meXfFDVfUrSZ6Q5LYkXxhjXLvA9QAAAA0tMnpOnP38h6q6PMlpY4ybduYJqmrTCjftzGeKAACAdWARl6y+L8nvJTkiyeNmP1s/B3Rsks9U1T4LWBcAANDQ5Gd6xhjfTfLb2+2+oqqen+RzSY5M8qok796J5zpiuf2zM0CH7+ZSAQCABtbMl5OOMTYn+cDs1+ctci0AAEAfayZ6Zv51tvX2NgAAYC7WWvQ8Z7b91kJXAQAAtDF59FTVkVX1qGX2H5+lLzlNkg9NuyoAAKCruVzIoKpekuQls1/3n22fW1UXzv586xjjDbM//2GSQ2eXp755tu/pSY6f/fktY4yr5rEuAACAeV297bAkp22378DZT5L8U5Kt0XNRkpOTPCvJSUl+IMl3kvxFkvPGGFfOaU0AAADziZ4xxtlJzt7J+/5pkj+dx1wAAICHs9YuZAAAADBXogcAAGhN9AAAAK2JHgAAoLUaYyx6DXNXVZuSHL7odQDQz3777TfZrBe96EWTzUqSCy64YLJZVTXZrEsvvXSyWSeeeOJks2C9GmOs+gDiTA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgNdEDAAC0JnoAAIDWRA8AANCa6AEAAFoTPQAAQGs1xlj0GuauqjYlOXzR6wCAR5IHH3xwslkbNmyYbNbmzZsnm/WCF7xgslmXX375ZLNgLRlj1Gof40wPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABa27DoBQDA7nr6058+2axTTjllslnPetazJpuVJBs29PxrwXXXXTfZrCuuuGKyWcDOc6YHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtbVj0AgCYxsEHHzzpvDPPPHOyWS996Usnm7X//vtPNquzhx56aLJZt9xyy2SztmzZMtksYOc50wMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNY2LHoBAGvN/vvvP9msV7ziFZPNOvPMMyeblSQHHHDApPPYfVdfffVks9761rdONuuv/uqvJpsFrE3O9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrux09VfWEqnpVVf1lVV1fVfdX1Z1V9bmqOr2qlp1RVUdV1SVVdXtV3VdV11bVWVW11+6uCQAAYKt5fDnpqUnel+SWJJcluSnJk5K8NMkHkpxUVaeOMcbWB1TVzyf5aJIHknw4ye1JXpTknUmOnj0nAADAbptH9HwjyYuT/M0YY8vWnVX1piR/n+RlWQqgj87275vkT5I8lOTYMcbVs/1vSXJpklOq6uVjjIvnsDYAAGCd2+23t40xLh1jfGLb4Jnt/3aS989+PXabm05J8oNJLt4aPLP7P5DkzbNff2131wUAAJDs+QsZ/Ptsu3mbfcfPtp9c5v5XJLkvyVFV9eg9uTAAAGB9mMfb25ZVVRuS/OLs120D5+DZ9hvbP2aMsbmqbkhyaJIDk3z1YWZsWuGmQ1a3WgAAoKs9eabnbUmeluSSMcanttm/cba9c4XHbd2/355aGAAAsH7skTM9VfWaJK9P8rUkv7Dah8+2Y4f3SjLGOGKF+ZuSHL7KuQAAQENzP9NTVWckeXeS65IcN8a4fbu7bD2TszHL23e7+wEAAOyyuUZPVZ2V5LwkX8lS8Hx7mbt9fbZ96jKP35DkKVm68MG35rk2AABgfZpb9FTVb2Tpy0WvyVLwfHeFu146275wmduel+QxSa4aYzw4r7UBAADr11yiZ/bFom9LsinJCWOMW3dw948kuTXJy6vqmds8x95Jfn/26/vmsS4AAIDdvpBBVZ2W5HeTPJTkyiSvqart73bjGOPCJBlj3FVVr85S/FxeVRcnuT3Ji7N0OeuPJPnw7q4LAAAgmc/V254y2+6V5KwV7vPZJBdu/WWM8bGqOibJbyV5WZK9k1yf5NeTnDvGeNgrtwEAAOyM3Y6eMcbZSc7ehcd9PsnP7O58AACAHdmTX04KAACwcKIHAABoTfQAAACtiR4AAKC1eVy9DVinnvSkJ0026yd+4icmm3XeeedNNuuQQw6ZbBbz8cUvfnGyWX/0R3802awk+fjHPz7ZrC1btkw2C8CZHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1jYsegHQ3eMf//jJZp1//vmTzUqSww47bLJZBx544GSzmI+rrrpqsllvf/vbJ5v1qU99arJZ999//2SzADpzpgcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK1tWPQCYKsjjzxysllvfOMbJ5v17Gc/e7JZP/zDPzzZLObjvvvum2zWueeeO9msJPmDP/iDyWbde++9k80C4JHHmR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALS2YdELgK1OPvnklrM6u+666yab9dd//deTzdq8efNks97+9rdPNuuOO+6YbBYArCXO9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoLUaYyx6DXNXVZuSHL7odQAAAPM1xqjVPsaZHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhtt6Onqp5QVa+qqr+squur6v6qurOqPldVp1fVf9vu/gdU1djBz8W7uyYAAICtNszhOU5N8r4ktyS5LMlNSZ6U5KVJPpDkpKo6dXz/t6B+OcnHlnm+r8xhTQAAAEnmEz3fSPLiJH8zxtiydWdVvSnJ3yd5WZYC6KPbPe6aMcbZc5gPAACwot1+e9sY49Ixxie2DZ7Z/m8nef/s12N3dw4AAMCumMeZnh3599l28zK3/VBV/UqSJyS5LckXxhjX7uH1AAAA68wei56q2pDkF2e/fnKZu5w4+9n2MZcnOW2McdNOzti0wk2H7OQyAQCA5vbkJavfluRpSS4ZY3xqm/33Jfm9JEckedzs55gsXQTh2CSfqap99uC6AACAdaS+/6Jqc3jSqtckeXeSryU5eoxx+048ZkOSzyU5MslZY4x378b8TUkO39XHAwAAa9MYo1b7mLmf6amqM7IUPNclOW5ngidJxhibs3SJ6yR53rzXBQAArE9zjZ6qOivJeVn6rp3jZldwW41/nW29vQ0AAJiLuUVPVf1GkncmuSZLwfPdXXia58y235rXugAAgPVtLtFTVW/J0oULNiU5YYxx6w7ue2RVPWqZ/ccned3s1w/NY10AAAC7fcnqqjotye8meSjJlUleU/V9ny26cYxx4ezPf5jk0NnlqW+e7Xt6kuNnf37LGOOq3V0XAABAMp/v6XnKbLtXkrNWuM9nk1w4+/NFSU5O8qwkJyX5gSTfSfIXSc4bY1w5hzUBAAAk2UOXrF40l6wGAICe1sQlqwEAANYS0QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0JroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALQmegAAgNZEDwAA0Co+k40AAAd+SURBVJroAQAAWhM9AABAa6IHAABoTfQAAACtiR4AAKA10QMAALTWNXoOWPQCAACAtWHDohewh9w12964isccMtt+bb5LoQmvD3bE64Md8fpgR7w+2BGvj//qgPzn3/NXpcYY813KI1RVbUqSMcYRi14La4/XBzvi9cGOeH2wI14f7IjXx/x0fXsbAABAEtEDAAA0J3oAAIDWRA8AANCa6AEAAFpz9TYAAKA1Z3oAAIDWRA8AANCa6AEAAFoTPQAAQGuiBwAAaE30AAAArYkeAACgtXUfPVX1I1X1Z1X1L1X1YFXdWFXvqqrHLXptLN7s9TBW+Pn2otfHnldVp1TVe6rqyqq6a/b//Yce5jFHVdUlVXV7Vd1XVddW1VlVtddU62Yaq3l9VNUBOziejKq6eOr1s+dU1ROq6lVV9ZdVdX1V3V9Vd1bV56rq9Kpa9u9gjh/rw2pfH44fu2/DohewSFV1UJKrkjwxyceTfC3Js5O8NskLq+roMcZtC1wia8OdSd61zP57pl4IC/HmJM/I0v/fNyc5ZEd3rqqfT/LRJA8k+XCS25O8KMk7kxyd5NQ9uVgmt6rXx8yXk3xsmf1fmeO6WLxTk7wvyS1JLktyU5InJXlpkg8kOamqTh3bfEu848e6surXx4zjxy6q7//fcv2oqk8leX6S14wx3rPN/nckeV2S88cYv7qo9bF4VXVjkowxDljsSliUqjouS3+ZvT7JMVn6j9P/G2O8cpn77ju738YkR48xrp7t3zvJpUmem+QVYwz/ItfEKl8fByS5Icn/HWP87+lWySJU1fFJ9knyN2OMLdvs3z/J3yf50SSnjDE+Otvv+LGO7MLr44A4fuyWdfv2tqo6MEvBc2OS92538+8kuTfJL1TVPhMvDVhDxhiXjTG+ucy/ti3nlCQ/mOTirX9hmT3HA1k6I5Akv7YHlsmCrPL1wToyxrh0jPGJbf9CO9v/7STvn/167DY3OX6sI7vw+mA3ree3tx0/2356mRfc3VX1+SxF0XOSfGbqxbGmPLqqXpnkyVmK4WuTXDHGeGixy2IN2npc+eQyt12R5L4kR1XVo8cYD063LNaYH6qqX0nyhCS3JfnCGOPaBa+Jaf37bLt5m32OH2y13OtjK8ePXbSeo+fg2fYbK9z+zSxFz1Mjeta7/ZNctN2+G6rql8YYn13EglizVjyujDE2V9UNSQ5NcmCSr065MNaUE2c//6GqLk9y2hjjpoWsiMlU1YYkvzj7ddvAcfxgR6+PrRw/dtG6fXtblt4zmyx9SH05W/fvN8FaWLsuSHJClsJnnyQ/meT8JAck+duqesbilsYa5LjCjtyX5PeSHJHkcbOfrZ8DOjbJZ7ylel14W5KnJblkjPGpbfY7fpCs/Ppw/NhN6zl6Hk7Ntt6nvY6NMc6Zve/2O2OM+8YYX5ld3OIdSf57krMXu0IeYRxX1rExxnfHGL89xvjSGOOO2c8VWXpXwReT/M8kr1rsKtmTquo1SV6fpavF/sJqHz7bOn40taPXh+PH7lvP0bP1X0w2rnD7vtvdD7a19UOGz1voKlhrHFdYtTHG5ixdojZxTGmrqs5I8u4k1yU5boxx+3Z3cfxYx3bi9bEsx4+dt56j5+uz7VNXuP3HZ9uVPvPD+vbd2dapZLa14nFl9j7tp2Tpg6nfmnJRPCL862zrmNJQVZ2V5LwsfZfKcbMrdG3P8WOd2snXx444fuyE9Rw9l822z1/mW28fm6UvAbs/yd9NvTAeEZ472/qPD9u6dLZ94TK3PS/JY5Jc5cpLLOM5s61jSjNV9RtZ+nLRa7L0F9rvrnBXx491aBWvjx1x/NgJ6zZ6xhj/mOTTWfpA+hnb3XxOlmr5g2OMeydeGmtEVR1aVY9fZv+PZelfZJLkQ9OuijXuI0luTfLyqnrm1p2zLxf8/dmv71vEwli8qjqyqh61zP7js/SF2IljSitV9ZYsfTB9U5ITxhi37uDujh/rzGpeH44fu6/W8/epVdVBSa5K8sQkH8/SJSCPTHJclt7WdtQY47bFrZBFqqqzk/xmls4K3pDk7iQHJfnZJHsnuSTJyWOM7y1qjex5VfWSJC+Z/bp/khdk6V/Trpztu3WM8Ybt7v+RJA8kuTjJ7UlenKXL0X4kyf/yRZZ9rOb1Mbus7KFJLk9y8+z2p+c/v5/lLWOMrX+55RGuqk5LcmGSh5K8J8t/FufGMcaF2zzG8WOdWO3rw/Fj963r6EmSqvrRJL+bpdPJT0hyS5KPJTlnZz9ERk9VdUySX03yU/nPS1bfkaVT0Bcluch/fPqbxe/v7OAu/zTGOGC7xxyd5Ley9DbIvZNcn+TPkpzrS217Wc3ro6pOT3Jyli5H+z+S/ECS7yT5QpLzxhhXrvQkPPLsxGsjST47xjh2u8c5fqwDq319OH7svnUfPQAAQG/r9jM9AADA+iB6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtCZ6AACA1kQPAADQmugBAABaEz0AAEBrogcAAGhN9AAAAK2JHgAAoDXRAwAAtPb/AZlgB+Ge770sAAAAAElFTkSuQmCC",
"text/plain": [
- "<Figure size 432x288 with 1 Axes>"
+ "<Figure size 864x504 with 1 Axes>"
]
},
+ "execution_count": 4,
"metadata": {
+ "image/png": {
+ "height": 411,
+ "width": 414
+ },
"needs_background": "light"
},
- "output_type": "display_data"
+ "output_type": "execute_result"
}
],
"source": [
@@ -192,7 +207,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"PyTorch uses ``torch.tensor``, rather than numpy arrays, so we need to\n",
"convert our data.\n",
@@ -201,11 +218,10 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
- "outputs_hidden": false
}
},
"outputs": [
@@ -239,7 +255,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Neural net from scratch (no torch.nn)\n",
"---------------------------------------------\n",
@@ -266,14 +284,15 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"import math\n",
"\n",
@@ -284,7 +303,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Thanks to PyTorch's ability to calculate gradients automatically, we can\n",
"use any standard Python function (or callable object) as a model! So\n",
@@ -300,14 +321,15 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def log_softmax(x):\n",
" return x - x.exp().sum(-1).log().unsqueeze(-1)\n",
@@ -318,7 +340,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"In the above, the ``@`` stands for the dot product operation. We will call\n",
"our function on one batch of data (in this case, 64 images). This is\n",
@@ -329,7 +353,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -337,14 +361,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([-2.2680, -1.7434, -2.2746, -2.7562, -2.7793, -2.4086, -2.2656, -2.2761,\n",
- " -2.1634, -2.5035], grad_fn=<SelectBackward>) torch.Size([64, 10])\n"
- ]
- }
],
"source": [
"bs = 64 # batch size\n",
@@ -357,7 +373,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"As you see, the ``preds`` tensor contains not only the tensor values, but also a\n",
"gradient function. We'll use this later to do backprop.\n",
@@ -369,14 +387,15 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def nll(input, target):\n",
" return -input[range(target.shape[0]), target].mean()\n",
@@ -386,7 +405,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Let's check our loss with our random model, so we can see if we improve\n",
"after a backprop pass later.\n",
@@ -395,7 +416,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -403,13 +424,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(2.4159, grad_fn=<NegBackward>)\n"
- ]
- }
],
"source": [
"yb = y_train[0:bs]\n",
@@ -418,7 +432,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Let's also implement a function to calculate the accuracy of our model.\n",
"For each prediction, if the index with the largest value matches the\n",
@@ -428,14 +444,15 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def accuracy(out, yb):\n",
" preds = torch.argmax(out, dim=1)\n",
@@ -444,7 +461,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Let's check the accuracy of our random model, so we can see if our\n",
"accuracy improves as our loss improves.\n",
@@ -453,7 +472,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -461,13 +480,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.0781)\n"
- ]
- }
],
"source": [
"print(accuracy(preds, yb))"
@@ -475,7 +487,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"We can now run a training loop. For each iteration, we will:\n",
"\n",
@@ -506,14 +520,15 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"from IPython.core.debugger import set_trace\n",
"\n",
@@ -540,7 +555,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"That's it: we've created and trained a minimal neural network (in this case, a\n",
"logistic regression, since we have no hidden layers) entirely from scratch!\n",
@@ -553,7 +570,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -561,13 +578,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.0825, grad_fn=<NegBackward>) tensor(1.)\n"
- ]
- }
],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
@@ -575,7 +585,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Using torch.nn.functional\n",
"------------------------------\n",
@@ -603,14 +615,15 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"import torch.nn.functional as F\n",
"\n",
@@ -622,7 +635,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Note that we no longer call ``log_softmax`` in the ``model`` function. Let's\n",
"confirm that our loss and accuracy are the same as before:\n",
@@ -631,7 +646,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -639,13 +654,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.0825, grad_fn=<NllLossBackward>) tensor(1.)\n"
- ]
- }
],
"source": [
"print(loss_func(model(xb), yb), accuracy(model(xb), yb))"
@@ -653,7 +661,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Refactor using nn.Module\n",
"-----------------------------\n",
@@ -673,14 +683,15 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"from torch import nn\n",
"\n",
@@ -696,7 +707,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Since we're now using an object instead of just using a function, we\n",
"first have to instantiate our model:\n",
@@ -705,21 +718,24 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"model = Mnist_Logistic()"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Now we can calculate the loss in the same way as before. Note that\n",
"``nn.Module`` objects are used as if they are functions (i.e they are\n",
@@ -730,7 +746,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -738,13 +754,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(2.1396, grad_fn=<NllLossBackward>)\n"
- ]
- }
],
"source": [
"print(loss_func(model(xb), yb))"
@@ -752,7 +761,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Previously for our training loop we had to update the values for each parameter\n",
"by name, and manually zero out the grads for each parameter separately, like this:\n",
@@ -781,14 +792,15 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def fit():\n",
" for epoch in range(epochs):\n",
@@ -811,7 +823,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Let's double-check that our loss has gone down:\n",
"\n"
@@ -819,7 +833,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -827,13 +841,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.0820, grad_fn=<NllLossBackward>)\n"
- ]
- }
],
"source": [
"print(loss_func(model(xb), yb))"
@@ -841,7 +848,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Refactor using nn.Linear\n",
"-------------------------\n",
@@ -858,14 +867,15 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"class Mnist_Logistic(nn.Module):\n",
" def __init__(self):\n",
@@ -878,7 +888,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"We instantiate our model and calculate the loss in the same way as before:\n",
"\n"
@@ -886,7 +898,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -894,13 +906,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(2.2840, grad_fn=<NllLossBackward>)\n"
- ]
- }
],
"source": [
"model = Mnist_Logistic()\n",
@@ -909,7 +914,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"We are still able to use our same ``fit`` method as before.\n",
"\n"
@@ -917,7 +924,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -925,13 +932,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.0798, grad_fn=<NllLossBackward>)\n"
- ]
- }
],
"source": [
"fit()\n",
@@ -941,7 +941,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Refactor using optim\n",
"------------------------------\n",
@@ -968,21 +970,24 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"from torch import optim"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"We'll define a little function to create our model and optimizer so we\n",
"can reuse it in the future.\n",
@@ -991,7 +996,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -999,14 +1004,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(2.2706, grad_fn=<NllLossBackward>)\n",
- "tensor(0.0798, grad_fn=<NllLossBackward>)\n"
- ]
- }
],
"source": [
"def get_model():\n",
@@ -1034,7 +1031,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Refactor using Dataset\n",
"------------------------------\n",
@@ -1056,21 +1055,24 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"from torch.utils.data import TensorDataset"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Both ``x_train`` and ``y_train`` can be combined in a single ``TensorDataset``,\n",
"which will be easier to iterate over and slice.\n",
@@ -1079,21 +1081,24 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"train_ds = TensorDataset(x_train, y_train)"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Previously, we had to iterate through minibatches of x and y values separately:\n",
"::\n",
@@ -1110,7 +1115,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1118,13 +1123,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.0817, grad_fn=<NllLossBackward>)\n"
- ]
- }
],
"source": [
"model, opt = get_model()\n",
@@ -1144,7 +1142,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Refactor using DataLoader\n",
"------------------------------\n",
@@ -1158,14 +1158,15 @@
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
@@ -1175,7 +1176,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Previously, our loop iterated over batches (xb, yb) like this:\n",
"::\n",
@@ -1192,7 +1195,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1200,13 +1203,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.0815, grad_fn=<NllLossBackward>)\n"
- ]
- }
],
"source": [
"model, opt = get_model()\n",
@@ -1225,7 +1221,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Thanks to Pytorch's ``nn.Module``, ``nn.Parameter``, ``Dataset``, and ``DataLoader``,\n",
"our training loop is now dramatically smaller and easier to understand. Let's\n",
@@ -1255,14 +1253,15 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"train_ds = TensorDataset(x_train, y_train)\n",
"train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
@@ -1273,7 +1272,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"We will calculate and print the validation loss at the end of each epoch.\n",
"\n",
@@ -1285,7 +1286,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1293,14 +1294,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 tensor(0.3232)\n",
- "1 tensor(0.2736)\n"
- ]
- }
],
"source": [
"model, opt = get_model()\n",
@@ -1324,7 +1317,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Create fit() and get_data()\n",
"----------------------------------\n",
@@ -1342,14 +1337,15 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def loss_batch(model, loss_func, xb, yb, opt=None):\n",
" loss = loss_func(model(xb), yb)\n",
@@ -1364,7 +1360,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"``fit`` runs the necessary operations to train our model and compute the\n",
"training and validation losses for each epoch.\n",
@@ -1373,14 +1371,15 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"import numpy as np\n",
"\n",
@@ -1402,7 +1401,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"``get_data`` returns dataloaders for the training and validation sets.\n",
"\n"
@@ -1410,14 +1411,15 @@
},
{
"cell_type": "code",
- "execution_count": 35,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def get_data(train_ds, valid_ds, bs):\n",
" return (\n",
@@ -1428,7 +1430,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Now, our whole process of obtaining the data loaders and fitting the\n",
"model can be run in 3 lines of code:\n",
@@ -1437,7 +1441,7 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1445,14 +1449,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 0.36182342684268953\n",
- "1 0.3086622476875782\n"
- ]
- }
],
"source": [
"train_dl, valid_dl = get_data(train_ds, valid_ds, bs)\n",
@@ -1462,7 +1458,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"You can use these basic 3 lines of code to train a wide variety of models.\n",
"Let's see if we can use them to train a convolutional neural network (CNN)!\n",
@@ -1485,14 +1483,15 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"class Mnist_CNN(nn.Module):\n",
" def __init__(self):\n",
@@ -1514,7 +1513,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"`Momentum <https://cs231n.github.io/neural-networks-3/#sgd>`_ is a variation on\n",
"stochastic gradient descent that takes previous updates into account as well\n",
@@ -1524,7 +1525,7 @@
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1532,14 +1533,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 0.30878638651371004\n",
- "1 0.25200295938253403\n"
- ]
- }
],
"source": [
"model = Mnist_CNN()\n",
@@ -1550,7 +1543,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"nn.Sequential\n",
"------------------------\n",
@@ -1570,14 +1565,15 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"class Lambda(nn.Module):\n",
" def __init__(self, func):\n",
@@ -1594,7 +1590,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"The model created with ``Sequential`` is simply:\n",
"\n"
@@ -1602,7 +1600,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1610,14 +1608,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 0.32227418100833893\n",
- "1 0.2695485789179802\n"
- ]
- }
],
"source": [
"model = nn.Sequential(\n",
@@ -1639,7 +1629,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Wrapping DataLoader\n",
"-----------------------------\n",
@@ -1657,14 +1649,15 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def preprocess(x, y):\n",
" return x.view(-1, 1, 28, 28), y\n",
@@ -1690,7 +1683,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Next, we can replace ``nn.AvgPool2d`` with ``nn.AdaptiveAvgPool2d``, which\n",
"allows us to define the size of the *output* tensor we want, rather than\n",
@@ -1701,14 +1696,15 @@
},
{
"cell_type": "code",
- "execution_count": 43,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"model = nn.Sequential(\n",
" nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),\n",
@@ -1726,7 +1722,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Let's try it out:\n",
"\n"
@@ -1734,7 +1732,7 @@
},
{
"cell_type": "code",
- "execution_count": 44,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1742,14 +1740,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 0.3791842395067215\n",
- "1 0.26341770286560057\n"
- ]
- }
],
"source": [
"fit(epochs, model, loss_func, opt, train_dl, valid_dl)"
@@ -1757,7 +1747,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Using your GPU\n",
"---------------\n",
@@ -1771,7 +1763,7 @@
},
{
"cell_type": "code",
- "execution_count": 45,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -1779,13 +1771,6 @@
}
},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "False\n"
- ]
- }
],
"source": [
"print(torch.cuda.is_available())"
@@ -1793,7 +1778,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"And then create a device object for it:\n",
"\n"
@@ -1801,14 +1788,15 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"dev = torch.device(\n",
" \"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
@@ -1816,7 +1804,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Let's update ``preprocess`` to move batches to the GPU:\n",
"\n"
@@ -1824,14 +1814,15 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"def preprocess(x, y):\n",
" return x.view(-1, 1, 28, 28).to(dev), y.to(dev)\n",
@@ -1844,7 +1835,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Finally, we can move our model to the GPU.\n",
"\n"
@@ -1852,14 +1845,15 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"model.to(dev)\n",
"opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)"
@@ -1867,7 +1861,9 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"You should find it runs faster now:\n",
"\n"
@@ -1875,21 +1871,24 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
- "outputs": [],
+ "outputs": [
+ ],
"source": [
"fit(epochs, model, loss_func, opt, train_dl, valid_dl)"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "collapsed": false
+ },
"source": [
"Closing thoughts\n",
"-----------------\n",
@@ -1929,9 +1928,10 @@
],
"metadata": {
"kernelspec": {
- "display_name": "PyTorch",
+ "display_name": "Python 3 (Ubuntu Linux)",
"language": "python",
- "name": "pytorch"
+ "name": "python3-ubuntu",
+ "resource_dir": "/usr/local/share/jupyter/kernels/python3-ubuntu"
},
"language_info": {
"codemirror_mode": {
@@ -1943,9 +1943,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.6"
+ "version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
-}
+} \ No newline at end of file