summaryrefslogtreecommitdiff
path: root/src/vae.rs
blob: 2fc6ae88061943620bcbf5f038e2558c1d3e99c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
use tch::nn::{Module, OptimizerConfig};
use tch::{kind, nn, Device, Tensor};

pub fn vae(vs: &nn::Path) -> impl Module {
    nn::seq()
        .add(nn::linear(vs, 100, 50, Default::default()))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, 50, 10, Default::default()))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, 10, 50, Default::default()))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, 50, 100, Default::default()))
}