summaryrefslogtreecommitdiff
path: root/mnist/src/mnist_linear.rs
diff options
context:
space:
mode:
Diffstat (limited to 'mnist/src/mnist_linear.rs')
-rw-r--r--mnist/src/mnist_linear.rs42
1 files changed, 42 insertions, 0 deletions
diff --git a/mnist/src/mnist_linear.rs b/mnist/src/mnist_linear.rs
new file mode 100644
index 0000000..d3245f7
--- /dev/null
+++ b/mnist/src/mnist_linear.rs
@@ -0,0 +1,42 @@
+// This should rearch 91.5% accuracy.
+
+use anyhow::Result;
+use tch::{kind, no_grad, vision, Kind, Tensor};
+
+const IMAGE_DIM: i64 = 784;
+const LABELS: i64 = 10;
+
+pub fn run() -> Result<()> {
+ let m = vision::mnist::load_dir("data")?;
+ println!("train-images: {:?}", m.train_images.size());
+ println!("train-labels: {:?}", m.train_labels.size());
+ println!("test-images: {:?}", m.test_images.size());
+ println!("test-labels: {:?}", m.test_labels.size());
+ let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true);
+ let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true);
+ for epoch in 1..200 {
+ let logits = m.train_images.mm(&ws) + &bs;
+ let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels);
+ ws.zero_grad();
+ bs.zero_grad();
+ loss.backward();
+ no_grad(|| {
+ ws += ws.grad() * (-1);
+ bs += bs.grad() * (-1);
+ });
+ let test_logits = m.test_images.mm(&ws) + &bs;
+ let test_accuracy = test_logits
+ .argmax(Some(-1), false)
+ .eq_tensor(&m.test_labels)
+ .to_kind(Kind::Float)
+ .mean(Kind::Float)
+ .double_value(&[]);
+ println!(
+ "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
+ epoch,
+ loss.double_value(&[]),
+ 100. * test_accuracy
+ );
+ }
+ Ok(())
+}