diff options
Diffstat (limited to 'mnist.py')
-rwxr-xr-x | mnist.py | 10 |
1 files changed, 9 insertions, 1 deletions
@@ -61,7 +61,11 @@ class CNN(nn.Module): return out + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model = CNN() +model.to(device) error = nn.CrossEntropyLoss() learning_rate = 0.001 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) @@ -79,6 +83,8 @@ labels_list = [] for epoch in range(num_epochs): for images, labels in train_loader: + images, labels = images.to(device), labels.to(device) + train = Variable(images.view(batch_size, 1, 28, 28)) labels = Variable(labels) @@ -95,12 +101,14 @@ for epoch in range(num_epochs): total = 0 correct = 0 for images, labels in test_loader: + images, labels = images.to(device), labels.to(device) + labels_list.append(labels) test = Variable(images.view(batch_size, 1, 28, 28)) outputs = model(test) - predictions = torch.max(outputs, 1)[1] + predictions = torch.max(outputs, 1)[1].to(device) predictions_list.append(predictions) correct += (predictions == labels).sum() |