diff options
Diffstat (limited to 'mnist')
-rw-r--r-- | mnist/Cargo.lock | 459 | ||||
-rw-r--r-- | mnist/Cargo.toml | 10 | ||||
-rw-r--r-- | mnist/src/main.rs | 25 | ||||
-rw-r--r-- | mnist/src/mnist_conv.rs | 54 | ||||
-rw-r--r-- | mnist/src/mnist_linear.rs | 42 | ||||
-rw-r--r-- | mnist/src/mnist_nn.rs | 34 |
6 files changed, 624 insertions, 0 deletions
diff --git a/mnist/Cargo.lock b/mnist/Cargo.lock new file mode 100644 index 0000000..4fa61d1 --- /dev/null +++ b/mnist/Cargo.lock @@ -0,0 +1,459 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "anyhow" +version = "1.0.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84450d0b4a8bd1ba4144ce8ce718fbc5d071358b1e5384bace6536b3d1f2d5b3" + +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "bzip2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6afcd980b5f3a45017c57e57a2fcccbb351cc43a356ce117ef760ef8052b89b0" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + +[[package]] +name = "cc" +version = "1.0.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22a9137b95ea06864e018375b72adfb7db6e6f68cfc8df5a04d00288050485ee" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crc32fast" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "738c290dfaea84fc1ca15ad9c168d083b05a714e1efddd8edaab678dc28d2836" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "curl" +version = "0.4.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7de97b894edd5b5bcceef8b78d7da9b75b1d2f2f9a910569d0bde3dd31d84939" +dependencies = [ + "curl-sys", + "libc", + "openssl-probe", + "openssl-sys", + "schannel", + "socket2", + "winapi", +] + +[[package]] +name = "curl-sys" +version = "0.4.52+curl-7.81.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14b8c2d1023ea5fded5b7b892e4b8e95f70038a421126a056761a84246a28971" +dependencies = [ + "cc", + "libc", + "libz-sys", + "openssl-sys", + "pkg-config", + "vcpkg", + "winapi", +] + +[[package]] +name = "flate2" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f" +dependencies = [ + "cfg-if", + "crc32fast", + "libc", + "miniz_oxide", +] + +[[package]] +name = "getrandom" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125" + +[[package]] +name = "libz-sys" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de5435b8549c16d423ed0c03dbaafe57cf6c3344744f1242520d59c9d8ecec66" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" +dependencies = [ + "rawpointer", +] + +[[package]] +name = "miniz_oxide" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" +dependencies = [ + "adler", + "autocfg", +] + +[[package]] +name = "mnist" +version = "0.1.0" +dependencies = [ + "anyhow", + "tch", +] + +[[package]] +name = "ndarray" +version = "0.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec23e6762830658d2b3d385a75aa212af2f67a4586d4442907144f3bb6a1ca8" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + +[[package]] +name = "num-complex" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" +dependencies = [ + "autocfg", +] + +[[package]] +name = "openssl-probe" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a" + +[[package]] +name = "openssl-sys" +version = "0.9.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e46109c383602735fa0a2e48dd2b7c892b048e1bf69e5c3b1d804b7d9c203cb" +dependencies = [ + "autocfg", + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "pkg-config" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe" + +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + +[[package]] +name = "proc-macro2" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029" +dependencies = [ + "unicode-xid", +] + +[[package]] +name = "quote" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47aa80447ce4daf1717500037052af176af5d38cc3e571d9ec1c7353fc10c87d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", + "rand_hc", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_hc" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "schannel" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75" +dependencies = [ + "lazy_static", + "winapi", +] + +[[package]] +name = "socket2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dc90fe6c7be1a323296982db1836d1ea9e47b6839496dde9a541bc496df3516" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "syn" +version = "1.0.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a684ac3dcd8913827e18cd09a68384ee66c1de24157e3c556c9ab16d85695fb7" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "tch" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b73f876b186599e22b01fa6ebfeea2dee2f11e8083463ab3572933d8201436b" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand", + "thiserror", + "torch-sys", + "zip", +] + +[[package]] +name = "thiserror" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "time" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "torch-sys" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34cc0f21b1aad5d71d529e9fe4dbbbdbf53918d7b4bde946f523839aa32cffae" +dependencies = [ + "anyhow", + "cc", + "curl", + "libc", + "zip", +] + +[[package]] +name = "unicode-xid" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "wasi" +version = "0.10.2+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "zip" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815" +dependencies = [ + "byteorder", + "bzip2", + "crc32fast", + "flate2", + "thiserror", + "time", +] diff --git a/mnist/Cargo.toml b/mnist/Cargo.toml new file mode 100644 index 0000000..ea3c1f7 --- /dev/null +++ b/mnist/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "mnist" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0" +tch = "0.6.1" diff --git a/mnist/src/main.rs b/mnist/src/main.rs new file mode 100644 index 0000000..88e0483 --- /dev/null +++ b/mnist/src/main.rs @@ -0,0 +1,25 @@ +/* Some very simple models trained on the MNIST dataset. + The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/ + These files should be extracted in the 'data' directory. + train-images-idx3-ubyte.gz + train-labels-idx1-ubyte.gz + t10k-images-idx3-ubyte.gz + t10k-labels-idx1-ubyte.gz +*/ + +use anyhow::Result; + +mod mnist_conv; +mod mnist_linear; +mod mnist_nn; + +fn main() -> Result<()> { + let args: Vec<String> = std::env::args().collect(); + let model = if args.len() < 2 { None } else { Some(args[1].as_str()) }; + match model { + None => mnist_nn::run(), + Some("linear") => mnist_linear::run(), + Some("conv") => mnist_conv::run(), + Some(_) => mnist_nn::run(), + } +} diff --git a/mnist/src/mnist_conv.rs b/mnist/src/mnist_conv.rs new file mode 100644 index 0000000..3cb5ee1 --- /dev/null +++ b/mnist/src/mnist_conv.rs @@ -0,0 +1,54 @@ +// CNN model. This should rearch 99.1% accuracy. + +use anyhow::Result; +use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor}; + +#[derive(Debug)] +struct Net { + conv1: nn::Conv2D, + conv2: nn::Conv2D, + fc1: nn::Linear, + fc2: nn::Linear, +} + +impl Net { + fn new(vs: &nn::Path) -> Net { + let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default()); + let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default()); + let fc1 = nn::linear(vs, 1024, 1024, Default::default()); + let fc2 = nn::linear(vs, 1024, 10, Default::default()); + Net { conv1, conv2, fc1, fc2 } + } +} + +impl nn::ModuleT for Net { + fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor { + xs.view([-1, 1, 28, 28]) + .apply(&self.conv1) + .max_pool2d_default(2) + .apply(&self.conv2) + .max_pool2d_default(2) + .view([-1, 1024]) + .apply(&self.fc1) + .relu() + .dropout_(0.5, train) + .apply(&self.fc2) + } +} + +pub fn run() -> Result<()> { + let m = tch::vision::mnist::load_dir("data")?; + let vs = nn::VarStore::new(Device::cuda_if_available()); + let net = Net::new(&vs.root()); + let mut opt = nn::Adam::default().build(&vs, 1e-4)?; + for epoch in 1..100 { + for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) { + let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels); + opt.backward_step(&loss); + } + let test_accuracy = + net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024); + println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,); + } + Ok(()) +} diff --git a/mnist/src/mnist_linear.rs b/mnist/src/mnist_linear.rs new file mode 100644 index 0000000..d3245f7 --- /dev/null +++ b/mnist/src/mnist_linear.rs @@ -0,0 +1,42 @@ +// This should rearch 91.5% accuracy. + +use anyhow::Result; +use tch::{kind, no_grad, vision, Kind, Tensor}; + +const IMAGE_DIM: i64 = 784; +const LABELS: i64 = 10; + +pub fn run() -> Result<()> { + let m = vision::mnist::load_dir("data")?; + println!("train-images: {:?}", m.train_images.size()); + println!("train-labels: {:?}", m.train_labels.size()); + println!("test-images: {:?}", m.test_images.size()); + println!("test-labels: {:?}", m.test_labels.size()); + let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true); + let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true); + for epoch in 1..200 { + let logits = m.train_images.mm(&ws) + &bs; + let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels); + ws.zero_grad(); + bs.zero_grad(); + loss.backward(); + no_grad(|| { + ws += ws.grad() * (-1); + bs += bs.grad() * (-1); + }); + let test_logits = m.test_images.mm(&ws) + &bs; + let test_accuracy = test_logits + .argmax(Some(-1), false) + .eq_tensor(&m.test_labels) + .to_kind(Kind::Float) + .mean(Kind::Float) + .double_value(&[]); + println!( + "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", + epoch, + loss.double_value(&[]), + 100. * test_accuracy + ); + } + Ok(()) +} diff --git a/mnist/src/mnist_nn.rs b/mnist/src/mnist_nn.rs new file mode 100644 index 0000000..ca1bc4e --- /dev/null +++ b/mnist/src/mnist_nn.rs @@ -0,0 +1,34 @@ +// This should rearch 97% accuracy. + +use anyhow::Result; +use tch::{nn, nn::Module, nn::OptimizerConfig, Device}; + +const IMAGE_DIM: i64 = 784; +const HIDDEN_NODES: i64 = 128; +const LABELS: i64 = 10; + +fn net(vs: &nn::Path) -> impl Module { + nn::seq() + .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default())) +} + +pub fn run() -> Result<()> { + let m = tch::vision::mnist::load_dir("data")?; + let vs = nn::VarStore::new(Device::Cpu); + let net = net(&vs.root()); + let mut opt = nn::Adam::default().build(&vs, 1e-3)?; + for epoch in 1..1000 { + let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels); + opt.backward_step(&loss); + let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels); + println!( + "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", + epoch, + f64::from(&loss), + 100. * f64::from(&test_accuracy), + ); + } + Ok(()) +} |