1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
from itertools import chain
import torch
import torch.nn as nn
import transformers
from datasets import DatasetDict
from transformers import PreTrainedTokenizer
def ltr_mask(seq_len: int) -> torch.Tensor:
mask = torch.ones((seq_len, seq_len), dtype=torch.bool)
return torch.tril(mask)
def rtl_mask(seq_len: int) -> torch.Tensor:
return ltr_mask(seq_len).T
def add_attn_hooks(model: transformers.BertModel, model_direction: str) -> None:
"""
Forces bidirectional `model` into a unidirectional one based on `model_direction`.
Adds hooks to `model`'s self-attention blocks, in-place.
Args:
model: only implemented for BERT models right now
model_direction: one of "ltr" or "rtl"
"""
assert model_direction.lower() in ("ltr", "rtl")
mask_func = ltr_mask if model_direction.lower() == "ltr" else rtl_mask
model.register_buffer("attention_mask", mask_func(model.config.max_position_embeddings).to(model.device))
def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):
"""
Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515
so no `kwargs` and `attention_mask` is second positional arg.
Uses nonlocal `model.attention_mask` to save memory.
"""
assert not kwargs
args = list(args)
seq_len = args[0].size(1)
# During training, we should always be padding to max length, so we can always use `model.attention_mask`.
if seq_len != model.config.max_position_embeddings:
assert not torch.is_grad_enabled()
attention_mask = ltr_mask(seq_len).to(model.device)
else:
attention_mask = model.attention_mask
args[1] = attention_mask
return tuple(args), kwargs
for name, module in model.named_modules():
if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
module._forward_pre_hooks.clear() # in case we run multiple times
module.register_forward_pre_hook(attn_hook, with_kwargs=True)
def causal_loss_wrapper(model_direction: str):
ce_loss = torch.nn.CrossEntropyLoss()
def loss_fn(logits, labels):
if model_direction.lower() == "ltr":
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
elif model_direction.lower() == "rtl":
shift_logits = logits[..., 1:, :].contiguous()
shift_labels = labels[..., :-1].contiguous()
else:
raise NotImplementedError(f"{model_direction=}")
# Flatten the tokens
return ce_loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return loss_fn
def preprocess_datasets(raw_datasets: DatasetDict, tokenizer: PreTrainedTokenizer, block_size: int) -> DatasetDict:
"""
Preprocess datasets.
Closely follows https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L449
`raw_datasets` is the output of `load_datasets()`, expected to always have a "train" split
"""
column_names = list(raw_datasets["train"].features)
text_column_name = "text" if "text" in column_names else column_names[0]
tokenized_datasets = raw_datasets.map(
lambda examples: tokenizer(examples[text_column_name]),
batched=True,
num_proc=8,
remove_columns=column_names,
desc="Running tokenizer on dataset",
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/process#map
# # with training_args.main_process_first(desc="grouping texts together"):
return tokenized_datasets.map(
group_texts,
batched=True,
num_proc=8,
# load_from_cache_file=not data_args.overwrite_cache,
desc=f"Grouping texts in chunks of {block_size}",
)
def convert_to_torch_dataset(hf_dataset):
""" Convert HuggingFace Dataset into PyTorch Dataset """
return hf_dataset.with_format("torch")
|