aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--train.py92
-rw-r--r--train_lstm.py97
2 files changed, 105 insertions, 84 deletions
diff --git a/train.py b/train.py
index b90f02b..bbdc54a 100644
--- a/train.py
+++ b/train.py
@@ -1,97 +1,21 @@
from argparse import ArgumentParser
-import torch
-from torch import nn
-from torch.utils.data import DataLoader
-
-from dataset import Dataset
-from model import Model
-from predict import predict
+from datasets import load_dataset
+from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
parser = ArgumentParser()
-parser.add_argument('-d', '--device', default='cpu',
- help='device to train with')
parser.add_argument('-i', '--input', default='data',
help='training data input file')
-parser.add_argument('-o', '--output', default='model.pt',
- help='trained model output file')
-parser.add_argument('-e', '--epochs', default=10, type=int,
- help='number of epochs to train for')
-parser.add_argument('-s', '--seq-size', default=32, type=int,
- help='sequence size')
-parser.add_argument('-b', '--batch-size', default=256, type=int,
- help='size of each training batch')
-parser.add_argument('-m', '--embedding-dim', default=64, type=int,
- help='size of the embedding')
-parser.add_argument('-l', '--lstm-size', default=256, type=int,
- help='size of the LSTM hidden state')
-parser.add_argument('-a', '--layers', default=3, type=int,
- help='number of LSTM layers')
-parser.add_argument('-r', '--dropout', default=0.2, type=int,
- help='how much dropout to apply')
-parser.add_argument('-n', '--max-norm', default=5, type=int,
- help='maximum norm for gradient clipping')
args = parser.parse_args()
-# Prepare dataloader
-dataset = Dataset(args.input, args.seq_size)
-dataloader = DataLoader(dataset, args.batch_size)
-print(len(dataloader))
-
-
-# Prepare model
-device = torch.device(args.device)
-model = Model(dataset, args.embedding_dim, args.lstm_size,
- args.layers, args.dropout).to(device)
-print(model)
-
-
-loss_fn = nn.CrossEntropyLoss()
-optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-
-
-for t in range(args.epochs):
- state_h, state_c = model.zero_state(args.batch_size)
- state_h = state_h.to(device)
- state_c = state_c.to(device)
-
- iteration = 0
- for batch, (X, y) in enumerate(dataloader):
- iteration += 1
-
- model.train()
-
- optimizer.zero_grad()
-
- X = X.to(device)
- y = y.to(device)
-
- # Compute prediction error
- logits, (state_h, state_c) = model(X, (state_h, state_c))
- loss = loss_fn(logits.transpose(1, 2), y)
-
- loss_value = loss.item()
-
- # Backpropogation
- loss.backward()
-
- state_h = state_h.detach()
- state_c = state_c.detach()
-
- _ = torch.nn.utils.clip_grad_norm_(
- model.parameters(), args.max_norm)
-
- optimizer.step()
-
- if iteration % 1 == 0:
- print('Epoch: {}/{}'.format(t, args.epochs),
- 'Iteration: {}'.format(iteration),
- 'Loss: {}'.format(loss_value))
+raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True)
- if iteration % 10 == 0:
- print(' '.join(predict(args.device, dataset, model, 'i am')))
+tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
+tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names)
+model = AutoModelForCausalLM.from_pretrained('distilgpt2')
-torch.save(model, args.output)
+trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer)
+trainer.train()
diff --git a/train_lstm.py b/train_lstm.py
new file mode 100644
index 0000000..b90f02b
--- /dev/null
+++ b/train_lstm.py
@@ -0,0 +1,97 @@
+from argparse import ArgumentParser
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+
+from dataset import Dataset
+from model import Model
+from predict import predict
+
+
+parser = ArgumentParser()
+parser.add_argument('-d', '--device', default='cpu',
+ help='device to train with')
+parser.add_argument('-i', '--input', default='data',
+ help='training data input file')
+parser.add_argument('-o', '--output', default='model.pt',
+ help='trained model output file')
+parser.add_argument('-e', '--epochs', default=10, type=int,
+ help='number of epochs to train for')
+parser.add_argument('-s', '--seq-size', default=32, type=int,
+ help='sequence size')
+parser.add_argument('-b', '--batch-size', default=256, type=int,
+ help='size of each training batch')
+parser.add_argument('-m', '--embedding-dim', default=64, type=int,
+ help='size of the embedding')
+parser.add_argument('-l', '--lstm-size', default=256, type=int,
+ help='size of the LSTM hidden state')
+parser.add_argument('-a', '--layers', default=3, type=int,
+ help='number of LSTM layers')
+parser.add_argument('-r', '--dropout', default=0.2, type=int,
+ help='how much dropout to apply')
+parser.add_argument('-n', '--max-norm', default=5, type=int,
+ help='maximum norm for gradient clipping')
+args = parser.parse_args()
+
+
+# Prepare dataloader
+dataset = Dataset(args.input, args.seq_size)
+dataloader = DataLoader(dataset, args.batch_size)
+print(len(dataloader))
+
+
+# Prepare model
+device = torch.device(args.device)
+model = Model(dataset, args.embedding_dim, args.lstm_size,
+ args.layers, args.dropout).to(device)
+print(model)
+
+
+loss_fn = nn.CrossEntropyLoss()
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+
+for t in range(args.epochs):
+ state_h, state_c = model.zero_state(args.batch_size)
+ state_h = state_h.to(device)
+ state_c = state_c.to(device)
+
+ iteration = 0
+ for batch, (X, y) in enumerate(dataloader):
+ iteration += 1
+
+ model.train()
+
+ optimizer.zero_grad()
+
+ X = X.to(device)
+ y = y.to(device)
+
+ # Compute prediction error
+ logits, (state_h, state_c) = model(X, (state_h, state_c))
+ loss = loss_fn(logits.transpose(1, 2), y)
+
+ loss_value = loss.item()
+
+ # Backpropogation
+ loss.backward()
+
+ state_h = state_h.detach()
+ state_c = state_c.detach()
+
+ _ = torch.nn.utils.clip_grad_norm_(
+ model.parameters(), args.max_norm)
+
+ optimizer.step()
+
+ if iteration % 1 == 0:
+ print('Epoch: {}/{}'.format(t, args.epochs),
+ 'Iteration: {}'.format(iteration),
+ 'Loss: {}'.format(loss_value))
+
+ if iteration % 10 == 0:
+ print(' '.join(predict(args.device, dataset, model, 'i am')))
+
+
+torch.save(model, args.output)