diff options
Diffstat (limited to 'mnist/src')
-rw-r--r-- | mnist/src/main.rs | 25 | ||||
-rw-r--r-- | mnist/src/mnist_conv.rs | 54 | ||||
-rw-r--r-- | mnist/src/mnist_linear.rs | 42 | ||||
-rw-r--r-- | mnist/src/mnist_nn.rs | 34 |
4 files changed, 155 insertions, 0 deletions
diff --git a/mnist/src/main.rs b/mnist/src/main.rs new file mode 100644 index 0000000..88e0483 --- /dev/null +++ b/mnist/src/main.rs @@ -0,0 +1,25 @@ +/* Some very simple models trained on the MNIST dataset. + The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/ + These files should be extracted in the 'data' directory. + train-images-idx3-ubyte.gz + train-labels-idx1-ubyte.gz + t10k-images-idx3-ubyte.gz + t10k-labels-idx1-ubyte.gz +*/ + +use anyhow::Result; + +mod mnist_conv; +mod mnist_linear; +mod mnist_nn; + +fn main() -> Result<()> { + let args: Vec<String> = std::env::args().collect(); + let model = if args.len() < 2 { None } else { Some(args[1].as_str()) }; + match model { + None => mnist_nn::run(), + Some("linear") => mnist_linear::run(), + Some("conv") => mnist_conv::run(), + Some(_) => mnist_nn::run(), + } +} diff --git a/mnist/src/mnist_conv.rs b/mnist/src/mnist_conv.rs new file mode 100644 index 0000000..3cb5ee1 --- /dev/null +++ b/mnist/src/mnist_conv.rs @@ -0,0 +1,54 @@ +// CNN model. This should rearch 99.1% accuracy. + +use anyhow::Result; +use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor}; + +#[derive(Debug)] +struct Net { + conv1: nn::Conv2D, + conv2: nn::Conv2D, + fc1: nn::Linear, + fc2: nn::Linear, +} + +impl Net { + fn new(vs: &nn::Path) -> Net { + let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default()); + let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default()); + let fc1 = nn::linear(vs, 1024, 1024, Default::default()); + let fc2 = nn::linear(vs, 1024, 10, Default::default()); + Net { conv1, conv2, fc1, fc2 } + } +} + +impl nn::ModuleT for Net { + fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor { + xs.view([-1, 1, 28, 28]) + .apply(&self.conv1) + .max_pool2d_default(2) + .apply(&self.conv2) + .max_pool2d_default(2) + .view([-1, 1024]) + .apply(&self.fc1) + .relu() + .dropout_(0.5, train) + .apply(&self.fc2) + } +} + +pub fn run() -> Result<()> { + let m = tch::vision::mnist::load_dir("data")?; + let vs = nn::VarStore::new(Device::cuda_if_available()); + let net = Net::new(&vs.root()); + let mut opt = nn::Adam::default().build(&vs, 1e-4)?; + for epoch in 1..100 { + for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) { + let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels); + opt.backward_step(&loss); + } + let test_accuracy = + net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024); + println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,); + } + Ok(()) +} diff --git a/mnist/src/mnist_linear.rs b/mnist/src/mnist_linear.rs new file mode 100644 index 0000000..d3245f7 --- /dev/null +++ b/mnist/src/mnist_linear.rs @@ -0,0 +1,42 @@ +// This should rearch 91.5% accuracy. + +use anyhow::Result; +use tch::{kind, no_grad, vision, Kind, Tensor}; + +const IMAGE_DIM: i64 = 784; +const LABELS: i64 = 10; + +pub fn run() -> Result<()> { + let m = vision::mnist::load_dir("data")?; + println!("train-images: {:?}", m.train_images.size()); + println!("train-labels: {:?}", m.train_labels.size()); + println!("test-images: {:?}", m.test_images.size()); + println!("test-labels: {:?}", m.test_labels.size()); + let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true); + let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true); + for epoch in 1..200 { + let logits = m.train_images.mm(&ws) + &bs; + let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels); + ws.zero_grad(); + bs.zero_grad(); + loss.backward(); + no_grad(|| { + ws += ws.grad() * (-1); + bs += bs.grad() * (-1); + }); + let test_logits = m.test_images.mm(&ws) + &bs; + let test_accuracy = test_logits + .argmax(Some(-1), false) + .eq_tensor(&m.test_labels) + .to_kind(Kind::Float) + .mean(Kind::Float) + .double_value(&[]); + println!( + "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", + epoch, + loss.double_value(&[]), + 100. * test_accuracy + ); + } + Ok(()) +} diff --git a/mnist/src/mnist_nn.rs b/mnist/src/mnist_nn.rs new file mode 100644 index 0000000..ca1bc4e --- /dev/null +++ b/mnist/src/mnist_nn.rs @@ -0,0 +1,34 @@ +// This should rearch 97% accuracy. + +use anyhow::Result; +use tch::{nn, nn::Module, nn::OptimizerConfig, Device}; + +const IMAGE_DIM: i64 = 784; +const HIDDEN_NODES: i64 = 128; +const LABELS: i64 = 10; + +fn net(vs: &nn::Path) -> impl Module { + nn::seq() + .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default())) +} + +pub fn run() -> Result<()> { + let m = tch::vision::mnist::load_dir("data")?; + let vs = nn::VarStore::new(Device::Cpu); + let net = net(&vs.root()); + let mut opt = nn::Adam::default().build(&vs, 1e-3)?; + for epoch in 1..1000 { + let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels); + opt.backward_step(&loss); + let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels); + println!( + "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", + epoch, + f64::from(&loss), + 100. * f64::from(&test_accuracy), + ); + } + Ok(()) +} |