diff options
Diffstat (limited to 'utils.py')
-rw-r--r-- | utils.py | 107 |
1 files changed, 96 insertions, 11 deletions
@@ -1,45 +1,130 @@ +from itertools import chain + import torch import torch.nn as nn import transformers +from datasets import DatasetDict +from transformers import PreTrainedTokenizer def ltr_mask(seq_len: int) -> torch.Tensor: mask = torch.ones((seq_len, seq_len), dtype=torch.bool) - return torch.tril(mask, diagonal=-1) + return torch.tril(mask) def rtl_mask(seq_len: int) -> torch.Tensor: return ltr_mask(seq_len).T -def add_attn_hooks(model: transformers.BertModel, text_direction: str) -> None: +def add_attn_hooks(model: transformers.BertModel, model_direction: str) -> None: """ - Forces bidirectional `model` into a unidirectional one based on `direction`. + Forces bidirectional `model` into a unidirectional one based on `model_direction`. Adds hooks to `model`'s self-attention blocks, in-place. Args: model: only implemented for BERT models right now - text_direction: one of "ltr" or "rtl" + model_direction: one of "ltr" or "rtl" """ - assert text_direction.lower() in ("ltr", "rtl") - mask_func = ltr_mask if text_direction.lower() == "ltr" else rtl_mask - model.register_buffer("attn_mask", mask_func(model.config.max_position_embeddings).to(model.device)) + assert model_direction.lower() in ("ltr", "rtl") + mask_func = ltr_mask if model_direction.lower() == "ltr" else rtl_mask + model.register_buffer("attention_mask", mask_func(model.config.max_position_embeddings).to(model.device)) def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict): """ Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515 so no `kwargs` and `attention_mask` is second positional arg. - Uses nonlocal `model.attn_mask` to save memory. + Uses nonlocal `model.attention_mask` to save memory. """ assert not kwargs args = list(args) - assert args[1].size()[-2:] == model.attn_mask.size(), f"{args[1].size()=} {model.attn_mask.size()=}" - args[1] = model.attn_mask + seq_len = args[0].size(1) + # During training, we should always be padding to max length, so we can always use `model.attention_mask`. + if seq_len != model.config.max_position_embeddings: + assert not torch.is_grad_enabled() + attention_mask = ltr_mask(seq_len).to(model.device) + else: + attention_mask = model.attention_mask + + args[1] = attention_mask return tuple(args), kwargs for name, module in model.named_modules(): if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention): - module._forward_hooks.clear() # in case we run multiple times + module._forward_pre_hooks.clear() # in case we run multiple times module.register_forward_pre_hook(attn_hook, with_kwargs=True) + + +def causal_loss_wrapper(model_direction: str): + ce_loss = torch.nn.CrossEntropyLoss() + + def loss_fn(logits, labels): + if model_direction.lower() == "ltr": + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + elif model_direction.lower() == "rtl": + shift_logits = logits[..., 1:, :].contiguous() + shift_labels = labels[..., :-1].contiguous() + else: + raise NotImplementedError(f"{model_direction=}") + + # Flatten the tokens + return ce_loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + return loss_fn + + +def preprocess_datasets(raw_datasets: DatasetDict, tokenizer: PreTrainedTokenizer, block_size: int) -> DatasetDict: + """ + Preprocess datasets. + Closely follows https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L449 + + `raw_datasets` is the output of `load_datasets()`, expected to always have a "train" split + """ + column_names = list(raw_datasets["train"].features) + text_column_name = "text" if "text" in column_names else column_names[0] + tokenized_datasets = raw_datasets.map( + lambda examples: tokenizer(examples[text_column_name]), + batched=True, + num_proc=8, + remove_columns=column_names, + desc="Running tokenizer on dataset", + ) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i: i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/process#map + + # # with training_args.main_process_first(desc="grouping texts together"): + return tokenized_datasets.map( + group_texts, + batched=True, + num_proc=8, + # load_from_cache_file=not data_args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + +def convert_to_torch_dataset(hf_dataset): + """ Convert HuggingFace Dataset into PyTorch Dataset """ + return hf_dataset.with_format("torch") |