aboutsummaryrefslogtreecommitdiff
path: root/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils.py')
-rw-r--r--utils.py107
1 files changed, 96 insertions, 11 deletions
diff --git a/utils.py b/utils.py
index 7f922fe..a306987 100644
--- a/utils.py
+++ b/utils.py
@@ -1,45 +1,130 @@
+from itertools import chain
+
import torch
import torch.nn as nn
import transformers
+from datasets import DatasetDict
+from transformers import PreTrainedTokenizer
def ltr_mask(seq_len: int) -> torch.Tensor:
mask = torch.ones((seq_len, seq_len), dtype=torch.bool)
- return torch.tril(mask, diagonal=-1)
+ return torch.tril(mask)
def rtl_mask(seq_len: int) -> torch.Tensor:
return ltr_mask(seq_len).T
-def add_attn_hooks(model: transformers.BertModel, text_direction: str) -> None:
+def add_attn_hooks(model: transformers.BertModel, model_direction: str) -> None:
"""
- Forces bidirectional `model` into a unidirectional one based on `direction`.
+ Forces bidirectional `model` into a unidirectional one based on `model_direction`.
Adds hooks to `model`'s self-attention blocks, in-place.
Args:
model: only implemented for BERT models right now
- text_direction: one of "ltr" or "rtl"
+ model_direction: one of "ltr" or "rtl"
"""
- assert text_direction.lower() in ("ltr", "rtl")
- mask_func = ltr_mask if text_direction.lower() == "ltr" else rtl_mask
- model.register_buffer("attn_mask", mask_func(model.config.max_position_embeddings).to(model.device))
+ assert model_direction.lower() in ("ltr", "rtl")
+ mask_func = ltr_mask if model_direction.lower() == "ltr" else rtl_mask
+ model.register_buffer("attention_mask", mask_func(model.config.max_position_embeddings).to(model.device))
def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):
"""
Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515
so no `kwargs` and `attention_mask` is second positional arg.
- Uses nonlocal `model.attn_mask` to save memory.
+ Uses nonlocal `model.attention_mask` to save memory.
"""
assert not kwargs
args = list(args)
- assert args[1].size()[-2:] == model.attn_mask.size(), f"{args[1].size()=} {model.attn_mask.size()=}"
- args[1] = model.attn_mask
+ seq_len = args[0].size(1)
+ # During training, we should always be padding to max length, so we can always use `model.attention_mask`.
+ if seq_len != model.config.max_position_embeddings:
+ assert not torch.is_grad_enabled()
+ attention_mask = ltr_mask(seq_len).to(model.device)
+ else:
+ attention_mask = model.attention_mask
+
+ args[1] = attention_mask
return tuple(args), kwargs
for name, module in model.named_modules():
if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
- module._forward_hooks.clear() # in case we run multiple times
+ module._forward_pre_hooks.clear() # in case we run multiple times
module.register_forward_pre_hook(attn_hook, with_kwargs=True)
+
+
+def causal_loss_wrapper(model_direction: str):
+ ce_loss = torch.nn.CrossEntropyLoss()
+
+ def loss_fn(logits, labels):
+ if model_direction.lower() == "ltr":
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ elif model_direction.lower() == "rtl":
+ shift_logits = logits[..., 1:, :].contiguous()
+ shift_labels = labels[..., :-1].contiguous()
+ else:
+ raise NotImplementedError(f"{model_direction=}")
+
+ # Flatten the tokens
+ return ce_loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+ return loss_fn
+
+
+def preprocess_datasets(raw_datasets: DatasetDict, tokenizer: PreTrainedTokenizer, block_size: int) -> DatasetDict:
+ """
+ Preprocess datasets.
+ Closely follows https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L449
+
+ `raw_datasets` is the output of `load_datasets()`, expected to always have a "train" split
+ """
+ column_names = list(raw_datasets["train"].features)
+ text_column_name = "text" if "text" in column_names else column_names[0]
+ tokenized_datasets = raw_datasets.map(
+ lambda examples: tokenizer(examples[text_column_name]),
+ batched=True,
+ num_proc=8,
+ remove_columns=column_names,
+ desc="Running tokenizer on dataset",
+ )
+
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
+ def group_texts(examples):
+ # Concatenate all texts.
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
+ # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
+ # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
+ total_length = (total_length // block_size) * block_size
+ # Split by chunks of max_len.
+ result = {
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ result["labels"] = result["input_ids"].copy()
+ return result
+
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
+ # to preprocess.
+ #
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
+ # https://huggingface.co/docs/datasets/process#map
+
+ # # with training_args.main_process_first(desc="grouping texts together"):
+ return tokenized_datasets.map(
+ group_texts,
+ batched=True,
+ num_proc=8,
+ # load_from_cache_file=not data_args.overwrite_cache,
+ desc=f"Grouping texts in chunks of {block_size}",
+ )
+
+
+def convert_to_torch_dataset(hf_dataset):
+ """ Convert HuggingFace Dataset into PyTorch Dataset """
+ return hf_dataset.with_format("torch")