aboutsummaryrefslogtreecommitdiff
path: root/finetune_bert.py
diff options
context:
space:
mode:
Diffstat (limited to 'finetune_bert.py')
-rw-r--r--finetune_bert.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/finetune_bert.py b/finetune_bert.py
index e9f2147..da29af9 100644
--- a/finetune_bert.py
+++ b/finetune_bert.py
@@ -118,8 +118,8 @@ for size in 35 19 11 6; do
for dir in ltr rtl; do
accelerate launch --mixed_precision bf16 finetune_bert.py \
--model_direction $dir \
- --model_config "configs/bert_${size}M.json" \
--model_name bert-base-uncased \
+ --model_config "configs/bert_${size}M.json" \
--train_from_scratch \
--warmup_steps 500 \
--learning_rate 5e-5 \
@@ -133,22 +133,25 @@ for size in 35 19 11 6; do
done
done
-size=35
-dir=ltr
-WANDB_MODE=offline accelerate launch --mixed_precision bf16 finetune_bert.py \
+for seed in 0 1 2 3 4; do
+ for dir in ltr rtl; do
+ accelerate launch --mixed_precision bf16 finetune_bert.py \
--model_direction $dir \
- --model_config "configs/bert_${size}M.json" \
--model_name bert-base-uncased \
+ --model_config "configs/bert_${size}M.json" \
--train_from_scratch \
--warmup_steps 500 \
--learning_rate 5e-5 \
--per_device_train_batch_size 128 \
--per_device_eval_batch_size 128 \
- --output_dir "checkpoints/bert_${size}_${dir}_scratch/" \
+ --output_dir "checkpoints/overwritable_temp/" \
--eval_steps 899 \
--block_size 128 \
- --num_train_epochs 4 \
- --weight_decay 1e-4
+ --num_train_epochs 1 \
+ --weight_decay 1e-4 \
+ --seed $seed
+ done
+done
"""
import argparse
@@ -214,6 +217,7 @@ def parse_args():
parser.add_argument("--eval_steps", type=int, default=20000,
help="Number of update steps between two logs.")
parser.add_argument("--dataloader_num_workers", type=int, default=8)
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
args = parser.parse_args()
@@ -224,7 +228,7 @@ def main():
args = parse_args()
accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir)
- set_seed(42)
+ set_seed(args.seed)
# Will `add_attn_hooks` to `model` later
if args.model_config is not None: