diff options
author | Anthony Wang | 2022-02-21 13:02:37 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-21 13:02:37 -0600 |
commit | 441cae5c2cc230f5db4075a5379943123cc0aa11 (patch) | |
tree | 0c47c887e0386c99aae91314a583e31cfa77c00c /model.py | |
parent | bcb2b5e45514e9920f1c3ad025ffac3e2745c539 (diff) |
Add new multilayer model
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/model.py b/model.py new file mode 100644 index 0000000..9b08f47 --- /dev/null +++ b/model.py @@ -0,0 +1,32 @@ +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self, n_vocab, seq_size, embedding_size, lstm_size, num_layers, dropout): + super(Model, self).__init__() + + self.seq_size = seq_size + self.lstm_size = lstm_size + self.num_layers = num_layers + + self.embedding = nn.Embedding(n_vocab, embedding_size) + + self.lstm = nn.LSTM( + input_size=embedding_size, + hidden_size=lstm_size, + num_layers=num_layers, + batch_first=True, + dropout=dropout + ) + self.dense = nn.Linear(lstm_size, n_vocab) + + def forward(self, x, prev_state): + embed = self.embedding(x) + output, state = self.lstm(embed, prev_state) + logits = self.dense(output) + + return logits, state + + def zero_state(self, batch_size): + return (torch.zeros(self.num_layers, batch_size, self.lstm_size), + torch.zeros(self.num_layers, batch_size, self.lstm_size))
\ No newline at end of file |