summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2021-08-23 21:32:41 -0500
committerAnthony Wang2021-08-23 21:32:41 -0500
commitcc2d7d02ad92edf8e67a0de66edd3217edb25574 (patch)
tree91303418fc990cc4fb79664adfebe07006e02c41
parentac768636c632e82ef4ce07a5bbdf2245ef52140c (diff)
Start working on model.py
-rw-r--r--.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb232
-rw-r--r--model.py33
-rw-r--r--quickstart_tutorial.ipynb232
3 files changed, 163 insertions, 334 deletions
diff --git a/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb b/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb
index 3ccf563..55d4021 100644
--- a/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb
+++ b/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb
@@ -79,7 +79,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -88,112 +88,10 @@
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "47918efb82854fc7a269ce73230391b0",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/26421880 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
- "\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "9abecd52d9144d53bd028f14a2cfd60b",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/29515 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
- "\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "df61f428b0c44a818d2ab0f64420d9b3",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/4422102 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
- "\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "418ca86b3df24c84979a54ca66cebe56",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/5148 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
- "\n"
- ]
- },
- {
"name": "stderr",
"output_type": "stream",
"text": [
- "/home/ta180m/.local/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/python-pytorch/src/pytorch-1.9.0-opt/torch/csrc/utils/tensor_numpy.cpp:174.)\n",
+ "/home/ta180m/git/PyTorch/.venv/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
" return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
]
}
@@ -228,7 +126,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -291,7 +189,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 4,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -377,7 +275,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -401,7 +299,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 6,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -439,7 +337,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -476,7 +374,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 8,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -490,78 +388,78 @@
"text": [
"Epoch 1\n",
"-------------------------------\n",
- "loss: 1.758146 [ 0/60000]\n",
- "loss: 1.820034 [ 6400/60000]\n",
- "loss: 1.846449 [12800/60000]\n",
- "loss: 1.975245 [19200/60000]\n",
- "loss: 1.612495 [25600/60000]\n",
- "loss: 1.748993 [32000/60000]\n",
- "loss: 1.628008 [38400/60000]\n",
- "loss: 1.655061 [44800/60000]\n",
- "loss: 1.770255 [51200/60000]\n",
- "loss: 1.654287 [57600/60000]\n",
+ "loss: 2.300270 [ 0/60000]\n",
+ "loss: 2.290948 [ 6400/60000]\n",
+ "loss: 2.280627 [12800/60000]\n",
+ "loss: 2.283042 [19200/60000]\n",
+ "loss: 2.268817 [25600/60000]\n",
+ "loss: 2.262104 [32000/60000]\n",
+ "loss: 2.248093 [38400/60000]\n",
+ "loss: 2.233517 [44800/60000]\n",
+ "loss: 2.234299 [51200/60000]\n",
+ "loss: 2.234841 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 37.7%, Avg loss: 1.749445 \n",
+ " Accuracy: 51.1%, Avg loss: 2.221716 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
- "loss: 1.670408 [ 0/60000]\n",
- "loss: 1.743051 [ 6400/60000]\n",
- "loss: 1.773547 [12800/60000]\n",
- "loss: 1.924395 [19200/60000]\n",
- "loss: 1.529726 [25600/60000]\n",
- "loss: 1.692361 [32000/60000]\n",
- "loss: 1.559834 [38400/60000]\n",
- "loss: 1.593531 [44800/60000]\n",
- "loss: 1.712157 [51200/60000]\n",
- "loss: 1.605115 [57600/60000]\n",
+ "loss: 2.203925 [ 0/60000]\n",
+ "loss: 2.200477 [ 6400/60000]\n",
+ "loss: 2.180246 [12800/60000]\n",
+ "loss: 2.213840 [19200/60000]\n",
+ "loss: 2.155768 [25600/60000]\n",
+ "loss: 2.154045 [32000/60000]\n",
+ "loss: 2.130811 [38400/60000]\n",
+ "loss: 2.101326 [44800/60000]\n",
+ "loss: 2.118105 [51200/60000]\n",
+ "loss: 2.116674 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 38.1%, Avg loss: 1.694516 \n",
+ " Accuracy: 51.0%, Avg loss: 2.093179 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
- "loss: 1.607648 [ 0/60000]\n",
- "loss: 1.684907 [ 6400/60000]\n",
- "loss: 1.716139 [12800/60000]\n",
- "loss: 1.888849 [19200/60000]\n",
- "loss: 1.474264 [25600/60000]\n",
- "loss: 1.652733 [32000/60000]\n",
- "loss: 1.514825 [38400/60000]\n",
- "loss: 1.549373 [44800/60000]\n",
- "loss: 1.670293 [51200/60000]\n",
- "loss: 1.571395 [57600/60000]\n",
+ "loss: 2.054594 [ 0/60000]\n",
+ "loss: 2.047193 [ 6400/60000]\n",
+ "loss: 2.009665 [12800/60000]\n",
+ "loss: 2.093936 [19200/60000]\n",
+ "loss: 1.965194 [25600/60000]\n",
+ "loss: 1.976750 [32000/60000]\n",
+ "loss: 1.938592 [38400/60000]\n",
+ "loss: 1.889513 [44800/60000]\n",
+ "loss: 1.942611 [51200/60000]\n",
+ "loss: 1.936963 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 39.0%, Avg loss: 1.653676 \n",
+ " Accuracy: 51.2%, Avg loss: 1.900150 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
- "loss: 1.561757 [ 0/60000]\n",
- "loss: 1.640771 [ 6400/60000]\n",
- "loss: 1.669458 [12800/60000]\n",
- "loss: 1.862879 [19200/60000]\n",
- "loss: 1.435348 [25600/60000]\n",
- "loss: 1.623189 [32000/60000]\n",
- "loss: 1.482370 [38400/60000]\n",
- "loss: 1.515045 [44800/60000]\n",
- "loss: 1.638349 [51200/60000]\n",
- "loss: 1.545919 [57600/60000]\n",
+ "loss: 1.837173 [ 0/60000]\n",
+ "loss: 1.822518 [ 6400/60000]\n",
+ "loss: 1.775139 [12800/60000]\n",
+ "loss: 1.925843 [19200/60000]\n",
+ "loss: 1.731390 [25600/60000]\n",
+ "loss: 1.778743 [32000/60000]\n",
+ "loss: 1.714922 [38400/60000]\n",
+ "loss: 1.670009 [44800/60000]\n",
+ "loss: 1.755909 [51200/60000]\n",
+ "loss: 1.763030 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 39.9%, Avg loss: 1.621615 \n",
+ " Accuracy: 52.5%, Avg loss: 1.711389 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
- "loss: 1.525517 [ 0/60000]\n",
- "loss: 1.604991 [ 6400/60000]\n",
- "loss: 1.630397 [12800/60000]\n",
- "loss: 1.841878 [19200/60000]\n",
- "loss: 1.406707 [25600/60000]\n",
- "loss: 1.599460 [32000/60000]\n",
- "loss: 1.456716 [38400/60000]\n",
- "loss: 1.485950 [44800/60000]\n",
- "loss: 1.612476 [51200/60000]\n",
- "loss: 1.525381 [57600/60000]\n",
+ "loss: 1.621799 [ 0/60000]\n",
+ "loss: 1.615258 [ 6400/60000]\n",
+ "loss: 1.567131 [12800/60000]\n",
+ "loss: 1.768921 [19200/60000]\n",
+ "loss: 1.539987 [25600/60000]\n",
+ "loss: 1.627408 [32000/60000]\n",
+ "loss: 1.533756 [38400/60000]\n",
+ "loss: 1.510703 [44800/60000]\n",
+ "loss: 1.603583 [51200/60000]\n",
+ "loss: 1.636089 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 40.7%, Avg loss: 1.595456 \n",
+ " Accuracy: 53.2%, Avg loss: 1.568043 \n",
"\n",
"Done!\n"
]
@@ -725,9 +623,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "PyTorch",
"language": "python",
- "name": "python3"
+ "name": "pytorch"
},
"language_info": {
"codemirror_mode": {
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..93cd47a
--- /dev/null
+++ b/model.py
@@ -0,0 +1,33 @@
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+from torchvision import datasets
+from torchvision.transforms import ToTensor, Lambda, Compose
+import matplotlib.pyplot as plt
+
+# Download training data from open datasets.
+training_data = datasets.FashionMNIST(
+ root="data",
+ train=True,
+ download=True,
+ transform=ToTensor(),
+)
+
+# Download test data from open datasets.
+test_data = datasets.FashionMNIST(
+ root="data",
+ train=False,
+ download=True,
+ transform=ToTensor(),
+)
+
+batch_size = 64
+
+# Create data loaders.
+train_dataloader = DataLoader(training_data, batch_size=batch_size)
+test_dataloader = DataLoader(test_data, batch_size=batch_size)
+
+for X, y in test_dataloader:
+ print("Shape of X [N, C, H, W]: ", X.shape)
+ print("Shape of y: ", y.shape, y.dtype)
+ break \ No newline at end of file
diff --git a/quickstart_tutorial.ipynb b/quickstart_tutorial.ipynb
index 3ccf563..55d4021 100644
--- a/quickstart_tutorial.ipynb
+++ b/quickstart_tutorial.ipynb
@@ -79,7 +79,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -88,112 +88,10 @@
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "47918efb82854fc7a269ce73230391b0",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/26421880 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
- "\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "9abecd52d9144d53bd028f14a2cfd60b",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/29515 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
- "\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "df61f428b0c44a818d2ab0f64420d9b3",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/4422102 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n",
- "\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
- "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "418ca86b3df24c84979a54ca66cebe56",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/5148 [00:00<?, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n",
- "\n"
- ]
- },
- {
"name": "stderr",
"output_type": "stream",
"text": [
- "/home/ta180m/.local/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/python-pytorch/src/pytorch-1.9.0-opt/torch/csrc/utils/tensor_numpy.cpp:174.)\n",
+ "/home/ta180m/git/PyTorch/.venv/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n",
" return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
]
}
@@ -228,7 +126,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -291,7 +189,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 4,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -377,7 +275,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -401,7 +299,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 6,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -439,7 +337,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -476,7 +374,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 8,
"metadata": {
"collapsed": false,
"jupyter": {
@@ -490,78 +388,78 @@
"text": [
"Epoch 1\n",
"-------------------------------\n",
- "loss: 1.758146 [ 0/60000]\n",
- "loss: 1.820034 [ 6400/60000]\n",
- "loss: 1.846449 [12800/60000]\n",
- "loss: 1.975245 [19200/60000]\n",
- "loss: 1.612495 [25600/60000]\n",
- "loss: 1.748993 [32000/60000]\n",
- "loss: 1.628008 [38400/60000]\n",
- "loss: 1.655061 [44800/60000]\n",
- "loss: 1.770255 [51200/60000]\n",
- "loss: 1.654287 [57600/60000]\n",
+ "loss: 2.300270 [ 0/60000]\n",
+ "loss: 2.290948 [ 6400/60000]\n",
+ "loss: 2.280627 [12800/60000]\n",
+ "loss: 2.283042 [19200/60000]\n",
+ "loss: 2.268817 [25600/60000]\n",
+ "loss: 2.262104 [32000/60000]\n",
+ "loss: 2.248093 [38400/60000]\n",
+ "loss: 2.233517 [44800/60000]\n",
+ "loss: 2.234299 [51200/60000]\n",
+ "loss: 2.234841 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 37.7%, Avg loss: 1.749445 \n",
+ " Accuracy: 51.1%, Avg loss: 2.221716 \n",
"\n",
"Epoch 2\n",
"-------------------------------\n",
- "loss: 1.670408 [ 0/60000]\n",
- "loss: 1.743051 [ 6400/60000]\n",
- "loss: 1.773547 [12800/60000]\n",
- "loss: 1.924395 [19200/60000]\n",
- "loss: 1.529726 [25600/60000]\n",
- "loss: 1.692361 [32000/60000]\n",
- "loss: 1.559834 [38400/60000]\n",
- "loss: 1.593531 [44800/60000]\n",
- "loss: 1.712157 [51200/60000]\n",
- "loss: 1.605115 [57600/60000]\n",
+ "loss: 2.203925 [ 0/60000]\n",
+ "loss: 2.200477 [ 6400/60000]\n",
+ "loss: 2.180246 [12800/60000]\n",
+ "loss: 2.213840 [19200/60000]\n",
+ "loss: 2.155768 [25600/60000]\n",
+ "loss: 2.154045 [32000/60000]\n",
+ "loss: 2.130811 [38400/60000]\n",
+ "loss: 2.101326 [44800/60000]\n",
+ "loss: 2.118105 [51200/60000]\n",
+ "loss: 2.116674 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 38.1%, Avg loss: 1.694516 \n",
+ " Accuracy: 51.0%, Avg loss: 2.093179 \n",
"\n",
"Epoch 3\n",
"-------------------------------\n",
- "loss: 1.607648 [ 0/60000]\n",
- "loss: 1.684907 [ 6400/60000]\n",
- "loss: 1.716139 [12800/60000]\n",
- "loss: 1.888849 [19200/60000]\n",
- "loss: 1.474264 [25600/60000]\n",
- "loss: 1.652733 [32000/60000]\n",
- "loss: 1.514825 [38400/60000]\n",
- "loss: 1.549373 [44800/60000]\n",
- "loss: 1.670293 [51200/60000]\n",
- "loss: 1.571395 [57600/60000]\n",
+ "loss: 2.054594 [ 0/60000]\n",
+ "loss: 2.047193 [ 6400/60000]\n",
+ "loss: 2.009665 [12800/60000]\n",
+ "loss: 2.093936 [19200/60000]\n",
+ "loss: 1.965194 [25600/60000]\n",
+ "loss: 1.976750 [32000/60000]\n",
+ "loss: 1.938592 [38400/60000]\n",
+ "loss: 1.889513 [44800/60000]\n",
+ "loss: 1.942611 [51200/60000]\n",
+ "loss: 1.936963 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 39.0%, Avg loss: 1.653676 \n",
+ " Accuracy: 51.2%, Avg loss: 1.900150 \n",
"\n",
"Epoch 4\n",
"-------------------------------\n",
- "loss: 1.561757 [ 0/60000]\n",
- "loss: 1.640771 [ 6400/60000]\n",
- "loss: 1.669458 [12800/60000]\n",
- "loss: 1.862879 [19200/60000]\n",
- "loss: 1.435348 [25600/60000]\n",
- "loss: 1.623189 [32000/60000]\n",
- "loss: 1.482370 [38400/60000]\n",
- "loss: 1.515045 [44800/60000]\n",
- "loss: 1.638349 [51200/60000]\n",
- "loss: 1.545919 [57600/60000]\n",
+ "loss: 1.837173 [ 0/60000]\n",
+ "loss: 1.822518 [ 6400/60000]\n",
+ "loss: 1.775139 [12800/60000]\n",
+ "loss: 1.925843 [19200/60000]\n",
+ "loss: 1.731390 [25600/60000]\n",
+ "loss: 1.778743 [32000/60000]\n",
+ "loss: 1.714922 [38400/60000]\n",
+ "loss: 1.670009 [44800/60000]\n",
+ "loss: 1.755909 [51200/60000]\n",
+ "loss: 1.763030 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 39.9%, Avg loss: 1.621615 \n",
+ " Accuracy: 52.5%, Avg loss: 1.711389 \n",
"\n",
"Epoch 5\n",
"-------------------------------\n",
- "loss: 1.525517 [ 0/60000]\n",
- "loss: 1.604991 [ 6400/60000]\n",
- "loss: 1.630397 [12800/60000]\n",
- "loss: 1.841878 [19200/60000]\n",
- "loss: 1.406707 [25600/60000]\n",
- "loss: 1.599460 [32000/60000]\n",
- "loss: 1.456716 [38400/60000]\n",
- "loss: 1.485950 [44800/60000]\n",
- "loss: 1.612476 [51200/60000]\n",
- "loss: 1.525381 [57600/60000]\n",
+ "loss: 1.621799 [ 0/60000]\n",
+ "loss: 1.615258 [ 6400/60000]\n",
+ "loss: 1.567131 [12800/60000]\n",
+ "loss: 1.768921 [19200/60000]\n",
+ "loss: 1.539987 [25600/60000]\n",
+ "loss: 1.627408 [32000/60000]\n",
+ "loss: 1.533756 [38400/60000]\n",
+ "loss: 1.510703 [44800/60000]\n",
+ "loss: 1.603583 [51200/60000]\n",
+ "loss: 1.636089 [57600/60000]\n",
"Test Error: \n",
- " Accuracy: 40.7%, Avg loss: 1.595456 \n",
+ " Accuracy: 53.2%, Avg loss: 1.568043 \n",
"\n",
"Done!\n"
]
@@ -725,9 +623,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "PyTorch",
"language": "python",
- "name": "python3"
+ "name": "pytorch"
},
"language_info": {
"codemirror_mode": {