summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/vae.rs13
1 files changed, 13 insertions, 0 deletions
diff --git a/src/vae.rs b/src/vae.rs
new file mode 100644
index 0000000..2fc6ae8
--- /dev/null
+++ b/src/vae.rs
@@ -0,0 +1,13 @@
+use tch::nn::{Module, OptimizerConfig};
+use tch::{kind, nn, Device, Tensor};
+
+pub fn vae(vs: &nn::Path) -> impl Module {
+ nn::seq()
+ .add(nn::linear(vs, 100, 50, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, 50, 10, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, 10, 50, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, 50, 100, Default::default()))
+}