import torch import torch.nn as nn class Model(nn.Module): def __init__(self, dataset, embedding_dim, lstm_size, num_layers, dropout): super(Model, self).__init__() self.seq_size = dataset.seq_size self.lstm_size = lstm_size self.num_layers = num_layers n_vocab = len(dataset.uniq_words) self.embedding = nn.Embedding(n_vocab, embedding_dim) self.lstm = nn.LSTM( input_size=embedding_dim, hidden_size=lstm_size, num_layers=num_layers, batch_first=True, dropout=dropout ) self.dense = nn.Linear(lstm_size, n_vocab) def forward(self, x, prev_state): embed = self.embedding(x) output, state = self.lstm(embed, prev_state) logits = self.dense(output) return logits, state def zero_state(self, batch_size): return (torch.zeros(self.num_layers, batch_size, self.lstm_size), torch.zeros(self.num_layers, batch_size, self.lstm_size))