diff options
-rw-r--r-- | finetune_bert.py | 126 | ||||
-rw-r--r-- | utils.py | 63 |
2 files changed, 162 insertions, 27 deletions
diff --git a/finetune_bert.py b/finetune_bert.py index e4dc5aa..fba4d1d 100644 --- a/finetune_bert.py +++ b/finetune_bert.py @@ -1,6 +1,9 @@ """ + +# BERT base accelerate launch --mixed_precision bf16 finetune_bert.py \ --model_direction rtl \ +--model_name bert-base-uncased \ --warmup_steps 500 \ --learning_rate 5e-5 \ --per_device_train_batch_size 128 \ @@ -13,6 +16,7 @@ accelerate launch --mixed_precision bf16 finetune_bert.py \ accelerate launch --mixed_precision bf16 finetune_bert.py \ --model_direction ltr \ +--model_name bert-base-uncased \ --warmup_steps 500 \ --learning_rate 5e-5 \ --per_device_train_batch_size 128 \ @@ -22,6 +26,93 @@ accelerate launch --mixed_precision bf16 finetune_bert.py \ --block_size 128 \ --num_train_epochs 4 \ --weight_decay 1e-4 + +# DistilBERT scratch +accelerate launch --mixed_precision bf16 finetune_bert.py \ +--model_direction rtl \ +--model_name distilbert/distilbert-base-uncased \ +--train_from_scratch \ +--warmup_steps 500 \ +--learning_rate 5e-5 \ +--per_device_train_batch_size 128 \ +--per_device_eval_batch_size 128 \ +--output_dir checkpoints/distilbert_base_rtl_scratch/ \ +--eval_steps 899 \ +--block_size 128 \ +--num_train_epochs 4 \ +--weight_decay 1e-4 + +accelerate launch --mixed_precision bf16 finetune_bert.py \ +--model_direction ltr \ +--model_name distilbert/distilbert-base-uncased \ +--train_from_scratch \ +--warmup_steps 500 \ +--learning_rate 5e-5 \ +--per_device_train_batch_size 128 \ +--per_device_eval_batch_size 128 \ +--output_dir checkpoints/distilbert_base_ltr_scratch/ \ +--eval_steps 899 \ +--block_size 128 \ +--num_train_epochs 4 \ +--weight_decay 1e-4 + +# DistilBERT base +accelerate launch --mixed_precision bf16 finetune_bert.py \ +--model_direction rtl \ +--model_name distilbert/distilbert-base-uncased \ +--warmup_steps 500 \ +--learning_rate 5e-5 \ +--per_device_train_batch_size 128 \ +--per_device_eval_batch_size 128 \ +--output_dir checkpoints/distilbert_base_rtl/ \ +--eval_steps 899 \ +--block_size 128 \ +--num_train_epochs 4 \ +--weight_decay 1e-4 + + +accelerate launch --mixed_precision bf16 finetune_bert.py \ +--model_direction ltr \ +--model_name distilbert/distilbert-base-uncased \ +--warmup_steps 500 \ +--learning_rate 5e-5 \ +--per_device_train_batch_size 128 \ +--per_device_eval_batch_size 128 \ +--output_dir checkpoints/distilbert_base_ltr/ \ +--eval_steps 899 \ +--block_size 128 \ +--num_train_epochs 4 \ +--weight_decay 1e-4 + +# BERT large +accelerate launch --mixed_precision bf16 finetune_bert.py \ +--model_direction rtl \ +--model_name bert-large-uncased \ +--warmup_steps 500 \ +--learning_rate 5e-5 \ +--per_device_train_batch_size 64 \ +--gradient_accumulation_steps 2 \ +--per_device_eval_batch_size 64 \ +--output_dir checkpoints/bert_large_rtl/ \ +--eval_steps 899 \ +--block_size 128 \ +--num_train_epochs 4 \ +--weight_decay 1e-4 + + +accelerate launch --mixed_precision bf16 finetune_bert.py \ +--model_direction ltr \ +--model_name bert-large-uncased \ +--warmup_steps 500 \ +--learning_rate 5e-5 \ +--per_device_train_batch_size 64 \ +--gradient_accumulation_steps 2 \ +--per_device_eval_batch_size 64 \ +--output_dir checkpoints/bert_large_ltr/ \ +--eval_steps 899 \ +--block_size 128 \ +--num_train_epochs 4 \ +--weight_decay 1e-4 """ import argparse @@ -49,8 +140,9 @@ def parse_args(): # Model parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"], help="Whether to train a left-to-right or right-to-left LM.") - parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased", - help="Checkpoint to initialize weights from.") # TODO: option for training from scratch w/ conf + parser.add_argument("--model_name", type=str, + help="Name of tokenizer to load. If not training from scratch, " + "will also load model weights.") # Data parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext", @@ -67,6 +159,7 @@ def parse_args(): ) # Training + parser.add_argument("--train_from_scratch", action="store_true") parser.add_argument("--output_dir", type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--per_device_train_batch_size", type=int, default=8) @@ -81,7 +174,9 @@ def parse_args(): parser.add_argument("--eval_steps", type=int, default=20000, help="Number of update steps between two logs.") parser.add_argument("--dataloader_num_workers", type=int, default=8) - return parser.parse_args() + + args = parser.parse_args() + return args def main(): @@ -89,8 +184,14 @@ def main(): accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir) # Will `add_attn_hooks` to `model` later - model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name_or_path, attn_implementation="sdpa") - tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path) + + # Load model weights in both cases, but re-initialize if training from scratch + model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa") + if args.train_from_scratch: + model.apply(model._initialize_weights) + model.tie_weights() # probably not applicable + + tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name) # Data raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) @@ -119,13 +220,17 @@ def main(): model.train() optimizer.zero_grad() - accelerator.init_trackers(project_name="NLP-Class-Project", config=vars(args), - init_kwargs={"wandb": {"entity": "frostbyte"}}) + wandb.require("core") + accelerator.init_trackers( + project_name="NLP-Class-Project", + config=vars(args) | {"model_parameters": sum(p.numel() for p in model.parameters())}, + init_kwargs={"wandb": {"entity": "frostbyte"}} + ) global_step = 0 # unaccumulated steps past_losses = [] for epoch in tqdm(range(args.num_train_epochs), position=0, leave=True, desc="Epoch"): - for batch in tqdm(train_loader, position=1, leave=False, desc="Train Iteration"): + for step, batch in enumerate(tqdm(train_loader, position=1, leave=False, desc="Train Iteration")): with accelerator.accumulate(model): labels = batch.pop("labels") outputs = model(**batch) @@ -137,7 +242,7 @@ def main(): optimizer.zero_grad() past_losses.append(loss.item()) - if (global_step + 1) % args.logging_steps == 1: + if (global_step + 1) % args.logging_steps == 0: avg_train_loss = torch.tensor(past_losses).mean().item() # Assuming 1 GPU accelerator.log({ "train_loss": avg_train_loss, @@ -163,7 +268,8 @@ def main(): log_kwargs={"wandb": {"commit": False}}) model.train() - global_step += 1 + if ((step + 1) % args.gradient_accumulation_steps == 0) or step == (len(train_loader) - 1): + global_step += 1 model.save_pretrained(os.path.join(args.output_dir, f"epoch_{epoch}_checkpt")) @@ -4,7 +4,13 @@ import torch import torch.nn as nn import transformers from datasets import DatasetDict -from transformers import PreTrainedTokenizer +from transformers.models.bert.modeling_bert import BERT_SELF_ATTENTION_CLASSES +from transformers.models.distilbert.modeling_distilbert import DISTILBERT_ATTENTION_CLASSES + + +BERT_ATTENTIONS = tuple(BERT_SELF_ATTENTION_CLASSES.values()) +DISTILBERT_ATTENTIONS = tuple(DISTILBERT_ATTENTION_CLASSES.values()) +IMPLEMENTED_ATTENTIONS = tuple(BERT_ATTENTIONS + DISTILBERT_ATTENTIONS) def ltr_mask(seq_len: int) -> torch.Tensor: @@ -16,7 +22,7 @@ def rtl_mask(seq_len: int) -> torch.Tensor: return ltr_mask(seq_len).T -def add_attn_hooks(model: transformers.BertModel, model_direction: str) -> None: +def add_attn_hooks(model: transformers.PreTrainedModel, model_direction: str) -> None: """ Forces bidirectional `model` into a unidirectional one based on `model_direction`. Adds hooks to `model`'s self-attention blocks, in-place. @@ -29,29 +35,48 @@ def add_attn_hooks(model: transformers.BertModel, model_direction: str) -> None: mask_func = ltr_mask if model_direction.lower() == "ltr" else rtl_mask model.register_buffer("attention_mask", mask_func(model.config.max_position_embeddings).to(model.device)) - def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict): + def get_attention_mask(seq_len: int) -> torch.Tensor: """ - Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515 - so no `kwargs` and `attention_mask` is second positional arg. - - Uses nonlocal `model.attention_mask` to save memory. + Returns `model.attention_mask` if `seq_len` is the max length, generate new attention mask otherwise. """ - assert not kwargs - - args = list(args) - seq_len = args[0].size(1) # During training, we should always be padding to max length, so we can always use `model.attention_mask`. if seq_len != model.config.max_position_embeddings: assert not torch.is_grad_enabled() - attention_mask = ltr_mask(seq_len).to(model.device) + return ltr_mask(seq_len).to(model.device) # TODO: should this be mask_func? + # TODO: should we just have a different function to "prepare" model for inference? + else: + return model.attention_mask + + def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict): + """ + Uses nonlocal `model.attention_mask` to save memory. + """ + if isinstance(attn_module, BERT_ATTENTIONS): + """ + Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515 + so no `kwargs` and `attention_mask` is second positional arg. + """ + assert not kwargs + + args = list(args) + seq_len = args[0].size(1) + args[1] = get_attention_mask(seq_len) + args = tuple(args) + elif isinstance(attn_module, DISTILBERT_ATTENTIONS): + """ + Assuming https://github.com/huggingface/transformers/blob/33eef992503689ba1af98090e26d3e98865b2a9b/src/transformers/models/distilbert/modeling_distilbert.py#L481 + so "mask" in `kwargs`. + """ + assert not args and "mask" in kwargs and "query" in kwargs, f"{args=} {kwargs=}" + seq_len = kwargs["query"].size(1) + kwargs["mask"] = get_attention_mask(seq_len) else: - attention_mask = model.attention_mask + raise NotImplementedError(f"{attn_module=}") - args[1] = attention_mask - return tuple(args), kwargs + return args, kwargs for name, module in model.named_modules(): - if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention): + if isinstance(module, IMPLEMENTED_ATTENTIONS): module._forward_pre_hooks.clear() # in case we run multiple times module.register_forward_pre_hook(attn_hook, with_kwargs=True) @@ -75,7 +100,11 @@ def causal_loss_wrapper(model_direction: str): return loss_fn -def preprocess_datasets(raw_datasets: DatasetDict, tokenizer: PreTrainedTokenizer, block_size: int) -> DatasetDict: +def preprocess_datasets( + raw_datasets: DatasetDict, + tokenizer: transformers.PreTrainedTokenizer, + block_size: int +) -> DatasetDict: """ Preprocess datasets. Closely follows https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L449 |