aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--finetune_bert.py126
-rw-r--r--utils.py63
2 files changed, 162 insertions, 27 deletions
diff --git a/finetune_bert.py b/finetune_bert.py
index e4dc5aa..fba4d1d 100644
--- a/finetune_bert.py
+++ b/finetune_bert.py
@@ -1,6 +1,9 @@
"""
+
+# BERT base
accelerate launch --mixed_precision bf16 finetune_bert.py \
--model_direction rtl \
+--model_name bert-base-uncased \
--warmup_steps 500 \
--learning_rate 5e-5 \
--per_device_train_batch_size 128 \
@@ -13,6 +16,7 @@ accelerate launch --mixed_precision bf16 finetune_bert.py \
accelerate launch --mixed_precision bf16 finetune_bert.py \
--model_direction ltr \
+--model_name bert-base-uncased \
--warmup_steps 500 \
--learning_rate 5e-5 \
--per_device_train_batch_size 128 \
@@ -22,6 +26,93 @@ accelerate launch --mixed_precision bf16 finetune_bert.py \
--block_size 128 \
--num_train_epochs 4 \
--weight_decay 1e-4
+
+# DistilBERT scratch
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--model_name distilbert/distilbert-base-uncased \
+--train_from_scratch \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_rtl_scratch/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction ltr \
+--model_name distilbert/distilbert-base-uncased \
+--train_from_scratch \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_ltr_scratch/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+# DistilBERT base
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--model_name distilbert/distilbert-base-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_rtl/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction ltr \
+--model_name distilbert/distilbert-base-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_ltr/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+# BERT large
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--model_name bert-large-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 64 \
+--gradient_accumulation_steps 2 \
+--per_device_eval_batch_size 64 \
+--output_dir checkpoints/bert_large_rtl/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction ltr \
+--model_name bert-large-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 64 \
+--gradient_accumulation_steps 2 \
+--per_device_eval_batch_size 64 \
+--output_dir checkpoints/bert_large_ltr/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
"""
import argparse
@@ -49,8 +140,9 @@ def parse_args():
# Model
parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"],
help="Whether to train a left-to-right or right-to-left LM.")
- parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased",
- help="Checkpoint to initialize weights from.") # TODO: option for training from scratch w/ conf
+ parser.add_argument("--model_name", type=str,
+ help="Name of tokenizer to load. If not training from scratch, "
+ "will also load model weights.")
# Data
parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext",
@@ -67,6 +159,7 @@ def parse_args():
)
# Training
+ parser.add_argument("--train_from_scratch", action="store_true")
parser.add_argument("--output_dir", type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--per_device_train_batch_size", type=int, default=8)
@@ -81,7 +174,9 @@ def parse_args():
parser.add_argument("--eval_steps", type=int, default=20000,
help="Number of update steps between two logs.")
parser.add_argument("--dataloader_num_workers", type=int, default=8)
- return parser.parse_args()
+
+ args = parser.parse_args()
+ return args
def main():
@@ -89,8 +184,14 @@ def main():
accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir)
# Will `add_attn_hooks` to `model` later
- model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name_or_path, attn_implementation="sdpa")
- tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)
+
+ # Load model weights in both cases, but re-initialize if training from scratch
+ model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa")
+ if args.train_from_scratch:
+ model.apply(model._initialize_weights)
+ model.tie_weights() # probably not applicable
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name)
# Data
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
@@ -119,13 +220,17 @@ def main():
model.train()
optimizer.zero_grad()
- accelerator.init_trackers(project_name="NLP-Class-Project", config=vars(args),
- init_kwargs={"wandb": {"entity": "frostbyte"}})
+ wandb.require("core")
+ accelerator.init_trackers(
+ project_name="NLP-Class-Project",
+ config=vars(args) | {"model_parameters": sum(p.numel() for p in model.parameters())},
+ init_kwargs={"wandb": {"entity": "frostbyte"}}
+ )
global_step = 0 # unaccumulated steps
past_losses = []
for epoch in tqdm(range(args.num_train_epochs), position=0, leave=True, desc="Epoch"):
- for batch in tqdm(train_loader, position=1, leave=False, desc="Train Iteration"):
+ for step, batch in enumerate(tqdm(train_loader, position=1, leave=False, desc="Train Iteration")):
with accelerator.accumulate(model):
labels = batch.pop("labels")
outputs = model(**batch)
@@ -137,7 +242,7 @@ def main():
optimizer.zero_grad()
past_losses.append(loss.item())
- if (global_step + 1) % args.logging_steps == 1:
+ if (global_step + 1) % args.logging_steps == 0:
avg_train_loss = torch.tensor(past_losses).mean().item() # Assuming 1 GPU
accelerator.log({
"train_loss": avg_train_loss,
@@ -163,7 +268,8 @@ def main():
log_kwargs={"wandb": {"commit": False}})
model.train()
- global_step += 1
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or step == (len(train_loader) - 1):
+ global_step += 1
model.save_pretrained(os.path.join(args.output_dir, f"epoch_{epoch}_checkpt"))
diff --git a/utils.py b/utils.py
index a306987..4c90023 100644
--- a/utils.py
+++ b/utils.py
@@ -4,7 +4,13 @@ import torch
import torch.nn as nn
import transformers
from datasets import DatasetDict
-from transformers import PreTrainedTokenizer
+from transformers.models.bert.modeling_bert import BERT_SELF_ATTENTION_CLASSES
+from transformers.models.distilbert.modeling_distilbert import DISTILBERT_ATTENTION_CLASSES
+
+
+BERT_ATTENTIONS = tuple(BERT_SELF_ATTENTION_CLASSES.values())
+DISTILBERT_ATTENTIONS = tuple(DISTILBERT_ATTENTION_CLASSES.values())
+IMPLEMENTED_ATTENTIONS = tuple(BERT_ATTENTIONS + DISTILBERT_ATTENTIONS)
def ltr_mask(seq_len: int) -> torch.Tensor:
@@ -16,7 +22,7 @@ def rtl_mask(seq_len: int) -> torch.Tensor:
return ltr_mask(seq_len).T
-def add_attn_hooks(model: transformers.BertModel, model_direction: str) -> None:
+def add_attn_hooks(model: transformers.PreTrainedModel, model_direction: str) -> None:
"""
Forces bidirectional `model` into a unidirectional one based on `model_direction`.
Adds hooks to `model`'s self-attention blocks, in-place.
@@ -29,29 +35,48 @@ def add_attn_hooks(model: transformers.BertModel, model_direction: str) -> None:
mask_func = ltr_mask if model_direction.lower() == "ltr" else rtl_mask
model.register_buffer("attention_mask", mask_func(model.config.max_position_embeddings).to(model.device))
- def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):
+ def get_attention_mask(seq_len: int) -> torch.Tensor:
"""
- Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515
- so no `kwargs` and `attention_mask` is second positional arg.
-
- Uses nonlocal `model.attention_mask` to save memory.
+ Returns `model.attention_mask` if `seq_len` is the max length, generate new attention mask otherwise.
"""
- assert not kwargs
-
- args = list(args)
- seq_len = args[0].size(1)
# During training, we should always be padding to max length, so we can always use `model.attention_mask`.
if seq_len != model.config.max_position_embeddings:
assert not torch.is_grad_enabled()
- attention_mask = ltr_mask(seq_len).to(model.device)
+ return ltr_mask(seq_len).to(model.device) # TODO: should this be mask_func?
+ # TODO: should we just have a different function to "prepare" model for inference?
+ else:
+ return model.attention_mask
+
+ def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):
+ """
+ Uses nonlocal `model.attention_mask` to save memory.
+ """
+ if isinstance(attn_module, BERT_ATTENTIONS):
+ """
+ Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515
+ so no `kwargs` and `attention_mask` is second positional arg.
+ """
+ assert not kwargs
+
+ args = list(args)
+ seq_len = args[0].size(1)
+ args[1] = get_attention_mask(seq_len)
+ args = tuple(args)
+ elif isinstance(attn_module, DISTILBERT_ATTENTIONS):
+ """
+ Assuming https://github.com/huggingface/transformers/blob/33eef992503689ba1af98090e26d3e98865b2a9b/src/transformers/models/distilbert/modeling_distilbert.py#L481
+ so "mask" in `kwargs`.
+ """
+ assert not args and "mask" in kwargs and "query" in kwargs, f"{args=} {kwargs=}"
+ seq_len = kwargs["query"].size(1)
+ kwargs["mask"] = get_attention_mask(seq_len)
else:
- attention_mask = model.attention_mask
+ raise NotImplementedError(f"{attn_module=}")
- args[1] = attention_mask
- return tuple(args), kwargs
+ return args, kwargs
for name, module in model.named_modules():
- if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
+ if isinstance(module, IMPLEMENTED_ATTENTIONS):
module._forward_pre_hooks.clear() # in case we run multiple times
module.register_forward_pre_hook(attn_hook, with_kwargs=True)
@@ -75,7 +100,11 @@ def causal_loss_wrapper(model_direction: str):
return loss_fn
-def preprocess_datasets(raw_datasets: DatasetDict, tokenizer: PreTrainedTokenizer, block_size: int) -> DatasetDict:
+def preprocess_datasets(
+ raw_datasets: DatasetDict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ block_size: int
+) -> DatasetDict:
"""
Preprocess datasets.
Closely follows https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L449