summaryrefslogtreecommitdiff
path: root/mnist/src/mnist_nn.rs
diff options
context:
space:
mode:
authorAnthony Wang2022-01-11 12:24:42 -0600
committerAnthony Wang2022-01-11 12:24:42 -0600
commit3c47bb63daa20dc2fa6de6c0fbc7f6f612fcb0da (patch)
tree00fb7a92e3c1fffc0774150ba6c58dd23371fee4 /mnist/src/mnist_nn.rs
parent6be866b0b8b5ec676159eaa56c09371c02cbd7e1 (diff)
Add Rust MNIST codeHEADmaster
Diffstat (limited to 'mnist/src/mnist_nn.rs')
-rw-r--r--mnist/src/mnist_nn.rs34
1 files changed, 34 insertions, 0 deletions
diff --git a/mnist/src/mnist_nn.rs b/mnist/src/mnist_nn.rs
new file mode 100644
index 0000000..ca1bc4e
--- /dev/null
+++ b/mnist/src/mnist_nn.rs
@@ -0,0 +1,34 @@
+// This should rearch 97% accuracy.
+
+use anyhow::Result;
+use tch::{nn, nn::Module, nn::OptimizerConfig, Device};
+
+const IMAGE_DIM: i64 = 784;
+const HIDDEN_NODES: i64 = 128;
+const LABELS: i64 = 10;
+
+fn net(vs: &nn::Path) -> impl Module {
+ nn::seq()
+ .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
+}
+
+pub fn run() -> Result<()> {
+ let m = tch::vision::mnist::load_dir("data")?;
+ let vs = nn::VarStore::new(Device::Cpu);
+ let net = net(&vs.root());
+ let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
+ for epoch in 1..1000 {
+ let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels);
+ opt.backward_step(&loss);
+ let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels);
+ println!(
+ "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
+ epoch,
+ f64::from(&loss),
+ 100. * f64::from(&test_accuracy),
+ );
+ }
+ Ok(())
+}