diff options
author | Anthony Wang | 2022-01-11 12:24:42 -0600 |
---|---|---|
committer | Anthony Wang | 2022-01-11 12:24:42 -0600 |
commit | 3c47bb63daa20dc2fa6de6c0fbc7f6f612fcb0da (patch) | |
tree | 00fb7a92e3c1fffc0774150ba6c58dd23371fee4 /mnist/src/mnist_conv.rs | |
parent | 6be866b0b8b5ec676159eaa56c09371c02cbd7e1 (diff) |
Diffstat (limited to 'mnist/src/mnist_conv.rs')
-rw-r--r-- | mnist/src/mnist_conv.rs | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/mnist/src/mnist_conv.rs b/mnist/src/mnist_conv.rs new file mode 100644 index 0000000..3cb5ee1 --- /dev/null +++ b/mnist/src/mnist_conv.rs @@ -0,0 +1,54 @@ +// CNN model. This should rearch 99.1% accuracy. + +use anyhow::Result; +use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor}; + +#[derive(Debug)] +struct Net { + conv1: nn::Conv2D, + conv2: nn::Conv2D, + fc1: nn::Linear, + fc2: nn::Linear, +} + +impl Net { + fn new(vs: &nn::Path) -> Net { + let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default()); + let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default()); + let fc1 = nn::linear(vs, 1024, 1024, Default::default()); + let fc2 = nn::linear(vs, 1024, 10, Default::default()); + Net { conv1, conv2, fc1, fc2 } + } +} + +impl nn::ModuleT for Net { + fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor { + xs.view([-1, 1, 28, 28]) + .apply(&self.conv1) + .max_pool2d_default(2) + .apply(&self.conv2) + .max_pool2d_default(2) + .view([-1, 1024]) + .apply(&self.fc1) + .relu() + .dropout_(0.5, train) + .apply(&self.fc2) + } +} + +pub fn run() -> Result<()> { + let m = tch::vision::mnist::load_dir("data")?; + let vs = nn::VarStore::new(Device::cuda_if_available()); + let net = Net::new(&vs.root()); + let mut opt = nn::Adam::default().build(&vs, 1e-4)?; + for epoch in 1..100 { + for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) { + let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels); + opt.backward_step(&loss); + } + let test_accuracy = + net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024); + println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,); + } + Ok(()) +} |