diff options
author | Anthony Wang | 2021-08-23 21:32:41 -0500 |
---|---|---|
committer | Anthony Wang | 2021-08-23 21:32:41 -0500 |
commit | cc2d7d02ad92edf8e67a0de66edd3217edb25574 (patch) | |
tree | 91303418fc990cc4fb79664adfebe07006e02c41 | |
parent | ac768636c632e82ef4ce07a5bbdf2245ef52140c (diff) |
Start working on model.py
-rw-r--r-- | .ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb | 232 | ||||
-rw-r--r-- | model.py | 33 | ||||
-rw-r--r-- | quickstart_tutorial.ipynb | 232 |
3 files changed, 163 insertions, 334 deletions
diff --git a/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb b/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb index 3ccf563..55d4021 100644 --- a/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb +++ b/.ipynb_checkpoints/quickstart_tutorial-checkpoint.ipynb @@ -79,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { @@ -88,112 +88,10 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "47918efb82854fc7a269ce73230391b0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/26421880 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", - "\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9abecd52d9144d53bd028f14a2cfd60b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/29515 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", - "\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "df61f428b0c44a818d2ab0f64420d9b3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/4422102 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", - "\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "418ca86b3df24c84979a54ca66cebe56", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/5148 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", - "\n" - ] - }, - { "name": "stderr", "output_type": "stream", "text": [ - "/home/ta180m/.local/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/python-pytorch/src/pytorch-1.9.0-opt/torch/csrc/utils/tensor_numpy.cpp:174.)\n", + "/home/ta180m/git/PyTorch/.venv/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n", " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" ] } @@ -228,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { @@ -291,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { @@ -377,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { @@ -401,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -439,7 +337,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -476,7 +374,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -490,78 +388,78 @@ "text": [ "Epoch 1\n", "-------------------------------\n", - "loss: 1.758146 [ 0/60000]\n", - "loss: 1.820034 [ 6400/60000]\n", - "loss: 1.846449 [12800/60000]\n", - "loss: 1.975245 [19200/60000]\n", - "loss: 1.612495 [25600/60000]\n", - "loss: 1.748993 [32000/60000]\n", - "loss: 1.628008 [38400/60000]\n", - "loss: 1.655061 [44800/60000]\n", - "loss: 1.770255 [51200/60000]\n", - "loss: 1.654287 [57600/60000]\n", + "loss: 2.300270 [ 0/60000]\n", + "loss: 2.290948 [ 6400/60000]\n", + "loss: 2.280627 [12800/60000]\n", + "loss: 2.283042 [19200/60000]\n", + "loss: 2.268817 [25600/60000]\n", + "loss: 2.262104 [32000/60000]\n", + "loss: 2.248093 [38400/60000]\n", + "loss: 2.233517 [44800/60000]\n", + "loss: 2.234299 [51200/60000]\n", + "loss: 2.234841 [57600/60000]\n", "Test Error: \n", - " Accuracy: 37.7%, Avg loss: 1.749445 \n", + " Accuracy: 51.1%, Avg loss: 2.221716 \n", "\n", "Epoch 2\n", "-------------------------------\n", - "loss: 1.670408 [ 0/60000]\n", - "loss: 1.743051 [ 6400/60000]\n", - "loss: 1.773547 [12800/60000]\n", - "loss: 1.924395 [19200/60000]\n", - "loss: 1.529726 [25600/60000]\n", - "loss: 1.692361 [32000/60000]\n", - "loss: 1.559834 [38400/60000]\n", - "loss: 1.593531 [44800/60000]\n", - "loss: 1.712157 [51200/60000]\n", - "loss: 1.605115 [57600/60000]\n", + "loss: 2.203925 [ 0/60000]\n", + "loss: 2.200477 [ 6400/60000]\n", + "loss: 2.180246 [12800/60000]\n", + "loss: 2.213840 [19200/60000]\n", + "loss: 2.155768 [25600/60000]\n", + "loss: 2.154045 [32000/60000]\n", + "loss: 2.130811 [38400/60000]\n", + "loss: 2.101326 [44800/60000]\n", + "loss: 2.118105 [51200/60000]\n", + "loss: 2.116674 [57600/60000]\n", "Test Error: \n", - " Accuracy: 38.1%, Avg loss: 1.694516 \n", + " Accuracy: 51.0%, Avg loss: 2.093179 \n", "\n", "Epoch 3\n", "-------------------------------\n", - "loss: 1.607648 [ 0/60000]\n", - "loss: 1.684907 [ 6400/60000]\n", - "loss: 1.716139 [12800/60000]\n", - "loss: 1.888849 [19200/60000]\n", - "loss: 1.474264 [25600/60000]\n", - "loss: 1.652733 [32000/60000]\n", - "loss: 1.514825 [38400/60000]\n", - "loss: 1.549373 [44800/60000]\n", - "loss: 1.670293 [51200/60000]\n", - "loss: 1.571395 [57600/60000]\n", + "loss: 2.054594 [ 0/60000]\n", + "loss: 2.047193 [ 6400/60000]\n", + "loss: 2.009665 [12800/60000]\n", + "loss: 2.093936 [19200/60000]\n", + "loss: 1.965194 [25600/60000]\n", + "loss: 1.976750 [32000/60000]\n", + "loss: 1.938592 [38400/60000]\n", + "loss: 1.889513 [44800/60000]\n", + "loss: 1.942611 [51200/60000]\n", + "loss: 1.936963 [57600/60000]\n", "Test Error: \n", - " Accuracy: 39.0%, Avg loss: 1.653676 \n", + " Accuracy: 51.2%, Avg loss: 1.900150 \n", "\n", "Epoch 4\n", "-------------------------------\n", - "loss: 1.561757 [ 0/60000]\n", - "loss: 1.640771 [ 6400/60000]\n", - "loss: 1.669458 [12800/60000]\n", - "loss: 1.862879 [19200/60000]\n", - "loss: 1.435348 [25600/60000]\n", - "loss: 1.623189 [32000/60000]\n", - "loss: 1.482370 [38400/60000]\n", - "loss: 1.515045 [44800/60000]\n", - "loss: 1.638349 [51200/60000]\n", - "loss: 1.545919 [57600/60000]\n", + "loss: 1.837173 [ 0/60000]\n", + "loss: 1.822518 [ 6400/60000]\n", + "loss: 1.775139 [12800/60000]\n", + "loss: 1.925843 [19200/60000]\n", + "loss: 1.731390 [25600/60000]\n", + "loss: 1.778743 [32000/60000]\n", + "loss: 1.714922 [38400/60000]\n", + "loss: 1.670009 [44800/60000]\n", + "loss: 1.755909 [51200/60000]\n", + "loss: 1.763030 [57600/60000]\n", "Test Error: \n", - " Accuracy: 39.9%, Avg loss: 1.621615 \n", + " Accuracy: 52.5%, Avg loss: 1.711389 \n", "\n", "Epoch 5\n", "-------------------------------\n", - "loss: 1.525517 [ 0/60000]\n", - "loss: 1.604991 [ 6400/60000]\n", - "loss: 1.630397 [12800/60000]\n", - "loss: 1.841878 [19200/60000]\n", - "loss: 1.406707 [25600/60000]\n", - "loss: 1.599460 [32000/60000]\n", - "loss: 1.456716 [38400/60000]\n", - "loss: 1.485950 [44800/60000]\n", - "loss: 1.612476 [51200/60000]\n", - "loss: 1.525381 [57600/60000]\n", + "loss: 1.621799 [ 0/60000]\n", + "loss: 1.615258 [ 6400/60000]\n", + "loss: 1.567131 [12800/60000]\n", + "loss: 1.768921 [19200/60000]\n", + "loss: 1.539987 [25600/60000]\n", + "loss: 1.627408 [32000/60000]\n", + "loss: 1.533756 [38400/60000]\n", + "loss: 1.510703 [44800/60000]\n", + "loss: 1.603583 [51200/60000]\n", + "loss: 1.636089 [57600/60000]\n", "Test Error: \n", - " Accuracy: 40.7%, Avg loss: 1.595456 \n", + " Accuracy: 53.2%, Avg loss: 1.568043 \n", "\n", "Done!\n" ] @@ -725,9 +623,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "PyTorch", "language": "python", - "name": "python3" + "name": "pytorch" }, "language_info": { "codemirror_mode": { diff --git a/model.py b/model.py new file mode 100644 index 0000000..93cd47a --- /dev/null +++ b/model.py @@ -0,0 +1,33 @@ +import torch +from torch import nn +from torch.utils.data import DataLoader +from torchvision import datasets +from torchvision.transforms import ToTensor, Lambda, Compose +import matplotlib.pyplot as plt + +# Download training data from open datasets. +training_data = datasets.FashionMNIST( + root="data", + train=True, + download=True, + transform=ToTensor(), +) + +# Download test data from open datasets. +test_data = datasets.FashionMNIST( + root="data", + train=False, + download=True, + transform=ToTensor(), +) + +batch_size = 64 + +# Create data loaders. +train_dataloader = DataLoader(training_data, batch_size=batch_size) +test_dataloader = DataLoader(test_data, batch_size=batch_size) + +for X, y in test_dataloader: + print("Shape of X [N, C, H, W]: ", X.shape) + print("Shape of y: ", y.shape, y.dtype) + break
\ No newline at end of file diff --git a/quickstart_tutorial.ipynb b/quickstart_tutorial.ipynb index 3ccf563..55d4021 100644 --- a/quickstart_tutorial.ipynb +++ b/quickstart_tutorial.ipynb @@ -79,7 +79,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { @@ -88,112 +88,10 @@ }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "47918efb82854fc7a269ce73230391b0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/26421880 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", - "\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9abecd52d9144d53bd028f14a2cfd60b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/29515 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", - "\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "df61f428b0c44a818d2ab0f64420d9b3", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/4422102 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw\n", - "\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", - "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "418ca86b3df24c84979a54ca66cebe56", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/5148 [00:00<?, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw\n", - "\n" - ] - }, - { "name": "stderr", "output_type": "stream", "text": [ - "/home/ta180m/.local/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/python-pytorch/src/pytorch-1.9.0-opt/torch/csrc/utils/tensor_numpy.cpp:174.)\n", + "/home/ta180m/git/PyTorch/.venv/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n", " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" ] } @@ -228,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { @@ -291,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { @@ -377,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { @@ -401,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { @@ -439,7 +337,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -476,7 +374,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": { "collapsed": false, "jupyter": { @@ -490,78 +388,78 @@ "text": [ "Epoch 1\n", "-------------------------------\n", - "loss: 1.758146 [ 0/60000]\n", - "loss: 1.820034 [ 6400/60000]\n", - "loss: 1.846449 [12800/60000]\n", - "loss: 1.975245 [19200/60000]\n", - "loss: 1.612495 [25600/60000]\n", - "loss: 1.748993 [32000/60000]\n", - "loss: 1.628008 [38400/60000]\n", - "loss: 1.655061 [44800/60000]\n", - "loss: 1.770255 [51200/60000]\n", - "loss: 1.654287 [57600/60000]\n", + "loss: 2.300270 [ 0/60000]\n", + "loss: 2.290948 [ 6400/60000]\n", + "loss: 2.280627 [12800/60000]\n", + "loss: 2.283042 [19200/60000]\n", + "loss: 2.268817 [25600/60000]\n", + "loss: 2.262104 [32000/60000]\n", + "loss: 2.248093 [38400/60000]\n", + "loss: 2.233517 [44800/60000]\n", + "loss: 2.234299 [51200/60000]\n", + "loss: 2.234841 [57600/60000]\n", "Test Error: \n", - " Accuracy: 37.7%, Avg loss: 1.749445 \n", + " Accuracy: 51.1%, Avg loss: 2.221716 \n", "\n", "Epoch 2\n", "-------------------------------\n", - "loss: 1.670408 [ 0/60000]\n", - "loss: 1.743051 [ 6400/60000]\n", - "loss: 1.773547 [12800/60000]\n", - "loss: 1.924395 [19200/60000]\n", - "loss: 1.529726 [25600/60000]\n", - "loss: 1.692361 [32000/60000]\n", - "loss: 1.559834 [38400/60000]\n", - "loss: 1.593531 [44800/60000]\n", - "loss: 1.712157 [51200/60000]\n", - "loss: 1.605115 [57600/60000]\n", + "loss: 2.203925 [ 0/60000]\n", + "loss: 2.200477 [ 6400/60000]\n", + "loss: 2.180246 [12800/60000]\n", + "loss: 2.213840 [19200/60000]\n", + "loss: 2.155768 [25600/60000]\n", + "loss: 2.154045 [32000/60000]\n", + "loss: 2.130811 [38400/60000]\n", + "loss: 2.101326 [44800/60000]\n", + "loss: 2.118105 [51200/60000]\n", + "loss: 2.116674 [57600/60000]\n", "Test Error: \n", - " Accuracy: 38.1%, Avg loss: 1.694516 \n", + " Accuracy: 51.0%, Avg loss: 2.093179 \n", "\n", "Epoch 3\n", "-------------------------------\n", - "loss: 1.607648 [ 0/60000]\n", - "loss: 1.684907 [ 6400/60000]\n", - "loss: 1.716139 [12800/60000]\n", - "loss: 1.888849 [19200/60000]\n", - "loss: 1.474264 [25600/60000]\n", - "loss: 1.652733 [32000/60000]\n", - "loss: 1.514825 [38400/60000]\n", - "loss: 1.549373 [44800/60000]\n", - "loss: 1.670293 [51200/60000]\n", - "loss: 1.571395 [57600/60000]\n", + "loss: 2.054594 [ 0/60000]\n", + "loss: 2.047193 [ 6400/60000]\n", + "loss: 2.009665 [12800/60000]\n", + "loss: 2.093936 [19200/60000]\n", + "loss: 1.965194 [25600/60000]\n", + "loss: 1.976750 [32000/60000]\n", + "loss: 1.938592 [38400/60000]\n", + "loss: 1.889513 [44800/60000]\n", + "loss: 1.942611 [51200/60000]\n", + "loss: 1.936963 [57600/60000]\n", "Test Error: \n", - " Accuracy: 39.0%, Avg loss: 1.653676 \n", + " Accuracy: 51.2%, Avg loss: 1.900150 \n", "\n", "Epoch 4\n", "-------------------------------\n", - "loss: 1.561757 [ 0/60000]\n", - "loss: 1.640771 [ 6400/60000]\n", - "loss: 1.669458 [12800/60000]\n", - "loss: 1.862879 [19200/60000]\n", - "loss: 1.435348 [25600/60000]\n", - "loss: 1.623189 [32000/60000]\n", - "loss: 1.482370 [38400/60000]\n", - "loss: 1.515045 [44800/60000]\n", - "loss: 1.638349 [51200/60000]\n", - "loss: 1.545919 [57600/60000]\n", + "loss: 1.837173 [ 0/60000]\n", + "loss: 1.822518 [ 6400/60000]\n", + "loss: 1.775139 [12800/60000]\n", + "loss: 1.925843 [19200/60000]\n", + "loss: 1.731390 [25600/60000]\n", + "loss: 1.778743 [32000/60000]\n", + "loss: 1.714922 [38400/60000]\n", + "loss: 1.670009 [44800/60000]\n", + "loss: 1.755909 [51200/60000]\n", + "loss: 1.763030 [57600/60000]\n", "Test Error: \n", - " Accuracy: 39.9%, Avg loss: 1.621615 \n", + " Accuracy: 52.5%, Avg loss: 1.711389 \n", "\n", "Epoch 5\n", "-------------------------------\n", - "loss: 1.525517 [ 0/60000]\n", - "loss: 1.604991 [ 6400/60000]\n", - "loss: 1.630397 [12800/60000]\n", - "loss: 1.841878 [19200/60000]\n", - "loss: 1.406707 [25600/60000]\n", - "loss: 1.599460 [32000/60000]\n", - "loss: 1.456716 [38400/60000]\n", - "loss: 1.485950 [44800/60000]\n", - "loss: 1.612476 [51200/60000]\n", - "loss: 1.525381 [57600/60000]\n", + "loss: 1.621799 [ 0/60000]\n", + "loss: 1.615258 [ 6400/60000]\n", + "loss: 1.567131 [12800/60000]\n", + "loss: 1.768921 [19200/60000]\n", + "loss: 1.539987 [25600/60000]\n", + "loss: 1.627408 [32000/60000]\n", + "loss: 1.533756 [38400/60000]\n", + "loss: 1.510703 [44800/60000]\n", + "loss: 1.603583 [51200/60000]\n", + "loss: 1.636089 [57600/60000]\n", "Test Error: \n", - " Accuracy: 40.7%, Avg loss: 1.595456 \n", + " Accuracy: 53.2%, Avg loss: 1.568043 \n", "\n", "Done!\n" ] @@ -725,9 +623,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "PyTorch", "language": "python", - "name": "python3" + "name": "pytorch" }, "language_info": { "codemirror_mode": { |