diff options
Diffstat (limited to 'mnist.py')
-rwxr-xr-x | mnist.py | 32 |
1 files changed, 11 insertions, 21 deletions
@@ -7,7 +7,6 @@ from torchvision import datasets from torchvision.transforms import ToTensor, Lambda, Compose import matplotlib.pyplot as plt - training_data = datasets.MNIST( root=".data", train=True, @@ -22,7 +21,6 @@ test_data = datasets.MNIST( transform=ToTensor(), ) - batch_size = 100 train_loader = DataLoader(training_data, batch_size=batch_size) @@ -34,18 +32,13 @@ class CNN(nn.Module): super(CNN, self).__init__() self.layer1 = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1), - nn.BatchNorm2d(32), - nn.ReLU(), - nn.MaxPool2d(kernel_size=2, stride=2) - ) + nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, + padding=1), nn.BatchNorm2d(32), nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2)) self.layer2 = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.MaxPool2d(2) - ) - self.fc1 = nn.Linear(in_features=64*6*6, out_features=600) + nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)) + self.fc1 = nn.Linear(in_features=64 * 6 * 6, out_features=600) self.drop = nn.Dropout2d(0.25) self.fc2 = nn.Linear(in_features=600, out_features=120) self.fc3 = nn.Linear(in_features=120, out_features=10) @@ -61,7 +54,6 @@ class CNN(nn.Module): return out - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = CNN() @@ -70,7 +62,6 @@ error = nn.CrossEntropyLoss() learning_rate = 0.001 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) - num_epochs = 5 count = 0 @@ -104,28 +95,27 @@ for epoch in range(num_epochs): images, labels = images.to(device), labels.to(device) labels_list.append(labels) - + test = Variable(images.view(batch_size, 1, 28, 28)) outputs = model(test) - + predictions = torch.max(outputs, 1)[1].to(device) predictions_list.append(predictions) correct += (predictions == labels).sum() - + total += len(labels) - + accuracy = correct * batch_size / total loss_list.append(loss.data) iteration_list.append(count) accuracy_list.append(accuracy) - print("Iteration: {}, Loss: {}, Accuracy: {}%".format(count, loss.data, accuracy)) - + print("Iteration: {}, Loss: {}, Accuracy: {}%".format( + count, loss.data, accuracy)) torch.save(model.state_dict(), "model.pth") print("Saved PyTorch Model State to model.pth") - plt.plot(iteration_list, loss_list) plt.xlabel("No. of Iteration") plt.ylabel("Loss") |