summaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAnthony Wang2021-08-24 19:49:49 -0500
committerAnthony Wang2021-08-24 19:49:49 -0500
commit8816c54f44c615481a712fa14e494ae9c91c6f91 (patch)
tree77e59577a293f8ed5ed1e0fa962df10fecd116ce /model.py
parent44d82a9276a7a73b003068e73581cb714124ef7d (diff)
Continue working on model.py
Diffstat (limited to 'model.py')
-rw-r--r--model.py68
1 files changed, 62 insertions, 6 deletions
diff --git a/model.py b/model.py
index 93cd47a..f090cd7 100644
--- a/model.py
+++ b/model.py
@@ -24,10 +24,66 @@ test_data = datasets.FashionMNIST(
batch_size = 64
# Create data loaders.
-train_dataloader = DataLoader(training_data, batch_size=batch_size)
-test_dataloader = DataLoader(test_data, batch_size=batch_size)
+train_loader = DataLoader(training_data, batch_size=batch_size)
+test_loader = DataLoader(test_data, batch_size=batch_size)
+
+
+def output_label(label):
+ output_mapping = {
+ 0: "T-shirt/Top",
+ 1: "Trouser",
+ 2: "Pullover",
+ 3: "Dress",
+ 4: "Coat",
+ 5: "Sandal",
+ 6: "Shirt",
+ 7: "Sneaker",
+ 8: "Bag",
+ 9: "Ankle Boot"
+ }
+ input = (label.item() if type(label) == torch.Tensor else label)
+ return output_mapping[input]
+
+class FashionCNN(nn.Module):
+
+ def __init__(self):
+ super(FashionCNN, self).__init__()
+
+ self.layer1 = nn.Sequential(
+ nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
+ nn.BatchNorm2d(32),
+ nn.ReLU(),
+ nn.MaxPool2d(kernel_size=2, stride=2)
+ )
+
+ self.layer2 = nn.Sequential(
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
+ nn.BatchNorm2d(64),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+
+ self.fc1 = nn.Linear(in_features=64*6*6, out_features=600)
+ self.drop = nn.Dropout2d(0.25)
+ self.fc2 = nn.Linear(in_features=600, out_features=120)
+ self.fc3 = nn.Linear(in_features=120, out_features=10)
+
+def forward(self, x):
+ out = self.layer1(x)
+ out = self.layer2(out)
+ out = out.view(out.size(0), -1)
+ out = self.fc1(out)
+ out = self.drop(out)
+ out = self.fc2(out)
+ out = self.fc3(out)
+
+ return out
+
+model = FashionCNN()
+
+error = nn.CrossEntropyLoss()
+
+learning_rate = 0.001
+optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+
-for X, y in test_dataloader:
- print("Shape of X [N, C, H, W]: ", X.shape)
- print("Shape of y: ", y.shape, y.dtype)
- break \ No newline at end of file