From 1c43115cd6803b40d8509eba50983c8378782573 Mon Sep 17 00:00:00 2001 From: Anthony Wang Date: Tue, 22 Feb 2022 17:43:29 -0600 Subject: Split data into chunks and save model --- train.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index bbdc54a..9d38f44 100644 --- a/train.py +++ b/train.py @@ -1,21 +1,49 @@ from argparse import ArgumentParser +from itertools import chain from datasets import load_dataset -from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator parser = ArgumentParser() parser.add_argument('-i', '--input', default='data', help='training data input file') +parser.add_argument('-o', '--output', default='model', + help='output directory for trained model') args = parser.parse_args() +# Load and tokenize dataset raw_dataset = load_dataset('text', data_files={'train': args.input}, keep_linebreaks=True) - -tokenizer = AutoTokenizer.from_pretrained('distilgpt2') -tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns=raw_dataset["train"].column_names) - +tokenizer = AutoTokenizer.from_pretrained('distilgpt2', use_fast=True) +tokenized_dataset = raw_dataset.map(lambda examples : tokenizer(examples['text']), batched=True, remove_columns='text') + + +# Generate chunks of block_size +block_size = tokenizer.model_max_length + +# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. +def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + +lm_dataset = tokenized_dataset.map(group_texts, batched=True) + + +# Create and train the model model = AutoModelForCausalLM.from_pretrained('distilgpt2') - -trainer = Trainer(model=model, train_dataset=tokenized_dataset['train'], tokenizer=tokenizer) +trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator) trainer.train() +trainer.save_model(args.output) -- cgit v1.2.3-70-g09d2