aboutsummaryrefslogtreecommitdiff
path: root/train.py
blob: bbdc54afa26926eb8312709a1638f0afa259e7c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from argparse import ArgumentParser

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer


parser = ArgumentParser()
parser.add_argument('-i', '--input', default='data',
                    help='training data input file')
args = parser.parse_args()


raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True)

tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names)

model = AutoModelForCausalLM.from_pretrained('distilgpt2')

trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer)
trainer.train()