aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train.py b/train.py
index 9d38f44..23422e0 100644
--- a/train.py
+++ b/train.py
@@ -2,7 +2,7 @@ from argparse import ArgumentParser
from itertools import chain
from datasets import load_dataset
-from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator
+from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator
parser = ArgumentParser()
@@ -44,6 +44,6 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True)
# Create and train the model
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
-trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator)
+trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train'])
trainer.train()
-trainer.save_model(args.output)
+trainer.save_model()