diff options
author | Anthony Wang | 2021-08-24 19:49:49 -0500 |
---|---|---|
committer | Anthony Wang | 2021-08-24 19:49:49 -0500 |
commit | 8816c54f44c615481a712fa14e494ae9c91c6f91 (patch) | |
tree | 77e59577a293f8ed5ed1e0fa962df10fecd116ce | |
parent | 44d82a9276a7a73b003068e73581cb714124ef7d (diff) |
Continue working on model.py
-rw-r--r-- | fashion-mnist-with-pytorch-93-accuracy.ipynb | 42 | ||||
-rw-r--r-- | model.py | 68 |
2 files changed, 93 insertions, 17 deletions
diff --git a/fashion-mnist-with-pytorch-93-accuracy.ipynb b/fashion-mnist-with-pytorch-93-accuracy.ipynb index 2b9f234..79f6923 100644 --- a/fashion-mnist-with-pytorch-93-accuracy.ipynb +++ b/fashion-mnist-with-pytorch-93-accuracy.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", @@ -53,7 +53,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", @@ -106,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -116,7 +116,27 @@ "id": "PV2g1_8qUDvA", "outputId": "b429a123-7574-413d-d15a-92f0481e6753" }, - "outputs": [], + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '../input/fashion-mnist_train.csv'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_45959/798855398.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain_csv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"../input/fashion-mnist_train.csv\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mtest_csv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"../input/fashion-mnist_test.csv\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/util/_decorators.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0mstacklevel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstacklevel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 310\u001b[0m )\n\u001b[0;32m--> 311\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 312\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, error_bad_lines, warn_bad_lines, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options)\u001b[0m\n\u001b[1;32m 584\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwds_defaults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 585\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 586\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_read\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 587\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 588\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 480\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 481\u001b[0m \u001b[0;31m# Create the parser.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 482\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextFileReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 483\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 484\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mchunksize\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 810\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 811\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 812\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 813\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/readers.py\u001b[0m in \u001b[0;36m_make_engine\u001b[0;34m(self, engine)\u001b[0m\n\u001b[1;32m 1038\u001b[0m )\n\u001b[1;32m 1039\u001b[0m \u001b[0;31m# error: Too many arguments for \"ParserBase\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1040\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmapping\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[call-arg]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1041\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_failover_to_python\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/c_parser_wrapper.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, src, **kwds)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;31m# open handles\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_open_handles\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhandles\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/parsers/base_parser.py\u001b[0m in \u001b[0;36m_open_handles\u001b[0;34m(self, src, kwds)\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0mLet\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mreaders\u001b[0m \u001b[0mopen\u001b[0m \u001b[0mIOHandles\u001b[0m \u001b[0mafter\u001b[0m \u001b[0mthey\u001b[0m \u001b[0mare\u001b[0m \u001b[0mdone\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtheir\u001b[0m \u001b[0mpotential\u001b[0m \u001b[0mraises\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \"\"\"\n\u001b[0;32m--> 222\u001b[0;31m self.handles = get_handle(\n\u001b[0m\u001b[1;32m 223\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;34m\"r\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/git/PyTorch/.venv/lib/python3.9/site-packages/pandas/io/common.py\u001b[0m in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 699\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoding\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m\"b\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 700\u001b[0m \u001b[0;31m# Encoding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 701\u001b[0;31m handle = open(\n\u001b[0m\u001b[1;32m 702\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 703\u001b[0m \u001b[0mioargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../input/fashion-mnist_train.csv'" + ] + } + ], "source": [ "train_csv = pd.read_csv(\"../input/fashion-mnist_train.csv\")\n", "test_csv = pd.read_csv(\"../input/fashion-mnist_test.csv\")" @@ -124,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", @@ -199,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", @@ -226,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", @@ -259,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", @@ -299,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -316,7 +336,7 @@ "torch.Size([100, 1, 28, 28])" ] }, - "execution_count": 8, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -505,7 +525,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", @@ -24,10 +24,66 @@ test_data = datasets.FashionMNIST( batch_size = 64 # Create data loaders. -train_dataloader = DataLoader(training_data, batch_size=batch_size) -test_dataloader = DataLoader(test_data, batch_size=batch_size) +train_loader = DataLoader(training_data, batch_size=batch_size) +test_loader = DataLoader(test_data, batch_size=batch_size) + + +def output_label(label): + output_mapping = { + 0: "T-shirt/Top", + 1: "Trouser", + 2: "Pullover", + 3: "Dress", + 4: "Coat", + 5: "Sandal", + 6: "Shirt", + 7: "Sneaker", + 8: "Bag", + 9: "Ankle Boot" + } + input = (label.item() if type(label) == torch.Tensor else label) + return output_mapping[input] + +class FashionCNN(nn.Module): + + def __init__(self): + super(FashionCNN, self).__init__() + + self.layer1 = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2) + ) + + self.layer2 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(2) + ) + + self.fc1 = nn.Linear(in_features=64*6*6, out_features=600) + self.drop = nn.Dropout2d(0.25) + self.fc2 = nn.Linear(in_features=600, out_features=120) + self.fc3 = nn.Linear(in_features=120, out_features=10) + +def forward(self, x): + out = self.layer1(x) + out = self.layer2(out) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = self.drop(out) + out = self.fc2(out) + out = self.fc3(out) + + return out + +model = FashionCNN() + +error = nn.CrossEntropyLoss() + +learning_rate = 0.001 +optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + -for X, y in test_dataloader: - print("Shape of X [N, C, H, W]: ", X.shape) - print("Shape of y: ", y.shape, y.dtype) - break
\ No newline at end of file |