from argparse import ArgumentParser import torch from torch import nn from torch.utils.data import DataLoader from dataset import Dataset from model import Model from predict import predict parser = ArgumentParser() parser.add_argument('-d', '--device', default='cpu', help='device to train with') parser.add_argument('-i', '--input', default='data', help='training data input file') parser.add_argument('-o', '--output', default='model.pt', help='trained model output file') parser.add_argument('-e', '--epochs', default=10, type=int, help='number of epochs to train for') parser.add_argument('-s', '--seq-size', default=32, type=int, help='sequence size') parser.add_argument('-b', '--batch-size', default=256, type=int, help='size of each training batch') parser.add_argument('-m', '--embedding-dim', default=64, type=int, help='size of the embedding') parser.add_argument('-l', '--lstm-size', default=256, type=int, help='size of the LSTM hidden state') parser.add_argument('-a', '--layers', default=3, type=int, help='number of LSTM layers') parser.add_argument('-r', '--dropout', default=0.2, type=int, help='how much dropout to apply') parser.add_argument('-n', '--max-norm', default=5, type=int, help='maximum norm for gradient clipping') args = parser.parse_args() # Prepare dataloader dataset = Dataset(args.input, args.seq_size) dataloader = DataLoader(dataset, args.batch_size) print(len(dataloader)) # Prepare model device = torch.device(args.device) model = Model(dataset, args.embedding_dim, args.lstm_size, args.layers, args.dropout).to(device) print(model) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for t in range(args.epochs): state_h, state_c = model.zero_state(args.batch_size) state_h = state_h.to(device) state_c = state_c.to(device) iteration = 0 for batch, (X, y) in enumerate(dataloader): iteration += 1 model.train() optimizer.zero_grad() X = X.to(device) y = y.to(device) # Compute prediction error logits, (state_h, state_c) = model(X, (state_h, state_c)) loss = loss_fn(logits.transpose(1, 2), y) loss_value = loss.item() # Backpropogation loss.backward() state_h = state_h.detach() state_c = state_c.detach() _ = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_norm) optimizer.step() if iteration % 1 == 0: print('Epoch: {}/{}'.format(t, args.epochs), 'Iteration: {}'.format(iteration), 'Loss: {}'.format(loss_value)) if iteration % 10 == 0: print(' '.join(predict(args.device, dataset, model, 'i am'))) torch.save(model, args.output)