diff options
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 6 |
1 files changed, 3 insertions, 3 deletions
@@ -2,7 +2,7 @@ from argparse import ArgumentParser from itertools import chain from datasets import load_dataset -from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator parser = ArgumentParser() @@ -44,6 +44,6 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True) # Create and train the model model = AutoModelForCausalLM.from_pretrained('distilgpt2') -trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator) +trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train']) trainer.train() -trainer.save_model(args.output) +trainer.save_model() |