diff options
Diffstat (limited to 'finetune_QA.py')
-rw-r--r-- | finetune_QA.py | 304 |
1 files changed, 304 insertions, 0 deletions
diff --git a/finetune_QA.py b/finetune_QA.py new file mode 100644 index 0000000..e5b8ef7 --- /dev/null +++ b/finetune_QA.py @@ -0,0 +1,304 @@ +""" +accelerate launch --mixed_precision bf16 finetune_QA.py \ +--model_direction rtl \ +--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_rtl/epoch_3_checkpt \ +--tokenizer_name distilbert/distilbert-base-uncased \ +--warmup_steps 100 \ +--learning_rate 1e-5 \ +--per_device_train_batch_size 128 \ +--per_device_eval_batch_size 128 \ +--output_dir checkpoints/qa_distilbert_base_rtl/ \ +--eval_steps 38 \ +--block_size 128 \ +--num_train_epochs 50 \ +--weight_decay 1e-4 + + +accelerate launch --mixed_precision bf16 finetune_QA.py \ +--model_direction ltr \ +--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \ +--tokenizer_name distilbert/distilbert-base-uncased \ +--warmup_steps 100 \ +--learning_rate 1e-5 \ +--per_device_train_batch_size 128 \ +--per_device_eval_batch_size 128 \ +--output_dir checkpoints/qa_distilbert_base_ltr/ \ +--eval_steps 38 \ +--block_size 128 \ +--num_train_epochs 50 \ +--weight_decay 1e-4 + +accelerate launch --mixed_precision bf16 finetune_QA.py \ +--model_direction ltr \ +--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \ +--tokenizer_name distilbert/distilbert-base-uncased \ +--warmup_steps 100 \ +--learning_rate 1e-5 \ +--per_device_train_batch_size 128 \ +--per_device_eval_batch_size 128 \ +--output_dir checkpoints/qa_distilbert_base_ltr_overfit/ \ +--eval_steps 50 \ +--block_size 128 \ +--num_train_epochs 1000 \ +--weight_decay 0 +""" + + + +import argparse +import math +import os +from collections import defaultdict + +import accelerate +import torch +import transformers +import wandb +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader +from transformers.data.data_collator import default_data_collator +from tqdm.auto import tqdm + +from utils import preprocess_datasets, convert_to_torch_dataset, add_attn_hooks, causal_loss_wrapper + +#### HERE WE do the dataset stuff +class DatasetAQ(Dataset): + def __init__(self, qa_pairs, text_direction, tokenizer): + self.qa_pairs = qa_pairs + self.text_direction = text_direction + self.tokenizer = tokenizer + + def __getitem__(self, idx): + question, answer = self.qa_pairs[idx] + sentence = torch.cat([question, answer], dim=0) if self.text_direction.lower() == "rtl" else torch.cat([answer, question], dim=0) + + # TODO: length + num_to_pad = self.tokenizer.model_max_length - sentence.size(0) + assert num_to_pad >= 0, (sentence.size(), self.tokenizer.model_max_length) + + if num_to_pad > 0: + pad_tokens = torch.full((num_to_pad,), self.tokenizer.pad_token_id, dtype=sentence.dtype) + pad_labels = torch.full((num_to_pad,), -100, dtype=sentence.dtype) + + if self.text_direction.lower() == "rtl": + input_ids = torch.cat([pad_tokens, sentence], dim=0) + labels = torch.cat([pad_labels, sentence], dim=0) + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + attention_mask[:num_to_pad] = 0 + else: + input_ids = torch.cat([sentence, pad_tokens], dim=0) + labels = torch.cat([sentence, pad_labels], dim=0) + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + attention_mask[-num_to_pad:] = 0 + + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + } + + def __len__(self): + return len(self.qa_pairs) + +#### + + + +def parse_args(): + """ + Re-using HuggingFace arguments when possible (most of the help strings are directly copied). + https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L75 + """ + parser = argparse.ArgumentParser() + + # Model + parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"], + help="Whether to train a left-to-right or right-to-left LM.") + parser.add_argument("--checkpoint_path", type=str, + help="Path to load model weights from.") + + # Data + parser.add_argument("--tokenizer_name", type=str, + help="Name of tokenizer to load.") + parser.add_argument("--dataset_name", type=str, default="truthfulqa/truthful_qa", + help="The name of the dataset to use (via the datasets library).") + parser.add_argument("--dataset_config_name", type=str, default="generation", + help="The configuration name of the dataset to use (via the datasets library).") + # TODO: block_size, train on shorter seqs? + parser.add_argument( + "--block_size", + type=int, + help="Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + ) + + # Training + parser.add_argument("--train_from_scratch", action="store_true") + parser.add_argument("--output_dir", type=str, required=True, + help="The output directory where the model predictions and checkpoints will be written.") + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--per_device_eval_batch_size", type=int, default=16) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument("--learning_rate", type=float, required=True) + parser.add_argument("--warmup_steps", type=int, default=0) + parser.add_argument("--scheduler", type=str, default="cosine") + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument("--logging_steps", type=int, default=1, + help="Number of update steps between two logs.") + parser.add_argument("--eval_steps", type=int, default=20000, + help="Number of update steps between two logs.") + parser.add_argument("--dataloader_num_workers", type=int, default=8) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + transformers.set_seed(42) + + accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir) + # Will `add_attn_hooks` to `model` later + + # Load model weights in both cases, but re-initialize if training from scratch + model = transformers.AutoModelForMaskedLM.from_pretrained(args.checkpoint_path, attn_implementation="sdpa", ignore_mismatched_sizes=True) + if args.train_from_scratch: + model.apply(model._init_weights) + model.tie_weights() # probably not applicable + + tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_name) + + # Data + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + block_size = args.block_size if args.block_size is not None else model.config.max_position_embeddings + model.config.max_position_embeddings = block_size + tokenizer.model_max_length = block_size + + # QA-specific code + all_data = raw_datasets["validation"] + transformers.set_seed(42) + train_val_split = all_data.train_test_split(test_size=0.2, shuffle=True) + val_test_split = train_val_split['test'].train_test_split(test_size=0.5, shuffle=False) + train_dataset = train_val_split['train'] + val_dataset = val_test_split['train'] + test_dataset = val_test_split['test'] + + qa_pairs = defaultdict(list) + for data_name, dataset in zip(["test","train","val"], [train_dataset, test_dataset, val_dataset]): + for row in dataset: + tokenized_question = tokenizer("Question: "+ row["question"], return_tensors="pt")["input_ids"].squeeze(0) + for ans_type in ["correct_answers", "incorrect_answers"]: + for answer in row[ans_type]: + # the [:, 1:] thing is to remove CLS token + qa_pairs[data_name].append((tokenized_question, tokenizer(f"Answer: {answer}", return_tensors="pt")["input_ids"].squeeze(0)[1:])) + + train_dataset = DatasetAQ(qa_pairs["train"], args.model_direction, tokenizer) + test_dataset = DatasetAQ(qa_pairs["test"], args.model_direction, tokenizer) + val_dataset = DatasetAQ(qa_pairs["val"], args.model_direction, tokenizer) + + train_loader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=args.per_device_eval_batch_size, shuffle=False) + val_loader = DataLoader(val_dataset, batch_size=args.per_device_eval_batch_size) + + model, train_loader, test_loader, val_loader = accelerator.prepare(model, train_loader, test_loader, val_loader) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) + lr_scheduler = transformers.get_scheduler( + name=transformers.SchedulerType.COSINE, + optimizer=optimizer, + num_warmup_steps=args.warmup_steps * accelerator.num_processes, + # num_training_steps=args.num_train_epochs * math.ceil(len(train_loader) / args.gradient_accumulation_steps), + num_training_steps=args.num_train_epochs * len(train_loader), + ) + + lr_scheduler = accelerator.prepare(lr_scheduler) # testing if this fixes learning rate + + loss_fn = causal_loss_wrapper(args.model_direction) + + add_attn_hooks(model, args.model_direction) + model.train() + optimizer.zero_grad() + + wandb.require("core") + accelerator.init_trackers( + project_name="NLP-Class-Project", + config=vars(args) | {"model_parameters": sum(p.numel() for p in model.parameters())}, + init_kwargs={"wandb": {"entity": "frostbyte"}} + ) + + global_step = 0 # unaccumulated steps + past_losses = [] + best_val_loss = float("inf") + best_checkpt_path = os.path.join(args.output_dir, f"best_checkpt") + + for epoch in tqdm(range(args.num_train_epochs), position=0, leave=True, desc="Epoch"): + for step, batch in enumerate(tqdm(train_loader, position=1, leave=False, desc="Train Iteration")): + with accelerator.accumulate(model): + labels = batch.pop("labels") + outputs = model(**batch) + loss = loss_fn(outputs.logits, labels) + accelerator.backward(loss) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + past_losses.append(loss.item()) + if (global_step + 1) % args.logging_steps == 0: + avg_train_loss = torch.tensor(past_losses).mean().item() # Assuming 1 GPU + accelerator.log({ + "train_loss": avg_train_loss, + "learning_rate": lr_scheduler.get_last_lr()[0], + }) + past_losses.clear() + + if (global_step + 1) % args.eval_steps == 0: + val_loss_sum = val_examples = 0 + model.eval() + for val_batch in tqdm(val_loader, position=2, leave=False, desc="Val Iteration"): + labels = val_batch.pop("labels") + with torch.no_grad(): + outputs = model(**val_batch) + + loss = loss_fn(outputs.logits, labels) + + batch_size = labels.size(0) + val_loss_sum += loss.item() * batch_size + val_examples += batch_size + + val_loss = val_loss_sum / val_examples + if val_loss < best_val_loss: + best_val_loss = val_loss + model.save_pretrained(best_checkpt_path) + + accelerator.log({"val_loss": val_loss_sum / val_examples}, + log_kwargs={"wandb": {"commit": False}}) + model.train() + + if ((step + 1) % args.gradient_accumulation_steps == 0) or step == (len(train_loader) - 1): + global_step += 1 + + # model.save_pretrained(os.path.join(args.output_dir, f"epoch_{epoch}_checkpt")) + + # testing + model.from_pretrained(best_checkpt_path) + model.eval() + with torch.no_grad(): + test_loss_sum = test_examples = 0 + for test_batch in tqdm(test_loader): + labels = test_batch.pop("labels") + outputs = model(**test_batch) + + loss = loss_fn(outputs.logits, labels) + + batch_size = labels.size(0) + test_loss_sum += loss.item() * batch_size + test_examples += batch_size + + accelerator.log({"test_loss": test_loss_sum / test_examples}) + + +if __name__ == "__main__": + main() |