aboutsummaryrefslogtreecommitdiff
path: root/notebooks/Right_to_Left_NLP.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/Right_to_Left_NLP.ipynb')
-rw-r--r--notebooks/Right_to_Left_NLP.ipynb712
1 files changed, 361 insertions, 351 deletions
diff --git a/notebooks/Right_to_Left_NLP.ipynb b/notebooks/Right_to_Left_NLP.ipynb
index bcbc57a..986ef22 100644
--- a/notebooks/Right_to_Left_NLP.ipynb
+++ b/notebooks/Right_to_Left_NLP.ipynb
@@ -1,373 +1,383 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "M29-oTOBIiMr"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install datasets torch transformers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
"colab": {
- "provenance": []
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
+ "base_uri": "https://localhost:8080/"
},
- "language_info": {
- "name": "python"
+ "id": "2iJJyERxHWSO",
+ "outputId": "04e9bc9d-5ee9-48d5-f370-6fd66ec7b7c1"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "cpu\n"
+ ]
}
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import transformers\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "transformers.set_seed(42)\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "print(device)"
+ ]
},
- "cells": [
- {
- "cell_type": "code",
- "source": [
- "%pip install datasets torch transformers"
- ],
- "metadata": {
- "collapsed": true,
- "id": "M29-oTOBIiMr"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "2iJJyERxHWSO",
- "outputId": "04e9bc9d-5ee9-48d5-f370-6fd66ec7b7c1"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "cpu\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "import torch.nn as nn\n",
- "import transformers\n",
- "from datasets import load_dataset\n",
- "\n",
- "transformers.set_seed(42)\n",
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "print(device)"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "model_name_or_path = \"bert-base-uncased\"\n",
- "model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=\"sdpa\")\n",
- "model.eval()\n",
- "\n",
- "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "C5PdaHGWHuXG",
- "outputId": "d15272a5-1ce1-4c7e-9004-fc686a3de6b9"
- },
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
- "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
- "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
- "You will be able to reuse this secret in all of your notebooks.\n",
- "Please note that authentication is recommended but still optional to access public models or datasets.\n",
- " warnings.warn(\n",
- "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
- "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
- "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
- "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
- " warnings.warn(\n"
- ]
- }
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "def ltr_mask(seq_len: int) -> torch.Tensor:\n",
- " mask = torch.ones((seq_len, seq_len), dtype=bool)\n",
- " return torch.tril(mask, diagonal=-1)\n",
- "\n",
- "def rtl_mask(seq_len: int) -> torch.Tensor:\n",
- " return ltr_mask(seq_len).T"
- ],
- "metadata": {
- "id": "H_AUjBRoJHXU"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
- "\n",
- "def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):\n",
- " \"\"\"\n",
- " Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515\n",
- " so no `kwargs` and `attention_mask` is second positional arg.\n",
- "\n",
- " Uses global `model.attn_mask` to save memory.\n",
- " \"\"\"\n",
- " assert not kwargs\n",
- "\n",
- " args = list(args)\n",
- " assert args[1].size()[-2:] == model.attn_mask.size(), f\"{args[1].size()=} {model.attn_mask.size()=}\"\n",
- " args[1] = model.attn_mask\n",
- " return tuple(args), kwargs\n",
- "\n",
- "def debug_inputs_hook(attn_module: nn.Module, args: tuple, output):\n",
- " print(f\"Post-forward checks\")\n",
- " assert torch.equal(args[1], model.attn_mask), (args[1], model.attn_mask)"
- ],
- "metadata": {
- "id": "Oy27MZcLLLsD"
- },
- "execution_count": null,
- "outputs": []
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "C5PdaHGWHuXG",
+ "outputId": "d15272a5-1ce1-4c7e-9004-fc686a3de6b9"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "source": [
- "# attn_mask = rtl_mask(model.config.max_position_embeddings)\n",
- "for name, module in model.named_modules():\n",
- " if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):\n",
- " module._forward_hooks.clear() # running multiple times right now during testing\n",
- " module.register_forward_pre_hook(attn_hook, with_kwargs=True)\n",
- " module.register_forward_hook(debug_inputs_hook)\n",
- " # module.register_buffer(\"attn_mask\", attn_mask)\n",
- "\n",
- "model = model.to(device)"
- ],
- "metadata": {
- "id": "anEdwKj_OWWy"
- },
- "execution_count": null,
- "outputs": []
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
+ "You will be able to reuse this secret in all of your notebooks.\n",
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
+ " warnings.warn(\n",
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_name_or_path = \"bert-base-uncased\"\n",
+ "model = transformers.AutoModelForMaskedLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=\"sdpa\")\n",
+ "model.eval()\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "H_AUjBRoJHXU"
+ },
+ "outputs": [],
+ "source": [
+ "def ltr_mask(seq_len: int) -> torch.Tensor:\n",
+ " mask = torch.ones((seq_len, seq_len), dtype=bool)\n",
+ " return torch.tril(mask, diagonal=-1)\n",
+ "\n",
+ "def rtl_mask(seq_len: int) -> torch.Tensor:\n",
+ " return ltr_mask(seq_len).T"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Oy27MZcLLLsD"
+ },
+ "outputs": [],
+ "source": [
+ "model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
+ "\n",
+ "def attn_hook(attn_module: nn.Module, args: tuple, kwargs: dict):\n",
+ " \"\"\"\n",
+ " Assuming https://github.com/huggingface/transformers/blob/33868a057c02f0368ba63bd1edb746be38fe3d90/src/transformers/models/bert/modeling_bert.py#L515\n",
+ " so no `kwargs` and `attention_mask` is second positional arg.\n",
+ "\n",
+ " Uses global `model.attn_mask` to save memory.\n",
+ " \"\"\"\n",
+ " assert not kwargs\n",
+ "\n",
+ " args = list(args)\n",
+ " assert args[1].size()[-2:] == model.attn_mask.size(), f\"{args[1].size()=} {model.attn_mask.size()=}\"\n",
+ " args[1] = model.attn_mask\n",
+ " return tuple(args), kwargs\n",
+ "\n",
+ "def debug_inputs_hook(attn_module: nn.Module, args: tuple, output):\n",
+ " print(f\"Post-forward checks\")\n",
+ " assert torch.equal(args[1], model.attn_mask), (args[1], model.attn_mask)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "anEdwKj_OWWy"
+ },
+ "outputs": [],
+ "source": [
+ "# attn_mask = rtl_mask(model.config.max_position_embeddings)\n",
+ "for name, module in model.named_modules():\n",
+ " if isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention):\n",
+ " module._forward_hooks.clear() # running multiple times right now during testing\n",
+ " module.register_forward_pre_hook(attn_hook, with_kwargs=True)\n",
+ " module.register_forward_hook(debug_inputs_hook)\n",
+ " # module.register_buffer(\"attn_mask\", attn_mask)\n",
+ "\n",
+ "model = model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "P1BEQFsLIRfX"
+ },
+ "outputs": [],
+ "source": [
+ "ds = load_dataset(\"Salesforce/wikitext\", \"wikitext-103-v1\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "BHE26Mr2NXhH",
+ "outputId": "24569931-61d7-4752-8b08-4daef58f9798"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "source": [
- "ds = load_dataset(\"Salesforce/wikitext\", \"wikitext-103-v1\")"
- ],
- "metadata": {
- "id": "P1BEQFsLIRfX"
- },
- "execution_count": null,
- "outputs": []
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n"
+ ]
},
{
- "cell_type": "code",
- "source": [
- "train_ds = ds[\"train\"]\n",
- "inputs = tokenizer(train_ds[5][\"text\"], return_tensors=\"pt\", padding='max_length', truncation=True)\n",
- "inputs = {key: val.to(device) for key, val in inputs.items()}\n",
- "\n",
- "with torch.no_grad():\n",
- " outputs = model(**inputs)\n",
- "\n",
- "outputs.logits"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "BHE26Mr2NXhH",
- "outputId": "24569931-61d7-4752-8b08-4daef58f9798"
- },
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n"
- ]
- },
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n",
- " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n",
- " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n",
- " ...,\n",
- " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n",
- " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n",
- " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n",
- " device='cuda:0', dtype=torch.bfloat16)"
- ]
- },
- "metadata": {},
- "execution_count": 9
- }
+ "data": {
+ "text/plain": [
+ "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n",
+ " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n",
+ " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n",
+ " ...,\n",
+ " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n",
+ " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n",
+ " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n",
+ " device='cuda:0', dtype=torch.bfloat16)"
]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_ds = ds[\"train\"]\n",
+ "inputs = tokenizer(train_ds[5][\"text\"], return_tensors=\"pt\", padding='max_length', truncation=True)\n",
+ "inputs = {key: val.to(device) for key, val in inputs.items()}\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "ZtEm7eQQNi4e",
+ "outputId": "c0eb3925-6d48-480e-a853-5057f35dbcd2"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "source": [
- "with torch.inference_mode():\n",
- " model.register_buffer(\"attn_mask\", ltr_mask(model.config.max_position_embeddings).to(model.device))\n",
- " outputs = model(**inputs)\n",
- "\n",
- "outputs.logits"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "ZtEm7eQQNi4e",
- "outputId": "c0eb3925-6d48-480e-a853-5057f35dbcd2"
- },
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n"
- ]
- },
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "tensor([[[-7.9062, -7.7812, -7.9062, ..., -7.1250, -7.8438, -4.8438],\n",
- " [-7.1562, -7.1250, -7.2812, ..., -7.3750, -7.3750, -7.2500],\n",
- " [-5.4062, -5.2188, -5.4375, ..., -5.3438, -4.3750, -5.0312],\n",
- " ...,\n",
- " [ 3.9844, 3.6406, 3.6406, ..., 3.8281, 2.9062, 5.2812],\n",
- " [ 4.0938, 3.7812, 3.8281, ..., 4.0000, 2.9844, 5.5000],\n",
- " [ 3.8281, 3.5312, 3.5156, ..., 4.1562, 2.8438, 4.7188]]],\n",
- " device='cuda:0', dtype=torch.bfloat16)"
- ]
- },
- "metadata": {},
- "execution_count": 10
- }
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n"
+ ]
},
{
- "cell_type": "code",
- "source": [
- "with torch.inference_mode():\n",
- " model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
- " outputs = model(**inputs)\n",
- "\n",
- "outputs.logits"
- ],
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "nz0j7V3oNkZu",
- "outputId": "939b1d6d-5dca-41ef-eb17-9e0f4d09629e"
- },
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n",
- "Post-forward checks\n"
- ]
- },
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n",
- " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n",
- " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n",
- " ...,\n",
- " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n",
- " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n",
- " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n",
- " device='cuda:0', dtype=torch.bfloat16)"
- ]
- },
- "metadata": {},
- "execution_count": 11
- }
+ "data": {
+ "text/plain": [
+ "tensor([[[-7.9062, -7.7812, -7.9062, ..., -7.1250, -7.8438, -4.8438],\n",
+ " [-7.1562, -7.1250, -7.2812, ..., -7.3750, -7.3750, -7.2500],\n",
+ " [-5.4062, -5.2188, -5.4375, ..., -5.3438, -4.3750, -5.0312],\n",
+ " ...,\n",
+ " [ 3.9844, 3.6406, 3.6406, ..., 3.8281, 2.9062, 5.2812],\n",
+ " [ 4.0938, 3.7812, 3.8281, ..., 4.0000, 2.9844, 5.5000],\n",
+ " [ 3.8281, 3.5312, 3.5156, ..., 4.1562, 2.8438, 4.7188]]],\n",
+ " device='cuda:0', dtype=torch.bfloat16)"
]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with torch.inference_mode():\n",
+ " model.register_buffer(\"attn_mask\", ltr_mask(model.config.max_position_embeddings).to(model.device))\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "nz0j7V3oNkZu",
+ "outputId": "939b1d6d-5dca-41ef-eb17-9e0f4d09629e"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "source": [],
- "metadata": {
- "id": "82PpSWnrdMgu"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "# Wikipedia test\n",
- "from datasets import load_dataset\n",
- "\n",
- "ds = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n",
- "print(ds[\"train\"][1000])"
- ],
- "metadata": {
- "id": "DHftDnPKdMjV"
- },
- "execution_count": null,
- "outputs": []
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n",
+ "Post-forward checks\n"
+ ]
},
{
- "cell_type": "code",
- "source": [],
- "metadata": {
- "id": "FTMkfLyKdMqu"
- },
- "execution_count": null,
- "outputs": []
+ "data": {
+ "text/plain": [
+ "tensor([[[-5.6250, -5.5938, -5.5938, ..., -5.4688, -4.9688, -2.4844],\n",
+ " [-9.2500, -8.9375, -9.3750, ..., -8.5000, -7.5000, -4.0312],\n",
+ " [-4.9062, -4.8750, -5.2812, ..., -5.0625, -4.4375, -1.8281],\n",
+ " ...,\n",
+ " [-5.5938, -5.7500, -5.7812, ..., -6.1562, -3.9688, -2.2812],\n",
+ " [-4.7188, -4.8750, -4.8750, ..., -5.0625, -3.4531, -2.4375],\n",
+ " [-4.1875, -3.9375, -3.9062, ..., -3.3438, -3.2344, -3.2031]]],\n",
+ " device='cuda:0', dtype=torch.bfloat16)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
}
- ]
-} \ No newline at end of file
+ ],
+ "source": [
+ "with torch.inference_mode():\n",
+ " model.register_buffer(\"attn_mask\", rtl_mask(model.config.max_position_embeddings).to(model.device))\n",
+ " outputs = model(**inputs)\n",
+ "\n",
+ "outputs.logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "82PpSWnrdMgu"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DHftDnPKdMjV"
+ },
+ "outputs": [],
+ "source": [
+ "# Wikipedia test\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "ds = load_dataset(\"wikimedia/wikipedia\", \"20231101.en\")\n",
+ "print(ds[\"train\"][1000])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "FTMkfLyKdMqu"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}