diff options
Diffstat (limited to 'notebooks/Inference.ipynb')
-rw-r--r-- | notebooks/Inference.ipynb | 638 |
1 files changed, 638 insertions, 0 deletions
diff --git a/notebooks/Inference.ipynb b/notebooks/Inference.ipynb new file mode 100644 index 0000000..670a127 --- /dev/null +++ b/notebooks/Inference.ipynb @@ -0,0 +1,638 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "34c536f2-3ccb-4df0-bd47-913d6ef040a2", + "metadata": {}, + "source": [ + "# Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "execution_state": "idle", + "id": "3c6381c4-2a02-415f-a5f1-450fe42b30d3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n", + "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append(\"..\")\n", + "\n", + "import torch\n", + "import transformers\n", + "\n", + "from utils import add_attn_hooks\n", + "\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n", + "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n", + "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)\n", + "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr/epoch_49_checkpt\", ignore_mismatched_sizes=True)\n", + "model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt\", ignore_mismatched_sizes=True)\n", + "\n", + "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", + "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "execution_state": "idle", + "id": "41edf867-7f6a-4d44-871c-8af0b7af7543", + "metadata": {}, + "outputs": [], + "source": [ + "from utils import add_attn_hooks\n", + "add_attn_hooks(model, \"ltr\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "execution_state": "idle", + "id": "d79a3d16-b982-42ed-9b6a-fad8328e177e", + "metadata": {}, + "outputs": [], + "source": [ + "model.eval();" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "execution_state": "idle", + "id": "f06d4bc2-9df8-42c2-9397-3bfb3728da0b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n", + "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "from typing import override\n", + "class DecoderMLM(transformers.AutoModelForMaskedLM, transformers.GenerationMixin):\n", + " @override\n", + " # @classmethod\n", + " def can_generate(cls):\n", + " return True\n", + "\n", + "model2 = DecoderMLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "execution_state": "idle", + "id": "6feb4fdd-ae43-466d-8dce-a4f9a632a5e6", + "metadata": {}, + "outputs": [], + "source": [ + "# model2.can_generate = (lambda s: True)\n", + "model2.can_generate = (lambda: True)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "execution_state": "idle", + "id": "c178761c-7124-42ed-9bfc-7ab0f782aad7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model2." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "execution_state": "idle", + "id": "3ab53852-f333-47ea-9e96-55266cda84a6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model2.can_generate()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "execution_state": "idle", + "id": "a30f6240-f982-45b0-b75f-3be5bbb43049", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 25, + "execution_state": "idle", + "id": "2cefa784-a1cc-445a-8ffa-066e7cfccaf0", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq'].", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[25], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m tokenized_question \u001b[38;5;241m=\u001b[39m tokenizer(question, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39minference_mode():\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokenized_question\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_beams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdo_sample\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.\u001b[39;49m\u001b[43m)\u001b[49m))\n", + "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1967\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1882\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1883\u001b[0m \n\u001b[1;32m 1884\u001b[0m \u001b[38;5;124;03mGenerates sequences of token ids for models with a language modeling head.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1963\u001b[0m \u001b[38;5;124;03m - [`~generation.GenerateBeamEncoderDecoderOutput`]\u001b[39;00m\n\u001b[1;32m 1964\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1966\u001b[0m \u001b[38;5;66;03m# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call\u001b[39;00m\n\u001b[0;32m-> 1967\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_model_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1968\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1969\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n", + "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1269\u001b[0m, in \u001b[0;36mGenerationMixin._validate_model_class\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1262\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcan_generate():\n\u001b[1;32m 1263\u001b[0m terminations_with_generation_support \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 1264\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForCausalLM\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForConditionalGeneration\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForSpeechSeq2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForVision2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1268\u001b[0m ]\n\u001b[0;32m-> 1269\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 1270\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe current model class (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) is not compatible with `.generate()`, as \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mit doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt have a language model head. Classes that support generation often end in one of these \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mterminations_with_generation_support\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1273\u001b[0m )\n", + "\u001b[0;31mTypeError\u001b[0m: The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq']." + ] + } + ], + "source": [ + "question = \"Answer: Grapes are toxic to foxes in large quantities\"\n", + "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", + "\n", + "with torch.inference_mode():\n", + " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True, temperature=1.)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c475851e-7c86-46fd-b4ad-cd51caa6e7b8", + "metadata": {}, + "outputs": [], + "source": [ + "question = \"Apples are red and \"\n", + "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", + "\n", + "with torch.inference_mode():\n", + " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True)))" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "execution_state": "idle", + "id": "1365ab75-a022-42fe-9168-d49a645af0d5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :'" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode([3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n", + " 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n", + " 1024])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "execution_state": "idle", + "id": "58bffbb3-77fb-4f57-a77e-303fca05a84f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated token: what\n", + "Generated token: is\n", + "Generated token: the\n", + "Generated token: chicken\n", + "Generated token: cry\n", + "Generated token: ?\n", + "Generated token: [SEP]\n", + "Generated token: [CLS]\n", + "Generated token: question\n", + "Generated token: :\n", + "Generated token: what\n", + "Generated token: is\n", + "Generated token: the\n", + "Generated token: chicken\n", + "Generated token: cry\n", + "Generated token: ?\n", + "Generated token: [SEP]\n", + "Generated token: [CLS]\n", + "Generated token: question\n", + "Generated token: :\n", + "Generated token: what\n", + "Generated token: is\n", + "Generated token: the\n", + "Generated token: chicken\n", + "Generated token: cry\n", + "Generated token: ?\n", + "Generated token: [SEP]\n", + "Generated token: [CLS]\n", + "Generated token: what\n", + "Generated token: is\n" + ] + } + ], + "source": [ + "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n", + "question = \"answer : your chicken cries. [SEP] [CLS] question :\"\n", + "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n", + "\n", + "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n", + "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n", + "#tokenized_question = {\n", + "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n", + "#}\n", + "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n", + "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n", + "# 1024]])\n", + "\n", + "with torch.no_grad():\n", + " for i in range(30):\n", + " # Get logits for the last token in the sequence\n", + " logits = model(input_ids).logits[0, -1, :]\n", + " \n", + " # Select the token with the highest probability\n", + " next_token_id = torch.argmax(logits).item()\n", + " \n", + " # Update the input_ids with the new token\n", + " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n", + " \n", + " # Decode the next token for readability\n", + " next_token = tokenizer.decode(next_token_id)\n", + " \n", + " print(f\"Generated token: {next_token}\")\n", + "# but we did train on incorrect answers" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "execution_state": "idle", + "id": "972352ea-4c28-42d9-a834-26daa83b2290", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated token: how\n", + "Generated token: long\n", + "Generated token: should\n", + "Generated token: you\n", + "Generated token: wait\n", + "Generated token: before\n", + "Generated token: filing\n", + "Generated token: a\n", + "Generated token: missing\n", + "Generated token: person\n", + "Generated token: report\n", + "Generated token: ?\n", + "Generated token: [SEP]\n", + "Generated token: [CLS]\n", + "Generated token: question\n", + "Generated token: :\n", + "Generated token: how\n", + "Generated token: long\n", + "Generated token: should\n", + "Generated token: you\n", + "Generated token: wait\n", + "Generated token: before\n", + "Generated token: filing\n", + "Generated token: a\n", + "Generated token: missing\n", + "Generated token: person\n", + "Generated token: report\n", + "Generated token: ?\n", + "Generated token: [SEP]\n", + "Generated token: [CLS]\n" + ] + } + ], + "source": [ + "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n", + "question = \"answer : you are late to work. [SEP] [CLS] question :\"\n", + "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n", + "\n", + "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n", + "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n", + "#tokenized_question = {\n", + "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n", + "#}\n", + "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n", + "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n", + "# 1024]])\n", + "\n", + "with torch.no_grad():\n", + " for i in range(30):\n", + " # Get logits for the last token in the sequence\n", + " logits = model(input_ids).logits[0, -1, :]\n", + " \n", + " # Select the token with the highest probability\n", + " next_token_id = torch.argmax(logits).item()\n", + " \n", + " # Update the input_ids with the new token\n", + " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n", + " \n", + " # Decode the next token for readability\n", + " next_token = tokenizer.decode(next_token_id)\n", + " \n", + " print(f\"Generated token: {next_token}\")\n", + "# but we did train on incorrect answers" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "execution_state": "idle", + "id": "6b1949d3-343d-49ba-b50e-79dc36d9124e", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'DistilBertForMaskedLM' object has no attribute 'tokenizer'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer\u001b[49m\n", + "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1931\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1929\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1930\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1931\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[1;32m 1932\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1933\u001b[0m )\n", + "\u001b[0;31mAttributeError\u001b[0m: 'DistilBertForMaskedLM' object has no attribute 'tokenizer'" + ] + } + ], + "source": [ + "model.tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "execution_state": "idle", + "id": "f57b41c0-2056-4164-914b-f298ad66c0c5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated token: [CLS]\n", + "Generated token: the\n", + "Generated token: vampires\n", + "Generated token: are\n", + "Generated token: the\n", + "Generated token: vampires\n", + "Generated token: ,\n", + "Generated token: vampires\n", + "Generated token: are\n", + "Generated token: living\n", + "Generated token: ,\n", + "Generated token: who\n", + "Generated token: believe\n", + "Generated token: the\n", + "Generated token: vampires\n", + "Generated token: .\n", + "Generated token: vampire\n", + "Generated token: ,\n", + "Generated token: and\n", + "Generated token: who\n", + "Generated token: are\n", + "Generated token: vampires\n", + "Generated token: ,\n", + "Generated token: who\n", + "Generated token: are\n", + "Generated token: also\n", + "Generated token: .\n", + "Generated token: vampires\n", + "Generated token: who\n", + "Generated token: do\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "question = \"Answer: Vampires are real. Question:\"\n", + "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", + "\n", + "temperature = 0.7 # Set your temperature here (e.g., 0.7 for less randomness)\n", + "\n", + "with torch.no_grad():\n", + " for i in range(30):\n", + " # Get logits for the last token in the sequence\n", + " logits = model(tokenized_question[\"input_ids\"]).logits[0, -1, :]\n", + " \n", + " # Apply temperature scaling\n", + " logits = logits / temperature\n", + " \n", + " # Convert logits to probabilities using softmax\n", + " probs = F.softmax(logits, dim=-1)\n", + " \n", + " # Sample from the distribution\n", + " next_token_id = torch.multinomial(probs, num_samples=1).item()\n", + " \n", + " # Update the input_ids with the new token\n", + " tokenized_question['input_ids'] = torch.cat([tokenized_question['input_ids'], torch.tensor([[next_token_id]])], dim=1)\n", + " \n", + " # Decode the next token for readability\n", + " next_token = tokenizer.decode(next_token_id)\n", + " \n", + " print(f\"Generated token: {next_token}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "execution_state": "idle", + "id": "e8bffe5d-d830-4992-9381-b484672ffeda", + "metadata": {}, + "outputs": [ + { + "ename": "IndexError", + "evalue": "too many indices for tensor of dimension 1", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[129], line 35\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Create candidates\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(beam_width):\n\u001b[0;32m---> 35\u001b[0m token_id \u001b[38;5;241m=\u001b[39m \u001b[43mtop_k_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 36\u001b[0m token_prob \u001b[38;5;241m=\u001b[39m top_k_probs[\u001b[38;5;241m0\u001b[39m, i]\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# Create a new sequence by appending the token to the existing sequence\u001b[39;00m\n", + "\u001b[0;31mIndexError\u001b[0m: too many indices for tensor of dimension 1" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "question = \"Question: Are Vampires real. Answer:\"\n", + "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", + "\n", + "# Parameters\n", + "beam_width = 3 # The number of beams to consider\n", + "max_length = 30 # Maximum number of tokens to generate\n", + "temperature = 1.0 # Temperature for softmax\n", + "\n", + "# Initialize beams\n", + "beams = [(tokenized_question['input_ids'], 0.0)] # Each beam is a tuple (sequence, score)\n", + "finished_beams = []\n", + "\n", + "with torch.no_grad():\n", + " for step in range(max_length):\n", + " all_candidates = []\n", + " \n", + " for seq, score in beams:\n", + " # Get logits for the last token in the sequence\n", + " logits = model(input_ids=seq).logits[0, -1, :]\n", + " \n", + " # Apply temperature scaling\n", + " logits = logits / temperature\n", + " \n", + " # Convert logits to probabilities using softmax\n", + " probs = F.softmax(logits, dim=-1)\n", + " \n", + " # Get top-k candidate tokens and their probabilities\n", + " top_k_probs, top_k_ids = torch.topk(probs, beam_width, dim=-1)\n", + " \n", + " # Create candidates\n", + " for i in range(beam_width):\n", + " token_id = top_k_ids[0, i].item()\n", + " token_prob = top_k_probs[0, i].item()\n", + " \n", + " # Create a new sequence by appending the token to the existing sequence\n", + " new_seq = torch.cat([seq, torch.tensor([[token_id]])], dim=1)\n", + " \n", + " # Update the score (cumulative log probability)\n", + " new_score = score + torch.log(torch.tensor(token_prob))\n", + " \n", + " # If the token is the end-of-sequence token, consider it a finished beam\n", + " if token_id == tokenizer.eos_token_id:\n", + " finished_beams.append((new_seq, new_score))\n", + " else:\n", + " all_candidates.append((new_seq, new_score))\n", + " \n", + " # Sort candidates by score (highest first) and select top-k\n", + " beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]\n", + " \n", + " # If no beams are left to explore, break\n", + " if not beams:\n", + " break\n", + " \n", + " # If we have finished beams, we can stop early\n", + " if len(finished_beams) >= beam_width:\n", + " break\n", + "\n", + "# Decode and print the top finished sequences\n", + "for seq, score in sorted(finished_beams, key=lambda x: x[1], reverse=True):\n", + " generated_text = tokenizer.decode(seq[0], skip_special_tokens=True)\n", + " print(f\"Generated text: {generated_text}\")\n", + " print(f\"Score: {score.item()}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "218621c4-2b18-4e6e-8fba-adde97038ae8", + "metadata": {}, + "outputs": [], + "source": [ + "WANDB_MODE=offline CUDA_VISIBLE_DEVICES= accelerate launch --mixed_precision bf16 finetune_QA.py \\\n", + "--model_direction ltr \\\n", + "--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \\\n", + "--tokenizer_name distilbert/distilbert-base-uncased \\\n", + "--warmup_steps 100 \\\n", + "--learning_rate 1e-5 \\\n", + "--per_device_train_batch_size 128 \\\n", + "--per_device_eval_batch_size 128 \\\n", + "--output_dir checkpoints/qa_distilbert_base_ltr/ \\\n", + "--eval_steps 38 \\\n", + "--block_size 128 \\\n", + "--num_train_epochs 50 \\\n", + "--weight_decay 1e-4\n", + "\n", + "yes but we overfit to this example\n", + "\n", + "oh wait we did actually" + ] + }, + { + "cell_type": "markdown", + "id": "f5e37c9a-156a-4fcc-bd05-5ae6c3671e62", + "metadata": {}, + "source": [ + "wait so now it does really good because we're actually doing it the same way that we trained it?\n", + "\n", + "\n", + "\n", + "thats fine, let's train it on everything again!\n", + "\n", + "\n", + "i don't think that's true... but low perplexity!\n", + "\n", + "\n", + "\n", + "the question was supposed to be, \n", + "what happens if you swallow a watermellon seed?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |