blob: 88e04839912689a6624183b8c7a34681d1e39ed0 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
/* Some very simple models trained on the MNIST dataset.
The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/
These files should be extracted in the 'data' directory.
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
*/
use anyhow::Result;
mod mnist_conv;
mod mnist_linear;
mod mnist_nn;
fn main() -> Result<()> {
let args: Vec<String> = std::env::args().collect();
let model = if args.len() < 2 { None } else { Some(args[1].as_str()) };
match model {
None => mnist_nn::run(),
Some("linear") => mnist_linear::run(),
Some("conv") => mnist_conv::run(),
Some(_) => mnist_nn::run(),
}
}
|