aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2022-02-22 17:51:52 -0600
committerAnthony Wang2022-02-22 17:51:52 -0600
commit8dab77d61bd0e8cf53f8e786191e026c724a6500 (patch)
tree3025b110d5f85aa8dd9cb5a592a94e7d7b990b6c
parent1c43115cd6803b40d8509eba50983c8378782573 (diff)
Set output directory correctly
-rw-r--r--train.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train.py b/train.py
index 9d38f44..23422e0 100644
--- a/train.py
+++ b/train.py
@@ -2,7 +2,7 @@ from argparse import ArgumentParser
from itertools import chain
from datasets import load_dataset
-from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator
+from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator
parser = ArgumentParser()
@@ -44,6 +44,6 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True)
# Create and train the model
model = AutoModelForCausalLM.from_pretrained('distilgpt2')
-trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator)
+trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train'])
trainer.train()
-trainer.save_model(args.output)
+trainer.save_model()