diff options
Diffstat (limited to 'mnist/src/main.rs')
-rw-r--r-- | mnist/src/main.rs | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/mnist/src/main.rs b/mnist/src/main.rs new file mode 100644 index 0000000..88e0483 --- /dev/null +++ b/mnist/src/main.rs @@ -0,0 +1,25 @@ +/* Some very simple models trained on the MNIST dataset. + The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/ + These files should be extracted in the 'data' directory. + train-images-idx3-ubyte.gz + train-labels-idx1-ubyte.gz + t10k-images-idx3-ubyte.gz + t10k-labels-idx1-ubyte.gz +*/ + +use anyhow::Result; + +mod mnist_conv; +mod mnist_linear; +mod mnist_nn; + +fn main() -> Result<()> { + let args: Vec<String> = std::env::args().collect(); + let model = if args.len() < 2 { None } else { Some(args[1].as_str()) }; + match model { + None => mnist_nn::run(), + Some("linear") => mnist_linear::run(), + Some("conv") => mnist_conv::run(), + Some(_) => mnist_nn::run(), + } +} |