diff options
author | Anthony Wang | 2022-01-11 12:24:42 -0600 |
---|---|---|
committer | Anthony Wang | 2022-01-11 12:24:42 -0600 |
commit | 3c47bb63daa20dc2fa6de6c0fbc7f6f612fcb0da (patch) | |
tree | 00fb7a92e3c1fffc0774150ba6c58dd23371fee4 /mnist/src/mnist_nn.rs | |
parent | 6be866b0b8b5ec676159eaa56c09371c02cbd7e1 (diff) |
Diffstat (limited to 'mnist/src/mnist_nn.rs')
-rw-r--r-- | mnist/src/mnist_nn.rs | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/mnist/src/mnist_nn.rs b/mnist/src/mnist_nn.rs new file mode 100644 index 0000000..ca1bc4e --- /dev/null +++ b/mnist/src/mnist_nn.rs @@ -0,0 +1,34 @@ +// This should rearch 97% accuracy. + +use anyhow::Result; +use tch::{nn, nn::Module, nn::OptimizerConfig, Device}; + +const IMAGE_DIM: i64 = 784; +const HIDDEN_NODES: i64 = 128; +const LABELS: i64 = 10; + +fn net(vs: &nn::Path) -> impl Module { + nn::seq() + .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default())) +} + +pub fn run() -> Result<()> { + let m = tch::vision::mnist::load_dir("data")?; + let vs = nn::VarStore::new(Device::Cpu); + let net = net(&vs.root()); + let mut opt = nn::Adam::default().build(&vs, 1e-3)?; + for epoch in 1..1000 { + let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels); + opt.backward_step(&loss); + let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels); + println!( + "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", + epoch, + f64::from(&loss), + 100. * f64::from(&test_accuracy), + ); + } + Ok(()) +} |