aboutsummaryrefslogtreecommitdiff
path: root/finetune_bert.py
blob: 9a8ad46cfcfa37a12eb0c0d055a81bcd48cd8be6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
accelerate launch --mixed_precision bf16 finetune_bert.py --model_direction ltr --learning_rate 5e-5 --output_dir checkpoints/test
"""

import argparse
import math

import accelerate
import torch
import transformers
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import get_scheduler

from utils import preprocess_datasets, convert_to_torch_dataset, add_attn_hooks, causal_loss_wrapper


def parse_args():
    """
    Re-using HuggingFace arguments when possible (most of the help strings are directly copied).
    https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L75
    """
    parser = argparse.ArgumentParser()

    # Model
    parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"],
                        help="Whether to train a left-to-right or right-to-left LM.")
    parser.add_argument("--model_name_or_path", type=str, default="bert-base-uncased",
                        help="Checkpoint to initialize weights from.")  # TODO: option for training from scratch w/ conf

    # Data
    parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext",
                        help="The name of the dataset to use (via the datasets library).")
    parser.add_argument("--dataset_config_name", type=str, default="wikitext-103-v1",
                        help="The configuration name of the dataset to use (via the datasets library).")
    # TODO: block_size, train on shorter seqs?
    parser.add_argument(
        "--block_size",
        type=int,
        help="Optional input sequence length after tokenization. "
             "The training dataset will be truncated in block of this size for training. "
             "Default to the model max input length for single sentence inputs (take into account special tokens)."
    )

    # Training
    parser.add_argument("--output_dir", type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--per_device_train_batch_size", type=int, default=8)
    parser.add_argument("--per_device_eval_batch_size", type=int, default=16)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument("--learning_rate", type=float, required=True)
    parser.add_argument("--warmup_steps", type=int, default=0)
    parser.add_argument("--weight_decay", type=float, default=0.0)
    parser.add_argument("--logging_steps", type=int, default=1,
                        help="Number of update steps between two logs.")
    parser.add_argument("--eval_steps", type=int, default=500,
                        help="Number of update steps between two logs.")
    parser.add_argument("--dataloader_num_workers", type=int, default=8)
    return parser.parse_args()


def main():
    args = parse_args()

    accelerator = accelerate.Accelerator()
    model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name_or_path, attn_implementation="sdpa")
    add_attn_hooks(model, args.model_direction)
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)

    # Data
    raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
    block_size = args.block_size if args.block_size is not None else model.config.max_position_embeddings
    processed_datasets = preprocess_datasets(raw_datasets, tokenizer, block_size)
    for split, hf_dataset in processed_datasets.items():
        processed_datasets[split] = convert_to_torch_dataset(hf_dataset)

    train_loader = DataLoader(processed_datasets["train"], batch_size=args.per_device_train_batch_size, shuffle=True)
    val_loader = DataLoader(processed_datasets["validation"], batch_size=args.per_device_eval_batch_size)
    # test_loader = DataLoader(processed_datasets["test"], batch_size=args.per_device_eval_batch_size)
    model, train_loader, val_loader = accelerator.prepare(model, train_loader, val_loader)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    lr_scheduler = transformers.get_scheduler(
        name=transformers.SchedulerType.COSINE,
        optimizer=optimizer,
        num_warmup_steps=args.warmup_steps,  #* accelerator.num_processes,
        num_training_steps=args.num_train_epochs * math.ceil(len(train_loader) / args.gradient_accumulation_steps),
    )
    loss_fn = causal_loss_wrapper(args.model_direction)

    model.train()
    optimizer.zero_grad()

    for epoch in range(args.num_train_epochs):
        for step, batch in enumerate(tqdm(train_loader)):
            labels = batch.pop("labels")
            outputs = model(**batch)
            loss = loss_fn(outputs.logits, labels)
            loss.backward()

            if (step + 1) % 50 == 1:
                print(f"{loss.item()=}")

            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()


if __name__ == "__main__":
    main()