1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout):
super(Model, self).__init__()
self.seq_size = dataset.seq_size
self.lstm_size = lstm_size
self.num_layers = num_layers
n_vocab = len(dataset.uniq_words)
self.embedding = nn.Embedding(n_vocab, embedding_dim)
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=lstm_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout
)
self.dense = nn.Linear(lstm_size, n_vocab)
def forward(self, x, prev_state):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.dense(output)
return logits, state
def zero_state(self, batch_size):
return (torch.zeros(self.num_layers, batch_size, self.lstm_size),
torch.zeros(self.num_layers, batch_size, self.lstm_size))
|