aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bot.py2
-rw-r--r--train.py4
2 files changed, 3 insertions, 3 deletions
diff --git a/bot.py b/bot.py
index c3d5274..5bcac18 100644
--- a/bot.py
+++ b/bot.py
@@ -19,7 +19,7 @@ args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
-model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=float16).to('cuda')
+model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda')
if args.input is None:
diff --git a/train.py b/train.py
index 11819bf..2e7d6df 100644
--- a/train.py
+++ b/train.py
@@ -47,7 +47,7 @@ lm_dataset = tokenized_dataset.map(group_texts, batched=True)
# Create and train the model
model = AutoModelForCausalLM.from_pretrained('gpt2-large',
torch_dtype=float16, low_cpu_mem_usage=True).to('cuda')
-trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1,
- gradient_accumulation_steps=8), default_data_collator, lm_dataset['train'])
+trainer = Trainer(model, TrainingArguments(output_dir=args.output, per_device_train_batch_size=1),
+ default_data_collator, lm_dataset['train'])
trainer.train()
trainer.save_model()