summaryrefslogtreecommitdiff
path: root/mnist.py
diff options
context:
space:
mode:
authorAnthony Wang2021-09-03 11:32:55 -0500
committerAnthony Wang2021-09-03 11:32:55 -0500
commita4f072d08a1258fe4a4bf33b73bac30ac2e0ca4d (patch)
tree76515be649be891245ca8756fd0937f6e79fc1f6 /mnist.py
parent20c8739f9979a7e76aa047c58ff7114f39fc4793 (diff)
Save model to file and generate images
Diffstat (limited to 'mnist.py')
-rw-r--r--mnist.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/mnist.py b/mnist.py
index 4466052..8ab8614 100644
--- a/mnist.py
+++ b/mnist.py
@@ -113,14 +113,18 @@ for epoch in range(num_epochs):
print("Iteration: {}, Loss: {}, Accuracy: {}%".format(count, loss.data, accuracy))
+torch.save(model.state_dict(), "model.pth")
+print("Saved PyTorch Model State to model.pth")
+
+
plt.plot(iteration_list, loss_list)
plt.xlabel("No. of Iteration")
plt.ylabel("Loss")
plt.title("Iterations vs Loss")
-plt.show()
+plt.savefig("loss.png")
plt.plot(iteration_list, accuracy_list)
plt.xlabel("No. of Iteration")
plt.ylabel("Accuracy")
plt.title("Iterations vs Accuracy")
-plt.show()
+plt.savefig("accuracy.png")