summaryrefslogtreecommitdiff
path: root/quickstart_tutorial.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'quickstart_tutorial.ipynb')
-rw-r--r--quickstart_tutorial.ipynb307
1 files changed, 220 insertions, 87 deletions
diff --git a/quickstart_tutorial.ipynb b/quickstart_tutorial.ipynb
index 55d4021..4e847d9 100644
--- a/quickstart_tutorial.ipynb
+++ b/quickstart_tutorial.ipynb
@@ -126,7 +126,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 23,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -138,13 +138,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n",
- "Shape of y: torch.Size([64]) torch.int64\n"
+ "Shape of X [N, C, H, W]: torch.Size([200, 1, 28, 28])\n",
+ "Shape of y: torch.Size([200]) torch.int64\n"
]
}
],
"source": [
- "batch_size = 64\n",
+ "batch_size = 200\n",
"\n",
"# Create data loaders.\n",
"train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
@@ -189,7 +189,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 24,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -205,11 +205,11 @@
"NeuralNetwork(\n",
" (flatten): Flatten(start_dim=1, end_dim=-1)\n",
" (linear_relu_stack): Sequential(\n",
- " (0): Linear(in_features=784, out_features=512, bias=True)\n",
+ " (0): Linear(in_features=784, out_features=1024, bias=True)\n",
" (1): ReLU()\n",
- " (2): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (2): Linear(in_features=1024, out_features=1024, bias=True)\n",
" (3): ReLU()\n",
- " (4): Linear(in_features=512, out_features=10, bias=True)\n",
+ " (4): Linear(in_features=1024, out_features=10, bias=True)\n",
" (5): ReLU()\n",
" )\n",
")\n"
@@ -227,11 +227,11 @@
" super(NeuralNetwork, self).__init__()\n",
" self.flatten = nn.Flatten()\n",
" self.linear_relu_stack = nn.Sequential(\n",
- " nn.Linear(28*28, 512),\n",
+ " nn.Linear(28*28, 1024),\n",
" nn.ReLU(),\n",
- " nn.Linear(512, 512),\n",
+ " nn.Linear(1024, 1024),\n",
" nn.ReLU(),\n",
- " nn.Linear(512, 10),\n",
+ " nn.Linear(1024, 10),\n",
" nn.ReLU()\n",
" )\n",
"\n",
@@ -275,7 +275,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 25,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -299,7 +299,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 27,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -337,7 +337,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 28,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -374,7 +374,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 29,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -388,85 +388,170 @@
"text": [
"Epoch 1\n",
"-------------------------------\n",
- "loss: 2.300270 [ 0/60000]\n",
- "loss: 2.290948 [ 6400/60000]\n",
- "loss: 2.280627 [12800/60000]\n",
- "loss: 2.283042 [19200/60000]\n",
- "loss: 2.268817 [25600/60000]\n",
- "loss: 2.262104 [32000/60000]\n",
- "loss: 2.248093 [38400/60000]\n",
- "loss: 2.233517 [44800/60000]\n",
- "loss: 2.234299 [51200/60000]\n",
- "loss: 2.234841 [57600/60000]\n",
+ "loss: 2.301592 [ 0/60000]\n",
+ "loss: 2.289894 [20000/60000]\n",
+ "loss: 2.280160 [40000/60000]\n",
"Test Error: \n",
- " Accuracy: 51.1%, Avg loss: 2.221716 \n",
+ " Accuracy: 30.9%, Avg loss: 2.269258 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
- "loss: 2.203925 [ 0/60000]\n",
- "loss: 2.200477 [ 6400/60000]\n",
- "loss: 2.180246 [12800/60000]\n",
- "loss: 2.213840 [19200/60000]\n",
- "loss: 2.155768 [25600/60000]\n",
- "loss: 2.154045 [32000/60000]\n",
- "loss: 2.130811 [38400/60000]\n",
- "loss: 2.101326 [44800/60000]\n",
- "loss: 2.118105 [51200/60000]\n",
- "loss: 2.116674 [57600/60000]\n",
+ "loss: 2.268190 [ 0/60000]\n",
+ "loss: 2.259592 [20000/60000]\n",
+ "loss: 2.251709 [40000/60000]\n",
"Test Error: \n",
- " Accuracy: 51.0%, Avg loss: 2.093179 \n",
+ " Accuracy: 34.5%, Avg loss: 2.238048 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
- "loss: 2.054594 [ 0/60000]\n",
- "loss: 2.047193 [ 6400/60000]\n",
- "loss: 2.009665 [12800/60000]\n",
- "loss: 2.093936 [19200/60000]\n",
- "loss: 1.965194 [25600/60000]\n",
- "loss: 1.976750 [32000/60000]\n",
- "loss: 1.938592 [38400/60000]\n",
- "loss: 1.889513 [44800/60000]\n",
- "loss: 1.942611 [51200/60000]\n",
- "loss: 1.936963 [57600/60000]\n",
+ "loss: 2.236049 [ 0/60000]\n",
+ "loss: 2.230108 [20000/60000]\n",
+ "loss: 2.223023 [40000/60000]\n",
"Test Error: \n",
- " Accuracy: 51.2%, Avg loss: 1.900150 \n",
+ " Accuracy: 34.4%, Avg loss: 2.204923 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
- "loss: 1.837173 [ 0/60000]\n",
- "loss: 1.822518 [ 6400/60000]\n",
- "loss: 1.775139 [12800/60000]\n",
- "loss: 1.925843 [19200/60000]\n",
- "loss: 1.731390 [25600/60000]\n",
- "loss: 1.778743 [32000/60000]\n",
- "loss: 1.714922 [38400/60000]\n",
- "loss: 1.670009 [44800/60000]\n",
- "loss: 1.755909 [51200/60000]\n",
- "loss: 1.763030 [57600/60000]\n",
+ "loss: 2.201698 [ 0/60000]\n",
+ "loss: 2.199297 [20000/60000]\n",
+ "loss: 2.192280 [40000/60000]\n",
"Test Error: \n",
- " Accuracy: 52.5%, Avg loss: 1.711389 \n",
+ " Accuracy: 34.7%, Avg loss: 2.168218 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
- "loss: 1.621799 [ 0/60000]\n",
- "loss: 1.615258 [ 6400/60000]\n",
- "loss: 1.567131 [12800/60000]\n",
- "loss: 1.768921 [19200/60000]\n",
- "loss: 1.539987 [25600/60000]\n",
- "loss: 1.627408 [32000/60000]\n",
- "loss: 1.533756 [38400/60000]\n",
- "loss: 1.510703 [44800/60000]\n",
- "loss: 1.603583 [51200/60000]\n",
- "loss: 1.636089 [57600/60000]\n",
+ "loss: 2.163056 [ 0/60000]\n",
+ "loss: 2.165860 [20000/60000]\n",
+ "loss: 2.158360 [40000/60000]\n",
"Test Error: \n",
- " Accuracy: 53.2%, Avg loss: 1.568043 \n",
+ " Accuracy: 35.1%, Avg loss: 2.126790 \n",
+ "\n",
+ "Epoch 6\n",
+ "-------------------------------\n",
+ "loss: 2.118984 [ 0/60000]\n",
+ "loss: 2.128839 [20000/60000]\n",
+ "loss: 2.120741 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 35.5%, Avg loss: 2.080902 \n",
+ "\n",
+ "Epoch 7\n",
+ "-------------------------------\n",
+ "loss: 2.069693 [ 0/60000]\n",
+ "loss: 2.088246 [20000/60000]\n",
+ "loss: 2.080452 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 35.9%, Avg loss: 2.032795 \n",
+ "\n",
+ "Epoch 8\n",
+ "-------------------------------\n",
+ "loss: 2.017604 [ 0/60000]\n",
+ "loss: 2.045965 [20000/60000]\n",
+ "loss: 2.040246 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 36.2%, Avg loss: 1.986216 \n",
+ "\n",
+ "Epoch 9\n",
+ "-------------------------------\n",
+ "loss: 1.966350 [ 0/60000]\n",
+ "loss: 2.004781 [20000/60000]\n",
+ "loss: 2.002753 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 36.6%, Avg loss: 1.944349 \n",
+ "\n",
+ "Epoch 10\n",
+ "-------------------------------\n",
+ "loss: 1.919495 [ 0/60000]\n",
+ "loss: 1.967277 [20000/60000]\n",
+ "loss: 1.969415 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 37.1%, Avg loss: 1.908297 \n",
+ "\n",
+ "Epoch 11\n",
+ "-------------------------------\n",
+ "loss: 1.878579 [ 0/60000]\n",
+ "loss: 1.934277 [20000/60000]\n",
+ "loss: 1.940233 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 38.0%, Avg loss: 1.877542 \n",
+ "\n",
+ "Epoch 12\n",
+ "-------------------------------\n",
+ "loss: 1.843474 [ 0/60000]\n",
+ "loss: 1.905277 [20000/60000]\n",
+ "loss: 1.914412 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 38.9%, Avg loss: 1.850760 \n",
+ "\n",
+ "Epoch 13\n",
+ "-------------------------------\n",
+ "loss: 1.812682 [ 0/60000]\n",
+ "loss: 1.879232 [20000/60000]\n",
+ "loss: 1.890868 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 39.6%, Avg loss: 1.826784 \n",
+ "\n",
+ "Epoch 14\n",
+ "-------------------------------\n",
+ "loss: 1.785155 [ 0/60000]\n",
+ "loss: 1.855434 [20000/60000]\n",
+ "loss: 1.868896 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 40.1%, Avg loss: 1.804814 \n",
+ "\n",
+ "Epoch 15\n",
+ "-------------------------------\n",
+ "loss: 1.760130 [ 0/60000]\n",
+ "loss: 1.833383 [20000/60000]\n",
+ "loss: 1.848235 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 40.5%, Avg loss: 1.784439 \n",
+ "\n",
+ "Epoch 16\n",
+ "-------------------------------\n",
+ "loss: 1.737113 [ 0/60000]\n",
+ "loss: 1.812643 [20000/60000]\n",
+ "loss: 1.828667 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 40.9%, Avg loss: 1.765419 \n",
+ "\n",
+ "Epoch 17\n",
+ "-------------------------------\n",
+ "loss: 1.715825 [ 0/60000]\n",
+ "loss: 1.792752 [20000/60000]\n",
+ "loss: 1.809863 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 41.2%, Avg loss: 1.747591 \n",
+ "\n",
+ "Epoch 18\n",
+ "-------------------------------\n",
+ "loss: 1.695961 [ 0/60000]\n",
+ "loss: 1.773857 [20000/60000]\n",
+ "loss: 1.791943 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 41.5%, Avg loss: 1.723384 \n",
+ "\n",
+ "Epoch 19\n",
+ "-------------------------------\n",
+ "loss: 1.669920 [ 0/60000]\n",
+ "loss: 1.726300 [20000/60000]\n",
+ "loss: 1.758728 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 42.2%, Avg loss: 1.677625 \n",
+ "\n",
+ "Epoch 20\n",
+ "-------------------------------\n",
+ "loss: 1.626003 [ 0/60000]\n",
+ "loss: 1.674431 [20000/60000]\n",
+ "loss: 1.733464 [40000/60000]\n",
+ "Test Error: \n",
+ " Accuracy: 43.5%, Avg loss: 1.645380 \n",
"\n",
"Done!\n"
]
}
],
"source": [
- "epochs = 5\n",
+ "epochs = 20\n",
"for t in range(epochs):\n",
" print(f\"Epoch {t+1}\\n-------------------------------\")\n",
" train(train_dataloader, model, loss_fn, optimizer)\n",
@@ -504,7 +589,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 9,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -539,7 +624,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 10,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -553,7 +638,7 @@
"<All keys matched successfully>"
]
},
- "execution_count": 13,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -573,7 +658,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 22,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -585,7 +670,56 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n"
+ "Predicted: \"Sandal\", Actual: \"Ankle boot\"\n",
+ "Predicted: \"Shirt\", Actual: \"Pullover\"\n",
+ "Predicted: \"Trouser\", Actual: \"Trouser\"\n",
+ "Predicted: \"Trouser\", Actual: \"Trouser\"\n",
+ "Predicted: \"Shirt\", Actual: \"Shirt\"\n",
+ "Predicted: \"Trouser\", Actual: \"Trouser\"\n",
+ "Predicted: \"Shirt\", Actual: \"Coat\"\n",
+ "Predicted: \"Shirt\", Actual: \"Shirt\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Sandal\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
+ "Predicted: \"T-shirt/top\", Actual: \"Coat\"\n",
+ "Predicted: \"Sandal\", Actual: \"Sandal\"\n",
+ "Predicted: \"Sandal\", Actual: \"Sneaker\"\n",
+ "Predicted: \"Dress\", Actual: \"Dress\"\n",
+ "Predicted: \"Shirt\", Actual: \"Coat\"\n",
+ "Predicted: \"Trouser\", Actual: \"Trouser\"\n",
+ "Predicted: \"Shirt\", Actual: \"Pullover\"\n",
+ "Predicted: \"Shirt\", Actual: \"Coat\"\n",
+ "Predicted: \"Bag\", Actual: \"Bag\"\n",
+ "Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n",
+ "Predicted: \"Shirt\", Actual: \"Pullover\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Sandal\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Ankle boot\"\n",
+ "Predicted: \"Trouser\", Actual: \"Trouser\"\n",
+ "Predicted: \"Shirt\", Actual: \"Coat\"\n",
+ "Predicted: \"Shirt\", Actual: \"Shirt\"\n",
+ "Predicted: \"Dress\", Actual: \"T-shirt/top\"\n",
+ "Predicted: \"Sandal\", Actual: \"Ankle boot\"\n",
+ "Predicted: \"Dress\", Actual: \"Dress\"\n",
+ "Predicted: \"Bag\", Actual: \"Bag\"\n",
+ "Predicted: \"Bag\", Actual: \"Bag\"\n",
+ "Predicted: \"Dress\", Actual: \"Dress\"\n",
+ "Predicted: \"Dress\", Actual: \"Dress\"\n",
+ "Predicted: \"Bag\", Actual: \"Bag\"\n",
+ "Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
+ "Predicted: \"Sandal\", Actual: \"Sandal\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
+ "Predicted: \"T-shirt/top\", Actual: \"Ankle boot\"\n",
+ "Predicted: \"T-shirt/top\", Actual: \"Shirt\"\n",
+ "Predicted: \"Trouser\", Actual: \"Trouser\"\n",
+ "Predicted: \"Dress\", Actual: \"Dress\"\n",
+ "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n",
+ "Predicted: \"Shirt\", Actual: \"Shirt\"\n",
+ "Predicted: \"Sandal\", Actual: \"Sneaker\"\n",
+ "Predicted: \"Shirt\", Actual: \"Pullover\"\n",
+ "Predicted: \"Trouser\", Actual: \"Trouser\"\n",
+ "Predicted: \"Trouser\", Actual: \"Pullover\"\n",
+ "Predicted: \"Shirt\", Actual: \"Pullover\"\n"
]
}
],
@@ -604,21 +738,20 @@
"]\n",
"\n",
"model.eval()\n",
- "x, y = test_data[0][0], test_data[0][1]\n",
- "with torch.no_grad():\n",
- " pred = model(x)\n",
- " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
- " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
+ "for i in range(0, 50):\n",
+ " x, y = test_data[i][0], test_data[i][1]\n",
+ " with torch.no_grad():\n",
+ " pred = model(x)\n",
+ " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
+ " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
]
},
{
- "cell_type": "markdown",
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
- "source": [
- "Read more about `Saving & Loading your model <saveloadrun_tutorial.html>`_.\n",
- "\n",
- "\n"
- ]
+ "outputs": [],
+ "source": []
}
],
"metadata": {