diff options
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -47,7 +47,7 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True) # Create and train the model model = AutoModelForCausalLM.from_pretrained('gpt2-large', torch_dtype=float16, low_cpu_mem_usage=True).to('cuda') -trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1, - gradient_accumulation_steps=8), default_data_collator, lm_dataset['train']) +trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1), + default_data_collator, lm_dataset['train']) trainer.train() trainer.save_model() |