diff options
-rw-r--r-- | bot.py | 2 | ||||
-rw-r--r-- | train.py | 4 |
2 files changed, 3 insertions, 3 deletions
@@ -19,7 +19,7 @@ args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained('gpt2-large') -model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=float16).to('cuda') +model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda') if args.input is None: @@ -47,7 +47,7 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True) # Create and train the model model = AutoModelForCausalLM.from_pretrained('gpt2-large', torch_dtype=float16, low_cpu_mem_usage=True).to('cuda') -trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1, - gradient_accumulation_steps=8), default_data_collator, lm_dataset['train']) +trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1), + default_data_collator, lm_dataset['train']) trainer.train() trainer.save_model() |