{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "M29-oTOBIiMr" }, "outputs": [], "source": [ "%pip install datasets torch transformers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2iJJyERxHWSO", "outputId": "04e9bc9d-5ee9-48d5-f370-6fd66ec7b7c1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import transformers\n", "from datasets import load_dataset\n", "\n", "transformers.set_seed(42)\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "C5PdaHGWHuXG", "outputId": "d15272a5-1ce1-4c7e-9004-fc686a3de6b9" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n", "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n" ] } ], "source": [ "model_name_or_path = \"bert-base-uncased\"\n", "model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=\"sdpa\")\n", "model.eval()\n", "\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H_AUjBRoJHXU" }, "outputs": [], "source": [ "def ltr_mask(seq_len: int) -> torch.Tensor:\n", " mask = torch.ones((seq_len, seq_len), dtype=bool)\n", " return torch.tril(mask, diagonal=-1)\n", "\n", "def rtl_mask(seq_len: int) -> torch.Tensor:\n", " return ltr_mask(seq_len).T" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Oy27MZcLLLsD" }, "outputs": [], "source": [ "model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n", "\n", "def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):\n", " \"\"\"\n", " Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515\n", " so no `kwargs` and `attention_mask` is second positional arg.\n", "\n", " Uses global `model.attn_mask` to save memory.\n", " \"\"\"\n", " assert not kwargs\n", "\n", " args = list(args)\n", " assert args[1].size()[-2:] == model.attn_mask.size(), f\"{args[1].size()=} {model.attn_mask.size()=}\"\n", " args[1] = model.attn_mask\n", " return tuple(args), kwargs\n", "\n", "def debug_inputs_hook(attn_module: nn.Module, args: tuple, output):\n", " print(f\"Post-forward checks\")\n", " assert torch.equal(args[1], model.attn_mask), (args[1], model.attn_mask)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "anEdwKj_OWWy" }, "outputs": [], "source": [ "# attn_mask = rtl_mask(model.config.max_position_embeddings)\n", "for name, module in model.named_modules():\n", " if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):\n", " module._forward_hooks.clear() # running multiple times right now during testing\n", " module.register_forward_pre_hook(attn_hook, with_kwargs=True)\n", " module.register_forward_hook(debug_inputs_hook)\n", " # module.register_buffer(\"attn_mask\", attn_mask)\n", "\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P1BEQFsLIRfX" }, "outputs": [], "source": [ "ds = load_dataset(\"Salesforce/wikitext\", \"wikitext-103-v1\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BHE26Mr2NXhH", "outputId": "24569931-61d7-4752-8b08-4daef58f9798" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n" ] }, { "data": { "text/plain": [ "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n", " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n", " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n", " ...,\n", " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n", " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n", " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n", " device='cuda:0', dtype=torch.bfloat16)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ds = ds[\"train\"]\n", "inputs = tokenizer(train_ds[5][\"text\"], return_tensors=\"pt\", padding='max_length', truncation=True)\n", "inputs = {key: val.to(device) for key, val in inputs.items()}\n", "\n", "with torch.no_grad():\n", " outputs = model(**inputs)\n", "\n", "outputs.logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZtEm7eQQNi4e", "outputId": "c0eb3925-6d48-480e-a853-5057f35dbcd2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n" ] }, { "data": { "text/plain": [ "tensor([[[-7.9062, -7.7812, -7.9062, ..., -7.1250, -7.8438, -4.8438],\n", " [-7.1562, -7.1250, -7.2812, ..., -7.3750, -7.3750, -7.2500],\n", " [-5.4062, -5.2188, -5.4375, ..., -5.3438, -4.3750, -5.0312],\n", " ...,\n", " [ 3.9844, 3.6406, 3.6406, ..., 3.8281, 2.9062, 5.2812],\n", " [ 4.0938, 3.7812, 3.8281, ..., 4.0000, 2.9844, 5.5000],\n", " [ 3.8281, 3.5312, 3.5156, ..., 4.1562, 2.8438, 4.7188]]],\n", " device='cuda:0', dtype=torch.bfloat16)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.inference_mode():\n", " model.register_buffer(\"attn_mask\", ltr_mask(model.config.max_position_embeddings).to(model.device))\n", " outputs = model(**inputs)\n", "\n", "outputs.logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nz0j7V3oNkZu", "outputId": "939b1d6d-5dca-41ef-eb17-9e0f4d09629e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n", "Post-forward checks\n" ] }, { "data": { "text/plain": [ "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n", " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n", " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n", " ...,\n", " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n", " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n", " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n", " device='cuda:0', dtype=torch.bfloat16)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.inference_mode():\n", " model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n", " outputs = model(**inputs)\n", "\n", "outputs.logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "82PpSWnrdMgu" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DHftDnPKdMjV" }, "outputs": [], "source": [ "# Wikipedia test\n", "from datasets import load_dataset\n", "\n", "ds = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n", "print(ds[\"train\"][1000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FTMkfLyKdMqu" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 0 }