summaryrefslogtreecommitdiff
path: root/mnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'mnist.py')
-rwxr-xr-xmnist.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/mnist.py b/mnist.py
index ea99633..95ca6ad 100755
--- a/mnist.py
+++ b/mnist.py
@@ -61,7 +61,11 @@ class CNN(nn.Module):
return out
+
+device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+
model = CNN()
+model.to(device)
error = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
@@ -79,6 +83,8 @@ labels_list = []
for epoch in range(num_epochs):
for images, labels in train_loader:
+ images, labels = images.to(device), labels.to(device)
+
train = Variable(images.view(batch_size, 1, 28, 28))
labels = Variable(labels)
@@ -95,12 +101,14 @@ for epoch in range(num_epochs):
total = 0
correct = 0
for images, labels in test_loader:
+ images, labels = images.to(device), labels.to(device)
+
labels_list.append(labels)
test = Variable(images.view(batch_size, 1, 28, 28))
outputs = model(test)
- predictions = torch.max(outputs, 1)[1]
+ predictions = torch.max(outputs, 1)[1].to(device)
predictions_list.append(predictions)
correct += (predictions == labels).sum()