aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAnthony Wang2022-02-21 13:02:37 -0600
committerAnthony Wang2022-02-21 13:02:37 -0600
commit441cae5c2cc230f5db4075a5379943123cc0aa11 (patch)
tree0c47c887e0386c99aae91314a583e31cfa77c00c /model.py
parentbcb2b5e45514e9920f1c3ad025ffac3e2745c539 (diff)
Add new multilayer model
Diffstat (limited to 'model.py')
-rw-r--r--model.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..9b08f47
--- /dev/null
+++ b/model.py
@@ -0,0 +1,32 @@
+import torch.nn as nn
+
+
+class Model(nn.Module):
+ def __init__(self, n_vocab, seq_size, embedding_size, lstm_size, num_layers, dropout):
+ super(Model, self).__init__()
+
+ self.seq_size = seq_size
+ self.lstm_size = lstm_size
+ self.num_layers = num_layers
+
+ self.embedding = nn.Embedding(n_vocab, embedding_size)
+
+ self.lstm = nn.LSTM(
+ input_size=embedding_size,
+ hidden_size=lstm_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout
+ )
+ self.dense = nn.Linear(lstm_size, n_vocab)
+
+ def forward(self, x, prev_state):
+ embed = self.embedding(x)
+ output, state = self.lstm(embed, prev_state)
+ logits = self.dense(output)
+
+ return logits, state
+
+ def zero_state(self, batch_size):
+ return (torch.zeros(self.num_layers, batch_size, self.lstm_size),
+ torch.zeros(self.num_layers, batch_size, self.lstm_size)) \ No newline at end of file