summaryrefslogtreecommitdiff
path: root/mnist/src/mnist_linear.rs
blob: d3245f7b7dfa937c3646b2fd5cbe7996feb92e14 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// This should rearch 91.5% accuracy.

use anyhow::Result;
use tch::{kind, no_grad, vision, Kind, Tensor};

const IMAGE_DIM: i64 = 784;
const LABELS: i64 = 10;

pub fn run() -> Result<()> {
    let m = vision::mnist::load_dir("data")?;
    println!("train-images: {:?}", m.train_images.size());
    println!("train-labels: {:?}", m.train_labels.size());
    println!("test-images: {:?}", m.test_images.size());
    println!("test-labels: {:?}", m.test_labels.size());
    let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true);
    let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true);
    for epoch in 1..200 {
        let logits = m.train_images.mm(&ws) + &bs;
        let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels);
        ws.zero_grad();
        bs.zero_grad();
        loss.backward();
        no_grad(|| {
            ws += ws.grad() * (-1);
            bs += bs.grad() * (-1);
        });
        let test_logits = m.test_images.mm(&ws) + &bs;
        let test_accuracy = test_logits
            .argmax(Some(-1), false)
            .eq_tensor(&m.test_labels)
            .to_kind(Kind::Float)
            .mean(Kind::Float)
            .double_value(&[]);
        println!(
            "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
            epoch,
            loss.double_value(&[]),
            100. * test_accuracy
        );
    }
    Ok(())
}