diff options
Diffstat (limited to 'finetune_bert.py')
-rw-r--r-- | finetune_bert.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/finetune_bert.py b/finetune_bert.py index e9f2147..da29af9 100644 --- a/finetune_bert.py +++ b/finetune_bert.py @@ -118,8 +118,8 @@ for size in 35 19 11 6; do for dir in ltr rtl; do accelerate launch --mixed_precision bf16 finetune_bert.py \ --model_direction $dir \ - --model_config "configs/bert_${size}M.json" \ --model_name bert-base-uncased \ + --model_config "configs/bert_${size}M.json" \ --train_from_scratch \ --warmup_steps 500 \ --learning_rate 5e-5 \ @@ -133,22 +133,25 @@ for size in 35 19 11 6; do done done -size=35 -dir=ltr -WANDB_MODE=offline accelerate launch --mixed_precision bf16 finetune_bert.py \ +for seed in 0 1 2 3 4; do + for dir in ltr rtl; do + accelerate launch --mixed_precision bf16 finetune_bert.py \ --model_direction $dir \ - --model_config "configs/bert_${size}M.json" \ --model_name bert-base-uncased \ + --model_config "configs/bert_${size}M.json" \ --train_from_scratch \ --warmup_steps 500 \ --learning_rate 5e-5 \ --per_device_train_batch_size 128 \ --per_device_eval_batch_size 128 \ - --output_dir "checkpoints/bert_${size}_${dir}_scratch/" \ + --output_dir "checkpoints/overwritable_temp/" \ --eval_steps 899 \ --block_size 128 \ - --num_train_epochs 4 \ - --weight_decay 1e-4 + --num_train_epochs 1 \ + --weight_decay 1e-4 \ + --seed $seed + done +done """ import argparse @@ -214,6 +217,7 @@ def parse_args(): parser.add_argument("--eval_steps", type=int, default=20000, help="Number of update steps between two logs.") parser.add_argument("--dataloader_num_workers", type=int, default=8) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") args = parser.parse_args() @@ -224,7 +228,7 @@ def main(): args = parse_args() accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir) - set_seed(42) + set_seed(args.seed) # Will `add_attn_hooks` to `model` later if args.model_config is not None: |