Self-contained masking
+ "cells": [
+ "cell_type": "code",
+ "source": [
+ "model_name_or_path = \"bert-base-uncased\"\n",
+ "model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=\"sdpa\")\n",
+ "model.eval()\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)"
+ "cell_type": "code",
+ "source": [
+ "def ltr_mask(seq_len: int) -> torch.Tensor:\n",
+ " mask = torch.ones((seq_len, seq_len), dtype=bool)\n",
+ " return torch.tril(mask, diagonal=-1)\n",
+ "\n",
+ "def rtl_mask(seq_len: int) -> torch.Tensor:\n",
+ " return ltr_mask(seq_len).T"
+ "cell_type": "code",
+ "source": [
+ "model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
+ "\n",
+ "def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):\n",
+ " \"\"\"\n",
+ " Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515\n",
+ " so no `kwargs` and `attention_mask` is second positional arg.\n",
+ "\n",
+ " Uses global `model.attn_mask` to save memory.\n",
+ " \"\"\"\n",
+ " assert not kwargs\n",
+ "\n",
+ " args = list(args)\n",
+ " assert args[1].size()[-2:] == model.attn_mask.size(), f\"{args[1].size()=} {model.attn_mask.size()=}\"\n",
+ " args[1] = model.attn_mask\n",
+ " return tuple(args), kwargs\n",
+ "\n",
+ "def debug_inputs_hook(attn_module: nn.Module, args: tuple, output):\n",
+ " print(f\"Post-forward checks\")\n",
+ " assert torch.equal(args[1], model.attn_mask), (args[1], model.attn_mask)"
+ "cell_type": "code",
+ "source": [
+ "# attn_mask = rtl_mask(model.config.max_position_embeddings)\n",
+ "for name, module in model.named_modules():\n",
+ " if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):\n",
+ " module._forward_hooks.clear() # running multiple times right now during testing\n",
+ " module.register_forward_pre_hook(attn_hook, with_kwargs=True)\n",
+ " module.register_forward_hook(debug_inputs_hook)\n",
+ " # module.register_buffer(\"attn_mask\", attn_mask)\n",
+ "\n",
+ "model = model.to(device)"
+ "cell_type": "code",
+ "source": [
+ "ds = load_dataset(\"Salesforce/wikitext\", \"wikitext-103-v1\")"
+ "cell_type": "code",
+ "source": [
+ "train_ds = ds[\"train\"]\n",
+ "inputs = tokenizer(train_ds[5][\"text\"], return_tensors=\"pt\", padding='max_length', truncation=True)\n",
+ "inputs = {key: val.to(device) for key, val in inputs.items()}\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ "cell_type": "code",
+ "source": [
+ "with torch.inference_mode():\n",
+ " model.register_buffer(\"attn_mask\", ltr_mask(model.config.max_position_embeddings).to(model.device))\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ "cell_type": "code",
+ "source": [
+ "with torch.inference_mode():\n",
+ " model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ "cell_type": "code",
+ "source": [
+ "# Wikipedia test\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "ds = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n",
+ "print(ds[\"train\"][1000])"
