In [None]:
%pip install datasets torch transformers

In [None]:
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset

transformers.set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [None]:
model_name_or_path = "bert-base-uncased"
model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa")
model.eval()

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identica

In [None]:
def ltr_mask(seq_len: int) -> torch.Tensor:
 mask = torch.ones((seq_len, seq_len), dtype=bool)
 return torch.tril(mask, diagonal=-1)

def rtl_mask(seq_len: int) -> torch.Tensor:
 return ltr_mask(seq_len).T

In [None]:
model.register_buffer("attn_mask", rtl_mask(model.config.max_position_embeddings).to(model.device))

def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):
 """
 Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515
 so no `kwargs` and `attention_mask` is second positional arg.

 Uses global `model.attn_mask` to save memory.
 """
 assert not kwargs

 args = list(args)
 assert args[1].size()[-2:] == model.attn_mask.size(), f"{args[1].size()=} {model.attn_mask.size()=}"
 args[1] = model.attn_mask
 return tuple(args), kwargs

def debug_inputs_hook(attn_module: nn.Module, args: tuple, output):
 print(f"Post-forward checks")
 assert torch.equal(args[1], model.attn_mask), (args[1], model.attn_mask)

In [None]:
# attn_mask = rtl_mask(model.config.max_position_embeddings)
for name, module in model.named_modules():
 if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):
 module._forward_hooks.clear() # running multiple times right now during testing
 module.register_forward_pre_hook(attn_hook, with_kwargs=True)
 module.register_forward_hook(debug_inputs_hook)
 # module.register_buffer("attn_mask", attn_mask)

model = model.to(device)

In [None]:
ds = load_dataset("Salesforce/wikitext", "wikitext-103-v1")

In [None]:
train_ds = ds["train"]
inputs = tokenizer(train_ds[5]["text"], return_tensors="pt", padding='max_length', truncation=True)
inputs = {key: val.to(device) for key, val in inputs.items()}

with torch.no_grad():
 outputs = model(**inputs)

outputs.logits

Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks


tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],
 [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],
 [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],
 ...,
 [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],
 [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],
 [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],
 device='cuda:0', dtype=torch.bfloat16)

In [None]:
with torch.inference_mode():
 model.register_buffer("attn_mask", ltr_mask(model.config.max_position_embeddings).to(model.device))
 outputs = model(**inputs)

outputs.logits

Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks


tensor([[[-7.9062, -7.7812, -7.9062, ..., -7.1250, -7.8438, -4.8438],
 [-7.1562, -7.1250, -7.2812, ..., -7.3750, -7.3750, -7.2500],
 [-5.4062, -5.2188, -5.4375, ..., -5.3438, -4.3750, -5.0312],
 ...,
 [ 3.9844, 3.6406, 3.6406, ..., 3.8281, 2.9062, 5.2812],
 [ 4.0938, 3.7812, 3.8281, ..., 4.0000, 2.9844, 5.5000],
 [ 3.8281, 3.5312, 3.5156, ..., 4.1562, 2.8438, 4.7188]]],
 device='cuda:0', dtype=torch.bfloat16)

In [None]:
with torch.inference_mode():
 model.register_buffer("attn_mask", rtl_mask(model.config.max_position_embeddings).to(model.device))
 outputs = model(**inputs)

outputs.logits

Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks
Post-forward checks


tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],
 [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],
 [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],
 ...,
 [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],
 [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],
 [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],
 device='cuda:0', dtype=torch.bfloat16)

In [None]:
# Wikipedia test
from datasets import load_dataset

ds = load_dataset("wikimedia/wikipedia", "20231101.en")
print(ds["train"][1000])