aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/train.py b/train.py
index 11819bf..2e7d6df 100644
--- a/train.py
+++ b/train.py
@@ -47,7 +47,7 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True)
# Create and train the model
model = AutoModelForCausalLM.from_pretrained('gpt2-large',
torch_dtype=float16, low_cpu_mem_usage=True).to('cuda')
-trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1,
- gradient_accumulation_steps=8), default_data_collator, lm_dataset['train'])
+trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1),
+ default_data_collator, lm_dataset['train'])
trainer.train()
trainer.save_model()