diff options
author | Anthony Wang | 2022-02-22 12:36:52 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-22 12:36:52 -0600 |
commit | fb9b81284b162c42d719b0381169996d691a27d8 (patch) | |
tree | 9e07c8152dcaa8b83253f2e99a61b2a8fe4a7762 | |
parent | 772f3c36928e68887ec1748ea859aa8aab71eb44 (diff) |
Optimize LSTM training
-rw-r--r-- | train.py | 12 |
1 files changed, 5 insertions, 7 deletions
@@ -16,13 +16,13 @@ parser.add_argument('-i', '--input', default='data', help='training data input file') parser.add_argument('-o', '--output', default='model.pt', help='trained model output file') -parser.add_argument('-e', '--epochs', default=100, type=int, +parser.add_argument('-e', '--epochs', default=10, type=int, help='number of epochs to train for') parser.add_argument('-s', '--seq-size', default=32, type=int, help='sequence size') parser.add_argument('-b', '--batch-size', default=256, type=int, help='size of each training batch') -parser.add_argument('-m', '--embedding-dim', default=256, type=int, +parser.add_argument('-m', '--embedding-dim', default=64, type=int, help='size of the embedding') parser.add_argument('-l', '--lstm-size', default=256, type=int, help='size of the LSTM hidden state') @@ -58,8 +58,6 @@ for t in range(args.epochs): state_c = state_c.to(device) iteration = 0 - - print(len(dataloader)) for batch, (X, y) in enumerate(dataloader): iteration += 1 @@ -67,8 +65,8 @@ for t in range(args.epochs): optimizer.zero_grad() - X = torch.tensor(X).to(device) - y = torch.tensor(y).to(device) + X = X.to(device) + y = y.to(device) # Compute prediction error logits, (state_h, state_c) = model(X, (state_h, state_c)) @@ -92,7 +90,7 @@ for t in range(args.epochs): 'Iteration: {}'.format(iteration), 'Loss: {}'.format(loss_value)) - if iteration % 20 == 0: + if iteration % 10 == 0: print(' '.join(predict(args.device, dataset, model, 'i am'))) |