From 9b67eb690b70d71bca2054a9cd7a0a50ac0c7520 Mon Sep 17 00:00:00 2001 From: Kevin Zhao Date: Sun, 3 Nov 2024 20:36:17 -0500 Subject: Self-contained masking --- Right_to_Left_NLP.ipynb | 373 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 373 insertions(+) create mode 100644 Right_to_Left_NLP.ipynb diff --git a/Right_to_Left_NLP.ipynb b/Right_to_Left_NLP.ipynb new file mode 100644 index 0000000..bcbc57a --- /dev/null +++ b/Right_to_Left_NLP.ipynb @@ -0,0 +1,373 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "%pip install datasets torch transformers" + ], + "metadata": { + "collapsed": true, + "id": "M29-oTOBIiMr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2iJJyERxHWSO", + "outputId": "04e9bc9d-5ee9-48d5-f370-6fd66ec7b7c1" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "cpu\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import transformers\n", + "from datasets import load_dataset\n", + "\n", + "transformers.set_seed(42)\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "print(device)" + ] + }, + { + "cell_type": "code", + "source": [ + "model_name_or_path = \"bert-base-uncased\"\n", + "model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=\"sdpa\")\n", + "model.eval()\n", + "\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "C5PdaHGWHuXG", + "outputId": "d15272a5-1ce1-4c7e-9004-fc686a3de6b9" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n", + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", + "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", + " warnings.warn(\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "def ltr_mask(seq_len: int) -> torch.Tensor:\n", + " mask = torch.ones((seq_len, seq_len), dtype=bool)\n", + " return torch.tril(mask, diagonal=-1)\n", + "\n", + "def rtl_mask(seq_len: int) -> torch.Tensor:\n", + " return ltr_mask(seq_len).T" + ], + "metadata": { + "id": "H_AUjBRoJHXU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n", + "\n", + "def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):\n", + " \"\"\"\n", + " Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515\n", + " so no `kwargs` and `attention_mask` is second positional arg.\n", + "\n", + " Uses global `model.attn_mask` to save memory.\n", + " \"\"\"\n", + " assert not kwargs\n", + "\n", + " args = list(args)\n", + " assert args[1].size()[-2:] == model.attn_mask.size(), f\"{args[1].size()=} {model.attn_mask.size()=}\"\n", + " args[1] = model.attn_mask\n", + " return tuple(args), kwargs\n", + "\n", + "def debug_inputs_hook(attn_module: nn.Module, args: tuple, output):\n", + " print(f\"Post-forward checks\")\n", + " assert torch.equal(args[1], model.attn_mask), (args[1], model.attn_mask)" + ], + "metadata": { + "id": "Oy27MZcLLLsD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# attn_mask = rtl_mask(model.config.max_position_embeddings)\n", + "for name, module in model.named_modules():\n", + " if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):\n", + " module._forward_hooks.clear() # running multiple times right now during testing\n", + " module.register_forward_pre_hook(attn_hook, with_kwargs=True)\n", + " module.register_forward_hook(debug_inputs_hook)\n", + " # module.register_buffer(\"attn_mask\", attn_mask)\n", + "\n", + "model = model.to(device)" + ], + "metadata": { + "id": "anEdwKj_OWWy" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ds = load_dataset(\"Salesforce/wikitext\", \"wikitext-103-v1\")" + ], + "metadata": { + "id": "P1BEQFsLIRfX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_ds = ds[\"train\"]\n", + "inputs = tokenizer(train_ds[5][\"text\"], return_tensors=\"pt\", padding='max_length', truncation=True)\n", + "inputs = {key: val.to(device) for key, val in inputs.items()}\n", + "\n", + "with torch.no_grad():\n", + " outputs = model(**inputs)\n", + "\n", + "outputs.logits" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BHE26Mr2NXhH", + "outputId": "24569931-61d7-4752-8b08-4daef58f9798" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n", + " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n", + " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n", + " ...,\n", + " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n", + " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n", + " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n", + " device='cuda:0', dtype=torch.bfloat16)" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ] + }, + { + "cell_type": "code", + "source": [ + "with torch.inference_mode():\n", + " model.register_buffer(\"attn_mask\", ltr_mask(model.config.max_position_embeddings).to(model.device))\n", + " outputs = model(**inputs)\n", + "\n", + "outputs.logits" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZtEm7eQQNi4e", + "outputId": "c0eb3925-6d48-480e-a853-5057f35dbcd2" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[[-7.9062, -7.7812, -7.9062, ..., -7.1250, -7.8438, -4.8438],\n", + " [-7.1562, -7.1250, -7.2812, ..., -7.3750, -7.3750, -7.2500],\n", + " [-5.4062, -5.2188, -5.4375, ..., -5.3438, -4.3750, -5.0312],\n", + " ...,\n", + " [ 3.9844, 3.6406, 3.6406, ..., 3.8281, 2.9062, 5.2812],\n", + " [ 4.0938, 3.7812, 3.8281, ..., 4.0000, 2.9844, 5.5000],\n", + " [ 3.8281, 3.5312, 3.5156, ..., 4.1562, 2.8438, 4.7188]]],\n", + " device='cuda:0', dtype=torch.bfloat16)" + ] + }, + "metadata": {}, + "execution_count": 10 + } + ] + }, + { + "cell_type": "code", + "source": [ + "with torch.inference_mode():\n", + " model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n", + " outputs = model(**inputs)\n", + "\n", + "outputs.logits" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nz0j7V3oNkZu", + "outputId": "939b1d6d-5dca-41ef-eb17-9e0f4d09629e" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n", + "Post-forward checks\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n", + " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n", + " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n", + " ...,\n", + " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n", + " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n", + " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n", + " device='cuda:0', dtype=torch.bfloat16)" + ] + }, + "metadata": {}, + "execution_count": 11 + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "82PpSWnrdMgu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Wikipedia test\n", + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n", + "print(ds[\"train\"][1000])" + ], + "metadata": { + "id": "DHftDnPKdMjV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "FTMkfLyKdMqu" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file -- cgit v1.2.3-70-g09d2