summaryrefslogtreecommitdiff
path: root/mnist/src/main.rs
blob: 88e04839912689a6624183b8c7a34681d1e39ed0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/* Some very simple models trained on the MNIST dataset.
   The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/
   These files should be extracted in the 'data' directory.
     train-images-idx3-ubyte.gz
     train-labels-idx1-ubyte.gz
     t10k-images-idx3-ubyte.gz
     t10k-labels-idx1-ubyte.gz
*/

use anyhow::Result;

mod mnist_conv;
mod mnist_linear;
mod mnist_nn;

fn main() -> Result<()> {
    let args: Vec<String> = std::env::args().collect();
    let model = if args.len() < 2 { None } else { Some(args[1].as_str()) };
    match model {
        None => mnist_nn::run(),
        Some("linear") => mnist_linear::run(),
        Some("conv") => mnist_conv::run(),
        Some(_) => mnist_nn::run(),
    }
}