diff options
-rw-r--r-- | train.py | 42 |
1 files changed, 35 insertions, 7 deletions
@@ -1,21 +1,49 @@ from argparse import ArgumentParser +from itertools import chain from datasets import load_dataset -from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator parser = ArgumentParser() parser.add_argument('-i', '--input', default='data', help='training data input file') +parser.add_argument('-o', '--output', default='model', + help='output directory for trained model') args = parser.parse_args() +# Load and tokenize dataset raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True) - -tokenizer = AutoTokenizer.from_pretrained('distilgpt2') -tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names) - +tokenizer = AutoTokenizer.from_pretrained('distilgpt2', use_fast=True) +tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns='text') + + +# Generate chunks of block_size +block_size = tokenizer.model_max_length + +# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. +def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + +lm_dataset = tokenized_dataset.map(group_texts, batched=True) + + +# Create and train the model model = AutoModelForCausalLM.from_pretrained('distilgpt2') - -trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer) +trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator) trainer.train() +trainer.save_model(args.output) |