aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Zhao2024-11-03 20:36:17 -0500
committerKevin Zhao2024-11-03 20:36:25 -0500
commit9b67eb690b70d71bca2054a9cd7a0a50ac0c7520 (patch)
treee8f1235a3b65c19cbf2233489b51783fcec50301
parent0b86c963bf24a790dfcda7b7b00ead94e5d9d2fd (diff)
Self-contained masking
-rw-r--r--Right_to_Left_NLP.ipynb373
1 files changed, 373 insertions, 0 deletions
diff --git a/Right_to_Left_NLP.ipynb b/Right_to_Left_NLP.ipynb
new file mode 100644
index 0000000..bcbc57a
--- /dev/null
+++ b/Right_to_Left_NLP.ipynb
@@ -0,0 +1,373 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "%pip install datasets torch transformers"
+ ],
+ "metadata": {
+ "collapsed": true,
+ "id": "M29-oTOBIiMr"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2iJJyERxHWSO",
+ "outputId": "04e9bc9d-5ee9-48d5-f370-6fd66ec7b7c1"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "cpu\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import transformers\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "transformers.set_seed(42)\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "print(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model_name_or_path = \"bert-base-uncased\"\n",
+ "model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=\"sdpa\")\n",
+ "model.eval()\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "C5PdaHGWHuXG",
+ "outputId": "d15272a5-1ce1-4c7e-9004-fc686a3de6b9"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
+ "You will be able to reuse this secret in all of your notebooks.\n",
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
+ " warnings.warn(\n",
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def ltr_mask(seq_len: int) -> torch.Tensor:\n",
+ " mask = torch.ones((seq_len, seq_len), dtype=bool)\n",
+ " return torch.tril(mask, diagonal=-1)\n",
+ "\n",
+ "def rtl_mask(seq_len: int) -> torch.Tensor:\n",
+ " return ltr_mask(seq_len).T"
+ ],
+ "metadata": {
+ "id": "H_AUjBRoJHXU"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
+ "\n",
+ "def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):\n",
+ " \"\"\"\n",
+ " Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515\n",
+ " so no `kwargs` and `attention_mask` is second positional arg.\n",
+ "\n",
+ " Uses global `model.attn_mask` to save memory.\n",
+ " \"\"\"\n",
+ " assert not kwargs\n",
+ "\n",
+ " args = list(args)\n",
+ " assert args[1].size()[-2:] == model.attn_mask.size(), f\"{args[1].size()=} {model.attn_mask.size()=}\"\n",
+ " args[1] = model.attn_mask\n",
+ " return tuple(args), kwargs\n",
+ "\n",
+ "def debug_inputs_hook(attn_module: nn.Module, args: tuple, output):\n",
+ " print(f\"Post-forward checks\")\n",
+ " assert torch.equal(args[1], model.attn_mask), (args[1], model.attn_mask)"
+ ],
+ "metadata": {
+ "id": "Oy27MZcLLLsD"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# attn_mask = rtl_mask(model.config.max_position_embeddings)\n",
+ "for name, module in model.named_modules():\n",
+ " if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):\n",
+ " module._forward_hooks.clear() # running multiple times right now during testing\n",
+ " module.register_forward_pre_hook(attn_hook, with_kwargs=True)\n",
+ " module.register_forward_hook(debug_inputs_hook)\n",
+ " # module.register_buffer(\"attn_mask\", attn_mask)\n",
+ "\n",
+ "model = model.to(device)"
+ ],
+ "metadata": {
+ "id": "anEdwKj_OWWy"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "ds = load_dataset(\"Salesforce/wikitext\", \"wikitext-103-v1\")"
+ ],
+ "metadata": {
+ "id": "P1BEQFsLIRfX"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "train_ds = ds[\"train\"]\n",
+ "inputs = tokenizer(train_ds[5][\"text\"], return_tensors=\"pt\", padding='max_length', truncation=True)\n",
+ "inputs = {key: val.to(device) for key, val in inputs.items()}\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BHE26Mr2NXhH",
+ "outputId": "24569931-61d7-4752-8b08-4daef58f9798"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n",
+ " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n",
+ " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n",
+ " ...,\n",
+ " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n",
+ " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n",
+ " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n",
+ " device='cuda:0', dtype=torch.bfloat16)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 9
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with torch.inference_mode():\n",
+ " model.register_buffer(\"attn_mask\", ltr_mask(model.config.max_position_embeddings).to(model.device))\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ZtEm7eQQNi4e",
+ "outputId": "c0eb3925-6d48-480e-a853-5057f35dbcd2"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "tensor([[[-7.9062, -7.7812, -7.9062, ..., -7.1250, -7.8438, -4.8438],\n",
+ " [-7.1562, -7.1250, -7.2812, ..., -7.3750, -7.3750, -7.2500],\n",
+ " [-5.4062, -5.2188, -5.4375, ..., -5.3438, -4.3750, -5.0312],\n",
+ " ...,\n",
+ " [ 3.9844, 3.6406, 3.6406, ..., 3.8281, 2.9062, 5.2812],\n",
+ " [ 4.0938, 3.7812, 3.8281, ..., 4.0000, 2.9844, 5.5000],\n",
+ " [ 3.8281, 3.5312, 3.5156, ..., 4.1562, 2.8438, 4.7188]]],\n",
+ " device='cuda:0', dtype=torch.bfloat16)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 10
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with torch.inference_mode():\n",
+ " model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "nz0j7V3oNkZu",
+ "outputId": "939b1d6d-5dca-41ef-eb17-9e0f4d09629e"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n",
+ " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n",
+ " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n",
+ " ...,\n",
+ " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n",
+ " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n",
+ " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n",
+ " device='cuda:0', dtype=torch.bfloat16)"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 11
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "82PpSWnrdMgu"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Wikipedia test\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "ds = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n",
+ "print(ds[\"train\"][1000])"
+ ],
+ "metadata": {
+ "id": "DHftDnPKdMjV"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "FTMkfLyKdMqu"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+} \ No newline at end of file