diff options
Diffstat (limited to 'quickstart_tutorial.ipynb')
-rw-r--r-- | quickstart_tutorial.ipynb | 307 |
1 files changed, 220 insertions, 87 deletions
diff --git a/quickstart_tutorial.ipynb b/quickstart_tutorial.ipynb index 55d4021..4e847d9 100644 --- a/quickstart_tutorial.ipynb +++ b/quickstart_tutorial.ipynb @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 23, "metadata": { "collapsed": false, "jupyter": { @@ -138,13 +138,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n", - "Shape of y: torch.Size([64]) torch.int64\n" + "Shape of X [N, C, H, W]: torch.Size([200, 1, 28, 28])\n", + "Shape of y: torch.Size([200]) torch.int64\n" ] } ], "source": [ - "batch_size = 64\n", + "batch_size = 200\n", "\n", "# Create data loaders.\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 24, "metadata": { "collapsed": false, "jupyter": { @@ -205,11 +205,11 @@ "NeuralNetwork(\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (linear_relu_stack): Sequential(\n", - " (0): Linear(in_features=784, out_features=512, bias=True)\n", + " (0): Linear(in_features=784, out_features=1024, bias=True)\n", " (1): ReLU()\n", - " (2): Linear(in_features=512, out_features=512, bias=True)\n", + " (2): Linear(in_features=1024, out_features=1024, bias=True)\n", " (3): ReLU()\n", - " (4): Linear(in_features=512, out_features=10, bias=True)\n", + " (4): Linear(in_features=1024, out_features=10, bias=True)\n", " (5): ReLU()\n", " )\n", ")\n" @@ -227,11 +227,11 @@ " super(NeuralNetwork, self).__init__()\n", " self.flatten = nn.Flatten()\n", " self.linear_relu_stack = nn.Sequential(\n", - " nn.Linear(28*28, 512),\n", + " nn.Linear(28*28, 1024),\n", " nn.ReLU(),\n", - " nn.Linear(512, 512),\n", + " nn.Linear(1024, 1024),\n", " nn.ReLU(),\n", - " nn.Linear(512, 10),\n", + " nn.Linear(1024, 10),\n", " nn.ReLU()\n", " )\n", "\n", @@ -275,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 25, "metadata": { "collapsed": false, "jupyter": { @@ -299,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 27, "metadata": { "collapsed": false, "jupyter": { @@ -337,7 +337,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 28, "metadata": { "collapsed": false, "jupyter": { @@ -374,7 +374,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 29, "metadata": { "collapsed": false, "jupyter": { @@ -388,85 +388,170 @@ "text": [ "Epoch 1\n", "-------------------------------\n", - "loss: 2.300270 [ 0/60000]\n", - "loss: 2.290948 [ 6400/60000]\n", - "loss: 2.280627 [12800/60000]\n", - "loss: 2.283042 [19200/60000]\n", - "loss: 2.268817 [25600/60000]\n", - "loss: 2.262104 [32000/60000]\n", - "loss: 2.248093 [38400/60000]\n", - "loss: 2.233517 [44800/60000]\n", - "loss: 2.234299 [51200/60000]\n", - "loss: 2.234841 [57600/60000]\n", + "loss: 2.301592 [ 0/60000]\n", + "loss: 2.289894 [20000/60000]\n", + "loss: 2.280160 [40000/60000]\n", "Test Error: \n", - " Accuracy: 51.1%, Avg loss: 2.221716 \n", + " Accuracy: 30.9%, Avg loss: 2.269258 \n", "\n", "Epoch 2\n", "-------------------------------\n", - "loss: 2.203925 [ 0/60000]\n", - "loss: 2.200477 [ 6400/60000]\n", - "loss: 2.180246 [12800/60000]\n", - "loss: 2.213840 [19200/60000]\n", - "loss: 2.155768 [25600/60000]\n", - "loss: 2.154045 [32000/60000]\n", - "loss: 2.130811 [38400/60000]\n", - "loss: 2.101326 [44800/60000]\n", - "loss: 2.118105 [51200/60000]\n", - "loss: 2.116674 [57600/60000]\n", + "loss: 2.268190 [ 0/60000]\n", + "loss: 2.259592 [20000/60000]\n", + "loss: 2.251709 [40000/60000]\n", "Test Error: \n", - " Accuracy: 51.0%, Avg loss: 2.093179 \n", + " Accuracy: 34.5%, Avg loss: 2.238048 \n", "\n", "Epoch 3\n", "-------------------------------\n", - "loss: 2.054594 [ 0/60000]\n", - "loss: 2.047193 [ 6400/60000]\n", - "loss: 2.009665 [12800/60000]\n", - "loss: 2.093936 [19200/60000]\n", - "loss: 1.965194 [25600/60000]\n", - "loss: 1.976750 [32000/60000]\n", - "loss: 1.938592 [38400/60000]\n", - "loss: 1.889513 [44800/60000]\n", - "loss: 1.942611 [51200/60000]\n", - "loss: 1.936963 [57600/60000]\n", + "loss: 2.236049 [ 0/60000]\n", + "loss: 2.230108 [20000/60000]\n", + "loss: 2.223023 [40000/60000]\n", "Test Error: \n", - " Accuracy: 51.2%, Avg loss: 1.900150 \n", + " Accuracy: 34.4%, Avg loss: 2.204923 \n", "\n", "Epoch 4\n", "-------------------------------\n", - "loss: 1.837173 [ 0/60000]\n", - "loss: 1.822518 [ 6400/60000]\n", - "loss: 1.775139 [12800/60000]\n", - "loss: 1.925843 [19200/60000]\n", - "loss: 1.731390 [25600/60000]\n", - "loss: 1.778743 [32000/60000]\n", - "loss: 1.714922 [38400/60000]\n", - "loss: 1.670009 [44800/60000]\n", - "loss: 1.755909 [51200/60000]\n", - "loss: 1.763030 [57600/60000]\n", + "loss: 2.201698 [ 0/60000]\n", + "loss: 2.199297 [20000/60000]\n", + "loss: 2.192280 [40000/60000]\n", "Test Error: \n", - " Accuracy: 52.5%, Avg loss: 1.711389 \n", + " Accuracy: 34.7%, Avg loss: 2.168218 \n", "\n", "Epoch 5\n", "-------------------------------\n", - "loss: 1.621799 [ 0/60000]\n", - "loss: 1.615258 [ 6400/60000]\n", - "loss: 1.567131 [12800/60000]\n", - "loss: 1.768921 [19200/60000]\n", - "loss: 1.539987 [25600/60000]\n", - "loss: 1.627408 [32000/60000]\n", - "loss: 1.533756 [38400/60000]\n", - "loss: 1.510703 [44800/60000]\n", - "loss: 1.603583 [51200/60000]\n", - "loss: 1.636089 [57600/60000]\n", + "loss: 2.163056 [ 0/60000]\n", + "loss: 2.165860 [20000/60000]\n", + "loss: 2.158360 [40000/60000]\n", "Test Error: \n", - " Accuracy: 53.2%, Avg loss: 1.568043 \n", + " Accuracy: 35.1%, Avg loss: 2.126790 \n", + "\n", + "Epoch 6\n", + "-------------------------------\n", + "loss: 2.118984 [ 0/60000]\n", + "loss: 2.128839 [20000/60000]\n", + "loss: 2.120741 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 35.5%, Avg loss: 2.080902 \n", + "\n", + "Epoch 7\n", + "-------------------------------\n", + "loss: 2.069693 [ 0/60000]\n", + "loss: 2.088246 [20000/60000]\n", + "loss: 2.080452 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 35.9%, Avg loss: 2.032795 \n", + "\n", + "Epoch 8\n", + "-------------------------------\n", + "loss: 2.017604 [ 0/60000]\n", + "loss: 2.045965 [20000/60000]\n", + "loss: 2.040246 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 36.2%, Avg loss: 1.986216 \n", + "\n", + "Epoch 9\n", + "-------------------------------\n", + "loss: 1.966350 [ 0/60000]\n", + "loss: 2.004781 [20000/60000]\n", + "loss: 2.002753 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 36.6%, Avg loss: 1.944349 \n", + "\n", + "Epoch 10\n", + "-------------------------------\n", + "loss: 1.919495 [ 0/60000]\n", + "loss: 1.967277 [20000/60000]\n", + "loss: 1.969415 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 37.1%, Avg loss: 1.908297 \n", + "\n", + "Epoch 11\n", + "-------------------------------\n", + "loss: 1.878579 [ 0/60000]\n", + "loss: 1.934277 [20000/60000]\n", + "loss: 1.940233 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 38.0%, Avg loss: 1.877542 \n", + "\n", + "Epoch 12\n", + "-------------------------------\n", + "loss: 1.843474 [ 0/60000]\n", + "loss: 1.905277 [20000/60000]\n", + "loss: 1.914412 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 38.9%, Avg loss: 1.850760 \n", + "\n", + "Epoch 13\n", + "-------------------------------\n", + "loss: 1.812682 [ 0/60000]\n", + "loss: 1.879232 [20000/60000]\n", + "loss: 1.890868 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 39.6%, Avg loss: 1.826784 \n", + "\n", + "Epoch 14\n", + "-------------------------------\n", + "loss: 1.785155 [ 0/60000]\n", + "loss: 1.855434 [20000/60000]\n", + "loss: 1.868896 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 40.1%, Avg loss: 1.804814 \n", + "\n", + "Epoch 15\n", + "-------------------------------\n", + "loss: 1.760130 [ 0/60000]\n", + "loss: 1.833383 [20000/60000]\n", + "loss: 1.848235 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 40.5%, Avg loss: 1.784439 \n", + "\n", + "Epoch 16\n", + "-------------------------------\n", + "loss: 1.737113 [ 0/60000]\n", + "loss: 1.812643 [20000/60000]\n", + "loss: 1.828667 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 40.9%, Avg loss: 1.765419 \n", + "\n", + "Epoch 17\n", + "-------------------------------\n", + "loss: 1.715825 [ 0/60000]\n", + "loss: 1.792752 [20000/60000]\n", + "loss: 1.809863 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 41.2%, Avg loss: 1.747591 \n", + "\n", + "Epoch 18\n", + "-------------------------------\n", + "loss: 1.695961 [ 0/60000]\n", + "loss: 1.773857 [20000/60000]\n", + "loss: 1.791943 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 41.5%, Avg loss: 1.723384 \n", + "\n", + "Epoch 19\n", + "-------------------------------\n", + "loss: 1.669920 [ 0/60000]\n", + "loss: 1.726300 [20000/60000]\n", + "loss: 1.758728 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 42.2%, Avg loss: 1.677625 \n", + "\n", + "Epoch 20\n", + "-------------------------------\n", + "loss: 1.626003 [ 0/60000]\n", + "loss: 1.674431 [20000/60000]\n", + "loss: 1.733464 [40000/60000]\n", + "Test Error: \n", + " Accuracy: 43.5%, Avg loss: 1.645380 \n", "\n", "Done!\n" ] } ], "source": [ - "epochs = 5\n", + "epochs = 20\n", "for t in range(epochs):\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n", " train(train_dataloader, model, loss_fn, optimizer)\n", @@ -504,7 +589,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "metadata": { "collapsed": false, "jupyter": { @@ -539,7 +624,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { @@ -553,7 +638,7 @@ "<All keys matched successfully>" ] }, - "execution_count": 13, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -573,7 +658,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "metadata": { "collapsed": false, "jupyter": { @@ -585,7 +670,56 @@ "name": "stdout", "output_type": "stream", "text": [ - "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" + "Predicted: \"Sandal\", Actual: \"Ankle boot\"\n", + "Predicted: \"Shirt\", Actual: \"Pullover\"\n", + "Predicted: \"Trouser\", Actual: \"Trouser\"\n", + "Predicted: \"Trouser\", Actual: \"Trouser\"\n", + "Predicted: \"Shirt\", Actual: \"Shirt\"\n", + "Predicted: \"Trouser\", Actual: \"Trouser\"\n", + "Predicted: \"Shirt\", Actual: \"Coat\"\n", + "Predicted: \"Shirt\", Actual: \"Shirt\"\n", + "Predicted: \"Sneaker\", Actual: \"Sandal\"\n", + "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n", + "Predicted: \"T-shirt/top\", Actual: \"Coat\"\n", + "Predicted: \"Sandal\", Actual: \"Sandal\"\n", + "Predicted: \"Sandal\", Actual: \"Sneaker\"\n", + "Predicted: \"Dress\", Actual: \"Dress\"\n", + "Predicted: \"Shirt\", Actual: \"Coat\"\n", + "Predicted: \"Trouser\", Actual: \"Trouser\"\n", + "Predicted: \"Shirt\", Actual: \"Pullover\"\n", + "Predicted: \"Shirt\", Actual: \"Coat\"\n", + "Predicted: \"Bag\", Actual: \"Bag\"\n", + "Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n", + "Predicted: \"Shirt\", Actual: \"Pullover\"\n", + "Predicted: \"Sneaker\", Actual: \"Sandal\"\n", + "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n", + "Predicted: \"Sneaker\", Actual: \"Ankle boot\"\n", + "Predicted: \"Trouser\", Actual: \"Trouser\"\n", + "Predicted: \"Shirt\", Actual: \"Coat\"\n", + "Predicted: \"Shirt\", Actual: \"Shirt\"\n", + "Predicted: \"Dress\", Actual: \"T-shirt/top\"\n", + "Predicted: \"Sandal\", Actual: \"Ankle boot\"\n", + "Predicted: \"Dress\", Actual: \"Dress\"\n", + "Predicted: \"Bag\", Actual: \"Bag\"\n", + "Predicted: \"Bag\", Actual: \"Bag\"\n", + "Predicted: \"Dress\", Actual: \"Dress\"\n", + "Predicted: \"Dress\", Actual: \"Dress\"\n", + "Predicted: \"Bag\", Actual: \"Bag\"\n", + "Predicted: \"T-shirt/top\", Actual: \"T-shirt/top\"\n", + "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n", + "Predicted: \"Sandal\", Actual: \"Sandal\"\n", + "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n", + "Predicted: \"T-shirt/top\", Actual: \"Ankle boot\"\n", + "Predicted: \"T-shirt/top\", Actual: \"Shirt\"\n", + "Predicted: \"Trouser\", Actual: \"Trouser\"\n", + "Predicted: \"Dress\", Actual: \"Dress\"\n", + "Predicted: \"Sneaker\", Actual: \"Sneaker\"\n", + "Predicted: \"Shirt\", Actual: \"Shirt\"\n", + "Predicted: \"Sandal\", Actual: \"Sneaker\"\n", + "Predicted: \"Shirt\", Actual: \"Pullover\"\n", + "Predicted: \"Trouser\", Actual: \"Trouser\"\n", + "Predicted: \"Trouser\", Actual: \"Pullover\"\n", + "Predicted: \"Shirt\", Actual: \"Pullover\"\n" ] } ], @@ -604,21 +738,20 @@ "]\n", "\n", "model.eval()\n", - "x, y = test_data[0][0], test_data[0][1]\n", - "with torch.no_grad():\n", - " pred = model(x)\n", - " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", - " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" + "for i in range(0, 50):\n", + " x, y = test_data[i][0], test_data[i][1]\n", + " with torch.no_grad():\n", + " pred = model(x)\n", + " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", + " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "Read more about `Saving & Loading your model <saveloadrun_tutorial.html>`_.\n", - "\n", - "\n" - ] + "outputs": [], + "source": [] } ], "metadata": { |