diff options
author | Anthony Wang | 2022-02-22 17:51:52 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-22 17:51:52 -0600 |
commit | 8dab77d61bd0e8cf53f8e786191e026c724a6500 (patch) | |
tree | 3025b110d5f85aa8dd9cb5a592a94e7d7b990b6c | |
parent | 1c43115cd6803b40d8509eba50983c8378782573 (diff) |
Set output directory correctly
-rw-r--r-- | train.py | 6 |
1 files changed, 3 insertions, 3 deletions
@@ -2,7 +2,7 @@ from argparse import ArgumentParser from itertools import chain from datasets import load_dataset -from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, default_data_collator +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, default_data_collator parser = ArgumentParser() @@ -44,6 +44,6 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True) # Create and train the model model = AutoModelForCausalLM.from_pretrained('distilgpt2') -trainer = Trainer(model=model, train_dataset=lm_dataset['train'], tokenizer=tokenizer, data_collator=default_data_collator) +trainer = Trainer(model, TrainingArguments(output_dir=args.output), default_data_collator, lm_dataset['train']) trainer.train() -trainer.save_model(args.output) +trainer.save_model() |