aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Zhao2024-12-01 16:16:29 -0500
committerKevin Zhao2024-12-01 16:16:29 -0500
commit7c518a3fab1aab0246b62fd51ea0da89da684c33 (patch)
treec33f0542c654beb3f50455121de29b581939f2a2
parent993359964202bc9cc643a98a7775c29a096a8bb6 (diff)
Add training from scratch
-rw-r--r--finetune_bert.py58
1 files changed, 53 insertions, 5 deletions
diff --git a/finetune_bert.py b/finetune_bert.py
index fba4d1d..e9f2147 100644
--- a/finetune_bert.py
+++ b/finetune_bert.py
@@ -113,6 +113,42 @@ accelerate launch --mixed_precision bf16 finetune_bert.py \
--block_size 128 \
--num_train_epochs 4 \
--weight_decay 1e-4
+
+for size in 35 19 11 6; do
+ for dir in ltr rtl; do
+ accelerate launch --mixed_precision bf16 finetune_bert.py \
+ --model_direction $dir \
+ --model_config "configs/bert_${size}M.json" \
+ --model_name bert-base-uncased \
+ --train_from_scratch \
+ --warmup_steps 500 \
+ --learning_rate 5e-5 \
+ --per_device_train_batch_size 128 \
+ --per_device_eval_batch_size 128 \
+ --output_dir "checkpoints/bert_${size}_${dir}_scratch/" \
+ --eval_steps 899 \
+ --block_size 128 \
+ --num_train_epochs 4 \
+ --weight_decay 1e-4
+ done
+done
+
+size=35
+dir=ltr
+WANDB_MODE=offline accelerate launch --mixed_precision bf16 finetune_bert.py \
+ --model_direction $dir \
+ --model_config "configs/bert_${size}M.json" \
+ --model_name bert-base-uncased \
+ --train_from_scratch \
+ --warmup_steps 500 \
+ --learning_rate 5e-5 \
+ --per_device_train_batch_size 128 \
+ --per_device_eval_batch_size 128 \
+ --output_dir "checkpoints/bert_${size}_${dir}_scratch/" \
+ --eval_steps 899 \
+ --block_size 128 \
+ --num_train_epochs 4 \
+ --weight_decay 1e-4
"""
import argparse
@@ -126,6 +162,7 @@ import wandb
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
+from transformers import set_seed
from utils import preprocess_datasets, convert_to_torch_dataset, add_attn_hooks, causal_loss_wrapper
@@ -140,9 +177,12 @@ def parse_args():
# Model
parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"],
help="Whether to train a left-to-right or right-to-left LM.")
- parser.add_argument("--model_name", type=str,
- help="Name of tokenizer to load. If not training from scratch, "
- "will also load model weights.")
+ parser.add_argument("--model_config", type=str,
+ help="Path to model config json, from which to train_from_scratch.")
+ parser.add_argument("--model_name", type=str, required=True,
+ help="Name of tokenizer to load. "
+ "If model_config is not specified, will also load model architecture."
+ "If not training from scratch, will also load model weights.")
# Data
parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext",
@@ -176,6 +216,7 @@ def parse_args():
parser.add_argument("--dataloader_num_workers", type=int, default=8)
args = parser.parse_args()
+
return args
@@ -183,10 +224,17 @@ def main():
args = parse_args()
accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir)
+ set_seed(42)
+
# Will `add_attn_hooks` to `model` later
+ if args.model_config is not None:
+ assert args.train_from_scratch, "Expected to train from scratch when model_config is specified."
+ config = transformers.AutoConfig.from_pretrained(args.model_config)
+ model = transformers.AutoModelForMaskedLM.from_config(config)
+ else:
+ # Load model weights in both cases, but re-initialize if training from scratch
+ model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa")
- # Load model weights in both cases, but re-initialize if training from scratch
- model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa")
if args.train_from_scratch:
model.apply(model._initialize_weights)
model.tie_weights() # probably not applicable