summaryrefslogtreecommitdiff
path: root/mnist/src/mnist_conv.rs
diff options
context:
space:
mode:
authorAnthony Wang2022-01-11 12:24:42 -0600
committerAnthony Wang2022-01-11 12:24:42 -0600
commit3c47bb63daa20dc2fa6de6c0fbc7f6f612fcb0da (patch)
tree00fb7a92e3c1fffc0774150ba6c58dd23371fee4 /mnist/src/mnist_conv.rs
parent6be866b0b8b5ec676159eaa56c09371c02cbd7e1 (diff)
Add Rust MNIST codeHEADmaster
Diffstat (limited to 'mnist/src/mnist_conv.rs')
-rw-r--r--mnist/src/mnist_conv.rs54
1 files changed, 54 insertions, 0 deletions
diff --git a/mnist/src/mnist_conv.rs b/mnist/src/mnist_conv.rs
new file mode 100644
index 0000000..3cb5ee1
--- /dev/null
+++ b/mnist/src/mnist_conv.rs
@@ -0,0 +1,54 @@
+// CNN model. This should rearch 99.1% accuracy.
+
+use anyhow::Result;
+use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
+
+#[derive(Debug)]
+struct Net {
+ conv1: nn::Conv2D,
+ conv2: nn::Conv2D,
+ fc1: nn::Linear,
+ fc2: nn::Linear,
+}
+
+impl Net {
+ fn new(vs: &nn::Path) -> Net {
+ let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
+ let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
+ let fc1 = nn::linear(vs, 1024, 1024, Default::default());
+ let fc2 = nn::linear(vs, 1024, 10, Default::default());
+ Net { conv1, conv2, fc1, fc2 }
+ }
+}
+
+impl nn::ModuleT for Net {
+ fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
+ xs.view([-1, 1, 28, 28])
+ .apply(&self.conv1)
+ .max_pool2d_default(2)
+ .apply(&self.conv2)
+ .max_pool2d_default(2)
+ .view([-1, 1024])
+ .apply(&self.fc1)
+ .relu()
+ .dropout_(0.5, train)
+ .apply(&self.fc2)
+ }
+}
+
+pub fn run() -> Result<()> {
+ let m = tch::vision::mnist::load_dir("data")?;
+ let vs = nn::VarStore::new(Device::cuda_if_available());
+ let net = Net::new(&vs.root());
+ let mut opt = nn::Adam::default().build(&vs, 1e-4)?;
+ for epoch in 1..100 {
+ for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) {
+ let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels);
+ opt.backward_step(&loss);
+ }
+ let test_accuracy =
+ net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024);
+ println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
+ }
+ Ok(())
+}