diff options
-rw-r--r-- | src/vae.rs | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/src/vae.rs b/src/vae.rs new file mode 100644 index 0000000..2fc6ae8 --- /dev/null +++ b/src/vae.rs @@ -0,0 +1,13 @@ +use tch::nn::{Module, OptimizerConfig}; +use tch::{kind, nn, Device, Tensor}; + +pub fn vae(vs: &nn::Path) -> impl Module { + nn::seq() + .add(nn::linear(vs, 100, 50, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, 50, 10, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, 10, 50, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, 50, 100, Default::default())) +} |