summaryrefslogtreecommitdiff
path: root/mnist/src
diff options
context:
space:
mode:
Diffstat (limited to 'mnist/src')
-rw-r--r--mnist/src/main.rs25
-rw-r--r--mnist/src/mnist_conv.rs54
-rw-r--r--mnist/src/mnist_linear.rs42
-rw-r--r--mnist/src/mnist_nn.rs34
4 files changed, 155 insertions, 0 deletions
diff --git a/mnist/src/main.rs b/mnist/src/main.rs
new file mode 100644
index 0000000..88e0483
--- /dev/null
+++ b/mnist/src/main.rs
@@ -0,0 +1,25 @@
+/* Some very simple models trained on the MNIST dataset.
+ The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/
+ These files should be extracted in the 'data' directory.
+ train-images-idx3-ubyte.gz
+ train-labels-idx1-ubyte.gz
+ t10k-images-idx3-ubyte.gz
+ t10k-labels-idx1-ubyte.gz
+*/
+
+use anyhow::Result;
+
+mod mnist_conv;
+mod mnist_linear;
+mod mnist_nn;
+
+fn main() -> Result<()> {
+ let args: Vec<String> = std::env::args().collect();
+ let model = if args.len() < 2 { None } else { Some(args[1].as_str()) };
+ match model {
+ None => mnist_nn::run(),
+ Some("linear") => mnist_linear::run(),
+ Some("conv") => mnist_conv::run(),
+ Some(_) => mnist_nn::run(),
+ }
+}
diff --git a/mnist/src/mnist_conv.rs b/mnist/src/mnist_conv.rs
new file mode 100644
index 0000000..3cb5ee1
--- /dev/null
+++ b/mnist/src/mnist_conv.rs
@@ -0,0 +1,54 @@
+// CNN model. This should rearch 99.1% accuracy.
+
+use anyhow::Result;
+use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
+
+#[derive(Debug)]
+struct Net {
+ conv1: nn::Conv2D,
+ conv2: nn::Conv2D,
+ fc1: nn::Linear,
+ fc2: nn::Linear,
+}
+
+impl Net {
+ fn new(vs: &nn::Path) -> Net {
+ let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
+ let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
+ let fc1 = nn::linear(vs, 1024, 1024, Default::default());
+ let fc2 = nn::linear(vs, 1024, 10, Default::default());
+ Net { conv1, conv2, fc1, fc2 }
+ }
+}
+
+impl nn::ModuleT for Net {
+ fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
+ xs.view([-1, 1, 28, 28])
+ .apply(&self.conv1)
+ .max_pool2d_default(2)
+ .apply(&self.conv2)
+ .max_pool2d_default(2)
+ .view([-1, 1024])
+ .apply(&self.fc1)
+ .relu()
+ .dropout_(0.5, train)
+ .apply(&self.fc2)
+ }
+}
+
+pub fn run() -> Result<()> {
+ let m = tch::vision::mnist::load_dir("data")?;
+ let vs = nn::VarStore::new(Device::cuda_if_available());
+ let net = Net::new(&vs.root());
+ let mut opt = nn::Adam::default().build(&vs, 1e-4)?;
+ for epoch in 1..100 {
+ for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) {
+ let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels);
+ opt.backward_step(&loss);
+ }
+ let test_accuracy =
+ net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024);
+ println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
+ }
+ Ok(())
+}
diff --git a/mnist/src/mnist_linear.rs b/mnist/src/mnist_linear.rs
new file mode 100644
index 0000000..d3245f7
--- /dev/null
+++ b/mnist/src/mnist_linear.rs
@@ -0,0 +1,42 @@
+// This should rearch 91.5% accuracy.
+
+use anyhow::Result;
+use tch::{kind, no_grad, vision, Kind, Tensor};
+
+const IMAGE_DIM: i64 = 784;
+const LABELS: i64 = 10;
+
+pub fn run() -> Result<()> {
+ let m = vision::mnist::load_dir("data")?;
+ println!("train-images: {:?}", m.train_images.size());
+ println!("train-labels: {:?}", m.train_labels.size());
+ println!("test-images: {:?}", m.test_images.size());
+ println!("test-labels: {:?}", m.test_labels.size());
+ let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true);
+ let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true);
+ for epoch in 1..200 {
+ let logits = m.train_images.mm(&ws) + &bs;
+ let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels);
+ ws.zero_grad();
+ bs.zero_grad();
+ loss.backward();
+ no_grad(|| {
+ ws += ws.grad() * (-1);
+ bs += bs.grad() * (-1);
+ });
+ let test_logits = m.test_images.mm(&ws) + &bs;
+ let test_accuracy = test_logits
+ .argmax(Some(-1), false)
+ .eq_tensor(&m.test_labels)
+ .to_kind(Kind::Float)
+ .mean(Kind::Float)
+ .double_value(&[]);
+ println!(
+ "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
+ epoch,
+ loss.double_value(&[]),
+ 100. * test_accuracy
+ );
+ }
+ Ok(())
+}
diff --git a/mnist/src/mnist_nn.rs b/mnist/src/mnist_nn.rs
new file mode 100644
index 0000000..ca1bc4e
--- /dev/null
+++ b/mnist/src/mnist_nn.rs
@@ -0,0 +1,34 @@
+// This should rearch 97% accuracy.
+
+use anyhow::Result;
+use tch::{nn, nn::Module, nn::OptimizerConfig, Device};
+
+const IMAGE_DIM: i64 = 784;
+const HIDDEN_NODES: i64 = 128;
+const LABELS: i64 = 10;
+
+fn net(vs: &nn::Path) -> impl Module {
+ nn::seq()
+ .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
+}
+
+pub fn run() -> Result<()> {
+ let m = tch::vision::mnist::load_dir("data")?;
+ let vs = nn::VarStore::new(Device::Cpu);
+ let net = net(&vs.root());
+ let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
+ for epoch in 1..1000 {
+ let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels);
+ opt.backward_step(&loss);
+ let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels);
+ println!(
+ "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
+ epoch,
+ f64::from(&loss),
+ 100. * f64::from(&test_accuracy),
+ );
+ }
+ Ok(())
+}