diff options
author | Anthony Wang | 2022-02-21 14:54:57 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-21 14:54:57 -0600 |
commit | 6e354215c1f3fb9f5f24aa118ff0a9561a31f108 (patch) | |
tree | 0a3dadcba038530b94a7fa4bf652335e3737092e /model.py | |
parent | 8f5bcc04d23da436dcde3eadf8a44198f1871cdb (diff) |
Use the dataset class to initialize model parameters
Diffstat (limited to 'model.py')
-rw-r--r-- | model.py | 5 |
1 files changed, 3 insertions, 2 deletions
@@ -2,13 +2,14 @@ import torch.nn as nn class Model(nn.Module): - def __init__(self, n_vocab, seq_size, embedding_size, lstm_size, num_layers, dropout): + def __init__(self, dataset, embedding_size, lstm_size, num_layers, dropout): super(Model, self).__init__() - self.seq_size = seq_size + self.seq_size = dataset.seq_size self.lstm_size = lstm_size self.num_layers = num_layers + n_vocab = len(dataset.uniq_words) self.embedding = nn.Embedding(n_vocab, embedding_size) self.lstm = nn.LSTM( |