diff options
-rw-r--r-- | finetune_bert.py | 58 |
1 files changed, 53 insertions, 5 deletions
diff --git a/finetune_bert.py b/finetune_bert.py index fba4d1d..e9f2147 100644 --- a/finetune_bert.py +++ b/finetune_bert.py @@ -113,6 +113,42 @@ accelerate launch --mixed_precision bf16 finetune_bert.py \ --block_size 128 \ --num_train_epochs 4 \ --weight_decay 1e-4 + +for size in 35 19 11 6; do + for dir in ltr rtl; do + accelerate launch --mixed_precision bf16 finetune_bert.py \ + --model_direction $dir \ + --model_config "configs/bert_${size}M.json" \ + --model_name bert-base-uncased \ + --train_from_scratch \ + --warmup_steps 500 \ + --learning_rate 5e-5 \ + --per_device_train_batch_size 128 \ + --per_device_eval_batch_size 128 \ + --output_dir "checkpoints/bert_${size}_${dir}_scratch/" \ + --eval_steps 899 \ + --block_size 128 \ + --num_train_epochs 4 \ + --weight_decay 1e-4 + done +done + +size=35 +dir=ltr +WANDB_MODE=offline accelerate launch --mixed_precision bf16 finetune_bert.py \ + --model_direction $dir \ + --model_config "configs/bert_${size}M.json" \ + --model_name bert-base-uncased \ + --train_from_scratch \ + --warmup_steps 500 \ + --learning_rate 5e-5 \ + --per_device_train_batch_size 128 \ + --per_device_eval_batch_size 128 \ + --output_dir "checkpoints/bert_${size}_${dir}_scratch/" \ + --eval_steps 899 \ + --block_size 128 \ + --num_train_epochs 4 \ + --weight_decay 1e-4 """ import argparse @@ -126,6 +162,7 @@ import wandb from datasets import load_dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm +from transformers import set_seed from utils import preprocess_datasets, convert_to_torch_dataset, add_attn_hooks, causal_loss_wrapper @@ -140,9 +177,12 @@ def parse_args(): # Model parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"], help="Whether to train a left-to-right or right-to-left LM.") - parser.add_argument("--model_name", type=str, - help="Name of tokenizer to load. If not training from scratch, " - "will also load model weights.") + parser.add_argument("--model_config", type=str, + help="Path to model config json, from which to train_from_scratch.") + parser.add_argument("--model_name", type=str, required=True, + help="Name of tokenizer to load. " + "If model_config is not specified, will also load model architecture." + "If not training from scratch, will also load model weights.") # Data parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext", @@ -176,6 +216,7 @@ def parse_args(): parser.add_argument("--dataloader_num_workers", type=int, default=8) args = parser.parse_args() + return args @@ -183,10 +224,17 @@ def main(): args = parse_args() accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir) + set_seed(42) + # Will `add_attn_hooks` to `model` later + if args.model_config is not None: + assert args.train_from_scratch, "Expected to train from scratch when model_config is specified." + config = transformers.AutoConfig.from_pretrained(args.model_config) + model = transformers.AutoModelForMaskedLM.from_config(config) + else: + # Load model weights in both cases, but re-initialize if training from scratch + model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa") - # Load model weights in both cases, but re-initialize if training from scratch - model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa") if args.train_from_scratch: model.apply(model._initialize_weights) model.tie_weights() # probably not applicable |