summaryrefslogtreecommitdiff
path: root/mnist
diff options
context:
space:
mode:
Diffstat (limited to 'mnist')
-rw-r--r--mnist/Cargo.lock459
-rw-r--r--mnist/Cargo.toml10
-rw-r--r--mnist/src/main.rs25
-rw-r--r--mnist/src/mnist_conv.rs54
-rw-r--r--mnist/src/mnist_linear.rs42
-rw-r--r--mnist/src/mnist_nn.rs34
6 files changed, 624 insertions, 0 deletions
diff --git a/mnist/Cargo.lock b/mnist/Cargo.lock
new file mode 100644
index 0000000..4fa61d1
--- /dev/null
+++ b/mnist/Cargo.lock
@@ -0,0 +1,459 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "adler"
+version = "1.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
+
+[[package]]
+name = "anyhow"
+version = "1.0.52"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "84450d0b4a8bd1ba4144ce8ce718fbc5d071358b1e5384bace6536b3d1f2d5b3"
+
+[[package]]
+name = "autocfg"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
+
+[[package]]
+name = "byteorder"
+version = "1.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
+
+[[package]]
+name = "bzip2"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6afcd980b5f3a45017c57e57a2fcccbb351cc43a356ce117ef760ef8052b89b0"
+dependencies = [
+ "bzip2-sys",
+ "libc",
+]
+
+[[package]]
+name = "bzip2-sys"
+version = "0.1.11+1.0.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
+dependencies = [
+ "cc",
+ "libc",
+ "pkg-config",
+]
+
+[[package]]
+name = "cc"
+version = "1.0.72"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "22a9137b95ea06864e018375b72adfb7db6e6f68cfc8df5a04d00288050485ee"
+
+[[package]]
+name = "cfg-if"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+
+[[package]]
+name = "crc32fast"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "738c290dfaea84fc1ca15ad9c168d083b05a714e1efddd8edaab678dc28d2836"
+dependencies = [
+ "cfg-if",
+]
+
+[[package]]
+name = "curl"
+version = "0.4.42"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7de97b894edd5b5bcceef8b78d7da9b75b1d2f2f9a910569d0bde3dd31d84939"
+dependencies = [
+ "curl-sys",
+ "libc",
+ "openssl-probe",
+ "openssl-sys",
+ "schannel",
+ "socket2",
+ "winapi",
+]
+
+[[package]]
+name = "curl-sys"
+version = "0.4.52+curl-7.81.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "14b8c2d1023ea5fded5b7b892e4b8e95f70038a421126a056761a84246a28971"
+dependencies = [
+ "cc",
+ "libc",
+ "libz-sys",
+ "openssl-sys",
+ "pkg-config",
+ "vcpkg",
+ "winapi",
+]
+
+[[package]]
+name = "flate2"
+version = "1.0.22"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f"
+dependencies = [
+ "cfg-if",
+ "crc32fast",
+ "libc",
+ "miniz_oxide",
+]
+
+[[package]]
+name = "getrandom"
+version = "0.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "wasi",
+]
+
+[[package]]
+name = "half"
+version = "1.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
+
+[[package]]
+name = "lazy_static"
+version = "1.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
+
+[[package]]
+name = "libc"
+version = "0.2.112"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125"
+
+[[package]]
+name = "libz-sys"
+version = "1.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "de5435b8549c16d423ed0c03dbaafe57cf6c3344744f1242520d59c9d8ecec66"
+dependencies = [
+ "cc",
+ "libc",
+ "pkg-config",
+ "vcpkg",
+]
+
+[[package]]
+name = "matrixmultiply"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84"
+dependencies = [
+ "rawpointer",
+]
+
+[[package]]
+name = "miniz_oxide"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b"
+dependencies = [
+ "adler",
+ "autocfg",
+]
+
+[[package]]
+name = "mnist"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "tch",
+]
+
+[[package]]
+name = "ndarray"
+version = "0.15.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dec23e6762830658d2b3d385a75aa212af2f67a4586d4442907144f3bb6a1ca8"
+dependencies = [
+ "matrixmultiply",
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "rawpointer",
+]
+
+[[package]]
+name = "num-complex"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
+name = "num-integer"
+version = "0.1.44"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db"
+dependencies = [
+ "autocfg",
+ "num-traits",
+]
+
+[[package]]
+name = "num-traits"
+version = "0.2.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "openssl-probe"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a"
+
+[[package]]
+name = "openssl-sys"
+version = "0.9.72"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7e46109c383602735fa0a2e48dd2b7c892b048e1bf69e5c3b1d804b7d9c203cb"
+dependencies = [
+ "autocfg",
+ "cc",
+ "libc",
+ "pkg-config",
+ "vcpkg",
+]
+
+[[package]]
+name = "pkg-config"
+version = "0.3.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe"
+
+[[package]]
+name = "ppv-lite86"
+version = "0.2.16"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
+
+[[package]]
+name = "proc-macro2"
+version = "1.0.36"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029"
+dependencies = [
+ "unicode-xid",
+]
+
+[[package]]
+name = "quote"
+version = "1.0.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "47aa80447ce4daf1717500037052af176af5d38cc3e571d9ec1c7353fc10c87d"
+dependencies = [
+ "proc-macro2",
+]
+
+[[package]]
+name = "rand"
+version = "0.8.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8"
+dependencies = [
+ "libc",
+ "rand_chacha",
+ "rand_core",
+ "rand_hc",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+dependencies = [
+ "ppv-lite86",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7"
+dependencies = [
+ "getrandom",
+]
+
+[[package]]
+name = "rand_hc"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7"
+dependencies = [
+ "rand_core",
+]
+
+[[package]]
+name = "rawpointer"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
+
+[[package]]
+name = "schannel"
+version = "0.1.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75"
+dependencies = [
+ "lazy_static",
+ "winapi",
+]
+
+[[package]]
+name = "socket2"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5dc90fe6c7be1a323296982db1836d1ea9e47b6839496dde9a541bc496df3516"
+dependencies = [
+ "libc",
+ "winapi",
+]
+
+[[package]]
+name = "syn"
+version = "1.0.85"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a684ac3dcd8913827e18cd09a68384ee66c1de24157e3c556c9ab16d85695fb7"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-xid",
+]
+
+[[package]]
+name = "tch"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b73f876b186599e22b01fa6ebfeea2dee2f11e8083463ab3572933d8201436b"
+dependencies = [
+ "half",
+ "lazy_static",
+ "libc",
+ "ndarray",
+ "rand",
+ "thiserror",
+ "torch-sys",
+ "zip",
+]
+
+[[package]]
+name = "thiserror"
+version = "1.0.30"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
+dependencies = [
+ "thiserror-impl",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "1.0.30"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "time"
+version = "0.1.43"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
+dependencies = [
+ "libc",
+ "winapi",
+]
+
+[[package]]
+name = "torch-sys"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34cc0f21b1aad5d71d529e9fe4dbbbdbf53918d7b4bde946f523839aa32cffae"
+dependencies = [
+ "anyhow",
+ "cc",
+ "curl",
+ "libc",
+ "zip",
+]
+
+[[package]]
+name = "unicode-xid"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
+
+[[package]]
+name = "vcpkg"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
+
+[[package]]
+name = "wasi"
+version = "0.10.2+wasi-snapshot-preview1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
+
+[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+
+[[package]]
+name = "zip"
+version = "0.5.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815"
+dependencies = [
+ "byteorder",
+ "bzip2",
+ "crc32fast",
+ "flate2",
+ "thiserror",
+ "time",
+]
diff --git a/mnist/Cargo.toml b/mnist/Cargo.toml
new file mode 100644
index 0000000..ea3c1f7
--- /dev/null
+++ b/mnist/Cargo.toml
@@ -0,0 +1,10 @@
+[package]
+name = "mnist"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+anyhow = "1.0"
+tch = "0.6.1"
diff --git a/mnist/src/main.rs b/mnist/src/main.rs
new file mode 100644
index 0000000..88e0483
--- /dev/null
+++ b/mnist/src/main.rs
@@ -0,0 +1,25 @@
+/* Some very simple models trained on the MNIST dataset.
+ The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/
+ These files should be extracted in the 'data' directory.
+ train-images-idx3-ubyte.gz
+ train-labels-idx1-ubyte.gz
+ t10k-images-idx3-ubyte.gz
+ t10k-labels-idx1-ubyte.gz
+*/
+
+use anyhow::Result;
+
+mod mnist_conv;
+mod mnist_linear;
+mod mnist_nn;
+
+fn main() -> Result<()> {
+ let args: Vec<String> = std::env::args().collect();
+ let model = if args.len() < 2 { None } else { Some(args[1].as_str()) };
+ match model {
+ None => mnist_nn::run(),
+ Some("linear") => mnist_linear::run(),
+ Some("conv") => mnist_conv::run(),
+ Some(_) => mnist_nn::run(),
+ }
+}
diff --git a/mnist/src/mnist_conv.rs b/mnist/src/mnist_conv.rs
new file mode 100644
index 0000000..3cb5ee1
--- /dev/null
+++ b/mnist/src/mnist_conv.rs
@@ -0,0 +1,54 @@
+// CNN model. This should rearch 99.1% accuracy.
+
+use anyhow::Result;
+use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
+
+#[derive(Debug)]
+struct Net {
+ conv1: nn::Conv2D,
+ conv2: nn::Conv2D,
+ fc1: nn::Linear,
+ fc2: nn::Linear,
+}
+
+impl Net {
+ fn new(vs: &nn::Path) -> Net {
+ let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
+ let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
+ let fc1 = nn::linear(vs, 1024, 1024, Default::default());
+ let fc2 = nn::linear(vs, 1024, 10, Default::default());
+ Net { conv1, conv2, fc1, fc2 }
+ }
+}
+
+impl nn::ModuleT for Net {
+ fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
+ xs.view([-1, 1, 28, 28])
+ .apply(&self.conv1)
+ .max_pool2d_default(2)
+ .apply(&self.conv2)
+ .max_pool2d_default(2)
+ .view([-1, 1024])
+ .apply(&self.fc1)
+ .relu()
+ .dropout_(0.5, train)
+ .apply(&self.fc2)
+ }
+}
+
+pub fn run() -> Result<()> {
+ let m = tch::vision::mnist::load_dir("data")?;
+ let vs = nn::VarStore::new(Device::cuda_if_available());
+ let net = Net::new(&vs.root());
+ let mut opt = nn::Adam::default().build(&vs, 1e-4)?;
+ for epoch in 1..100 {
+ for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) {
+ let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels);
+ opt.backward_step(&loss);
+ }
+ let test_accuracy =
+ net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024);
+ println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
+ }
+ Ok(())
+}
diff --git a/mnist/src/mnist_linear.rs b/mnist/src/mnist_linear.rs
new file mode 100644
index 0000000..d3245f7
--- /dev/null
+++ b/mnist/src/mnist_linear.rs
@@ -0,0 +1,42 @@
+// This should rearch 91.5% accuracy.
+
+use anyhow::Result;
+use tch::{kind, no_grad, vision, Kind, Tensor};
+
+const IMAGE_DIM: i64 = 784;
+const LABELS: i64 = 10;
+
+pub fn run() -> Result<()> {
+ let m = vision::mnist::load_dir("data")?;
+ println!("train-images: {:?}", m.train_images.size());
+ println!("train-labels: {:?}", m.train_labels.size());
+ println!("test-images: {:?}", m.test_images.size());
+ println!("test-labels: {:?}", m.test_labels.size());
+ let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true);
+ let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true);
+ for epoch in 1..200 {
+ let logits = m.train_images.mm(&ws) + &bs;
+ let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels);
+ ws.zero_grad();
+ bs.zero_grad();
+ loss.backward();
+ no_grad(|| {
+ ws += ws.grad() * (-1);
+ bs += bs.grad() * (-1);
+ });
+ let test_logits = m.test_images.mm(&ws) + &bs;
+ let test_accuracy = test_logits
+ .argmax(Some(-1), false)
+ .eq_tensor(&m.test_labels)
+ .to_kind(Kind::Float)
+ .mean(Kind::Float)
+ .double_value(&[]);
+ println!(
+ "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
+ epoch,
+ loss.double_value(&[]),
+ 100. * test_accuracy
+ );
+ }
+ Ok(())
+}
diff --git a/mnist/src/mnist_nn.rs b/mnist/src/mnist_nn.rs
new file mode 100644
index 0000000..ca1bc4e
--- /dev/null
+++ b/mnist/src/mnist_nn.rs
@@ -0,0 +1,34 @@
+// This should rearch 97% accuracy.
+
+use anyhow::Result;
+use tch::{nn, nn::Module, nn::OptimizerConfig, Device};
+
+const IMAGE_DIM: i64 = 784;
+const HIDDEN_NODES: i64 = 128;
+const LABELS: i64 = 10;
+
+fn net(vs: &nn::Path) -> impl Module {
+ nn::seq()
+ .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
+}
+
+pub fn run() -> Result<()> {
+ let m = tch::vision::mnist::load_dir("data")?;
+ let vs = nn::VarStore::new(Device::Cpu);
+ let net = net(&vs.root());
+ let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
+ for epoch in 1..1000 {
+ let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels);
+ opt.backward_step(&loss);
+ let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels);
+ println!(
+ "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
+ epoch,
+ f64::from(&loss),
+ 100. * f64::from(&test_accuracy),
+ );
+ }
+ Ok(())
+}