aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2022-02-22 12:36:52 -0600
committerAnthony Wang2022-02-22 12:36:52 -0600
commitfb9b81284b162c42d719b0381169996d691a27d8 (patch)
tree9e07c8152dcaa8b83253f2e99a61b2a8fe4a7762
parent772f3c36928e68887ec1748ea859aa8aab71eb44 (diff)
Optimize LSTM training
-rw-r--r--train.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/train.py b/train.py
index fcc4e19..b90f02b 100644
--- a/train.py
+++ b/train.py
@@ -16,13 +16,13 @@ parser.add_argument('-i', '--input', default='data',
help='training data input file')
parser.add_argument('-o', '--output', default='model.pt',
help='trained model output file')
-parser.add_argument('-e', '--epochs', default=100, type=int,
+parser.add_argument('-e', '--epochs', default=10, type=int,
help='number of epochs to train for')
parser.add_argument('-s', '--seq-size', default=32, type=int,
help='sequence size')
parser.add_argument('-b', '--batch-size', default=256, type=int,
help='size of each training batch')
-parser.add_argument('-m', '--embedding-dim', default=256, type=int,
+parser.add_argument('-m', '--embedding-dim', default=64, type=int,
help='size of the embedding')
parser.add_argument('-l', '--lstm-size', default=256, type=int,
help='size of the LSTM hidden state')
@@ -58,8 +58,6 @@ for t in range(args.epochs):
state_c = state_c.to(device)
iteration = 0
-
- print(len(dataloader))
for batch, (X, y) in enumerate(dataloader):
iteration += 1
@@ -67,8 +65,8 @@ for t in range(args.epochs):
optimizer.zero_grad()
- X = torch.tensor(X).to(device)
- y = torch.tensor(y).to(device)
+ X = X.to(device)
+ y = y.to(device)
# Compute prediction error
logits, (state_h, state_c) = model(X, (state_h, state_c))
@@ -92,7 +90,7 @@ for t in range(args.epochs):
'Iteration: {}'.format(iteration),
'Loss: {}'.format(loss_value))
- if iteration % 20 == 0:
+ if iteration % 10 == 0:
print(' '.join(predict(args.device, dataset, model, 'i am')))