aboutsummaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorAnthony Wang2022-02-21 14:54:57 -0600
committerAnthony Wang2022-02-21 14:54:57 -0600
commit6e354215c1f3fb9f5f24aa118ff0a9561a31f108 (patch)
tree0a3dadcba038530b94a7fa4bf652335e3737092e /model.py
parent8f5bcc04d23da436dcde3eadf8a44198f1871cdb (diff)
Use the dataset class to initialize model parameters
Diffstat (limited to 'model.py')
-rw-r--r--model.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/model.py b/model.py
index 9b08f47..cfa1deb 100644
--- a/model.py
+++ b/model.py
@@ -2,13 +2,14 @@ import torch.nn as nn
class Model(nn.Module):
- def __init__(self, n_vocab, seq_size, embedding_size, lstm_size, num_layers, dropout):
+ def __init__(self, dataset, embedding_size, lstm_size, num_layers, dropout):
super(Model, self).__init__()
- self.seq_size = seq_size
+ self.seq_size = dataset.seq_size
self.lstm_size = lstm_size
self.num_layers = num_layers
+ n_vocab = len(dataset.uniq_words)
self.embedding = nn.Embedding(n_vocab, embedding_size)
self.lstm = nn.LSTM(