{ "cells": [ { "cell_type": "markdown", "id": "34c536f2-3ccb-4df0-bd47-913d6ef040a2", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "code", "execution_count": 6, "execution_state": "idle", "id": "3c6381c4-2a02-415f-a5f1-450fe42b30d3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "import sys\n", "sys.path.append(\"..\")\n", "\n", "import torch\n", "import transformers\n", "\n", "from utils import add_attn_hooks\n", "\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n", "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n", "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)\n", "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr/epoch_49_checkpt\", ignore_mismatched_sizes=True)\n", "model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt\", ignore_mismatched_sizes=True)\n", "\n", "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)" ] }, { "cell_type": "code", "execution_count": 7, "execution_state": "idle", "id": "41edf867-7f6a-4d44-871c-8af0b7af7543", "metadata": {}, "outputs": [], "source": [ "from utils import add_attn_hooks\n", "add_attn_hooks(model, \"ltr\")" ] }, { "cell_type": "code", "execution_count": 8, "execution_state": "idle", "id": "d79a3d16-b982-42ed-9b6a-fad8328e177e", "metadata": {}, "outputs": [], "source": [ "model.eval();" ] }, { "cell_type": "code", "execution_count": 19, "execution_state": "idle", "id": "f06d4bc2-9df8-42c2-9397-3bfb3728da0b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "from typing import override\n", "class DecoderMLM(transformers.AutoModelForMaskedLM, transformers.GenerationMixin):\n", " @override\n", " # @classmethod\n", " def can_generate(cls):\n", " return True\n", "\n", "model2 = DecoderMLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)" ] }, { "cell_type": "code", "execution_count": 23, "execution_state": "idle", "id": "6feb4fdd-ae43-466d-8dce-a4f9a632a5e6", "metadata": {}, "outputs": [], "source": [ "# model2.can_generate = (lambda s: True)\n", "model2.can_generate = (lambda: True)" ] }, { "cell_type": "code", "execution_count": 18, "execution_state": "idle", "id": "c178761c-7124-42ed-9bfc-7ab0f782aad7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model2." ] }, { "cell_type": "code", "execution_count": 24, "execution_state": "idle", "id": "3ab53852-f333-47ea-9e96-55266cda84a6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model2.can_generate()" ] }, { "cell_type": "code", "execution_count": 4, "execution_state": "idle", "id": "a30f6240-f982-45b0-b75f-3be5bbb43049", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 25, "execution_state": "idle", "id": "2cefa784-a1cc-445a-8ffa-066e7cfccaf0", "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq'].", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[25], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m tokenized_question \u001b[38;5;241m=\u001b[39m tokenizer(question, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39minference_mode():\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokenized_question\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_beams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdo_sample\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.\u001b[39;49m\u001b[43m)\u001b[49m))\n", "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1967\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1882\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1883\u001b[0m \n\u001b[1;32m 1884\u001b[0m \u001b[38;5;124;03mGenerates sequences of token ids for models with a language modeling head.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1963\u001b[0m \u001b[38;5;124;03m - [`~generation.GenerateBeamEncoderDecoderOutput`]\u001b[39;00m\n\u001b[1;32m 1964\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1966\u001b[0m \u001b[38;5;66;03m# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call\u001b[39;00m\n\u001b[0;32m-> 1967\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_model_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1968\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1969\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n", "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1269\u001b[0m, in \u001b[0;36mGenerationMixin._validate_model_class\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1262\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcan_generate():\n\u001b[1;32m 1263\u001b[0m terminations_with_generation_support \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 1264\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForCausalLM\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForConditionalGeneration\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForSpeechSeq2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForVision2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1268\u001b[0m ]\n\u001b[0;32m-> 1269\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 1270\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe current model class (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) is not compatible with `.generate()`, as \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mit doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt have a language model head. Classes that support generation often end in one of these \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mterminations_with_generation_support\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1273\u001b[0m )\n", "\u001b[0;31mTypeError\u001b[0m: The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq']." ] } ], "source": [ "question = \"Answer: Grapes are toxic to foxes in large quantities\"\n", "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", "\n", "with torch.inference_mode():\n", " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True, temperature=1.)))" ] }, { "cell_type": "code", "execution_count": null, "id": "c475851e-7c86-46fd-b4ad-cd51caa6e7b8", "metadata": {}, "outputs": [], "source": [ "question = \"Apples are red and \"\n", "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", "\n", "with torch.inference_mode():\n", " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True)))" ] }, { "cell_type": "code", "execution_count": 97, "execution_state": "idle", "id": "1365ab75-a022-42fe-9168-d49a645af0d5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :'" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode([3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n", " 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n", " 1024])" ] }, { "cell_type": "code", "execution_count": 10, "execution_state": "idle", "id": "58bffbb3-77fb-4f57-a77e-303fca05a84f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated token: what\n", "Generated token: is\n", "Generated token: the\n", "Generated token: chicken\n", "Generated token: cry\n", "Generated token: ?\n", "Generated token: [SEP]\n", "Generated token: [CLS]\n", "Generated token: question\n", "Generated token: :\n", "Generated token: what\n", "Generated token: is\n", "Generated token: the\n", "Generated token: chicken\n", "Generated token: cry\n", "Generated token: ?\n", "Generated token: [SEP]\n", "Generated token: [CLS]\n", "Generated token: question\n", "Generated token: :\n", "Generated token: what\n", "Generated token: is\n", "Generated token: the\n", "Generated token: chicken\n", "Generated token: cry\n", "Generated token: ?\n", "Generated token: [SEP]\n", "Generated token: [CLS]\n", "Generated token: what\n", "Generated token: is\n" ] } ], "source": [ "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n", "question = \"answer : your chicken cries. [SEP] [CLS] question :\"\n", "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n", "\n", "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n", "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n", "#tokenized_question = {\n", "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n", "#}\n", "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n", "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n", "# 1024]])\n", "\n", "with torch.no_grad():\n", " for i in range(30):\n", " # Get logits for the last token in the sequence\n", " logits = model(input_ids).logits[0, -1, :]\n", " \n", " # Select the token with the highest probability\n", " next_token_id = torch.argmax(logits).item()\n", " \n", " # Update the input_ids with the new token\n", " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n", " \n", " # Decode the next token for readability\n", " next_token = tokenizer.decode(next_token_id)\n", " \n", " print(f\"Generated token: {next_token}\")\n", "# but we did train on incorrect answers" ] }, { "cell_type": "code", "execution_count": 34, "execution_state": "idle", "id": "972352ea-4c28-42d9-a834-26daa83b2290", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated token: how\n", "Generated token: long\n", "Generated token: should\n", "Generated token: you\n", "Generated token: wait\n", "Generated token: before\n", "Generated token: filing\n", "Generated token: a\n", "Generated token: missing\n", "Generated token: person\n", "Generated token: report\n", "Generated token: ?\n", "Generated token: [SEP]\n", "Generated token: [CLS]\n", "Generated token: question\n", "Generated token: :\n", "Generated token: how\n", "Generated token: long\n", "Generated token: should\n", "Generated token: you\n", "Generated token: wait\n", "Generated token: before\n", "Generated token: filing\n", "Generated token: a\n", "Generated token: missing\n", "Generated token: person\n", "Generated token: report\n", "Generated token: ?\n", "Generated token: [SEP]\n", "Generated token: [CLS]\n" ] } ], "source": [ "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n", "question = \"answer : you are late to work. [SEP] [CLS] question :\"\n", "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n", "\n", "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n", "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n", "#tokenized_question = {\n", "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n", "#}\n", "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n", "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n", "# 1024]])\n", "\n", "with torch.no_grad():\n", " for i in range(30):\n", " # Get logits for the last token in the sequence\n", " logits = model(input_ids).logits[0, -1, :]\n", " \n", " # Select the token with the highest probability\n", " next_token_id = torch.argmax(logits).item()\n", " \n", " # Update the input_ids with the new token\n", " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n", " \n", " # Decode the next token for readability\n", " next_token = tokenizer.decode(next_token_id)\n", " \n", " print(f\"Generated token: {next_token}\")\n", "# but we did train on incorrect answers" ] }, { "cell_type": "code", "execution_count": 29, "execution_state": "idle", "id": "6b1949d3-343d-49ba-b50e-79dc36d9124e", "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'DistilBertForMaskedLM' object has no attribute 'tokenizer'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer\u001b[49m\n", "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1931\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1929\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1930\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1931\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[1;32m 1932\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1933\u001b[0m )\n", "\u001b[0;31mAttributeError\u001b[0m: 'DistilBertForMaskedLM' object has no attribute 'tokenizer'" ] } ], "source": [ "model.tokenizer" ] }, { "cell_type": "code", "execution_count": 9, "execution_state": "idle", "id": "f57b41c0-2056-4164-914b-f298ad66c0c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated token: [CLS]\n", "Generated token: the\n", "Generated token: vampires\n", "Generated token: are\n", "Generated token: the\n", "Generated token: vampires\n", "Generated token: ,\n", "Generated token: vampires\n", "Generated token: are\n", "Generated token: living\n", "Generated token: ,\n", "Generated token: who\n", "Generated token: believe\n", "Generated token: the\n", "Generated token: vampires\n", "Generated token: .\n", "Generated token: vampire\n", "Generated token: ,\n", "Generated token: and\n", "Generated token: who\n", "Generated token: are\n", "Generated token: vampires\n", "Generated token: ,\n", "Generated token: who\n", "Generated token: are\n", "Generated token: also\n", "Generated token: .\n", "Generated token: vampires\n", "Generated token: who\n", "Generated token: do\n" ] } ], "source": [ "import torch\n", "import torch.nn.functional as F\n", "\n", "question = \"Answer: Vampires are real. Question:\"\n", "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", "\n", "temperature = 0.7 # Set your temperature here (e.g., 0.7 for less randomness)\n", "\n", "with torch.no_grad():\n", " for i in range(30):\n", " # Get logits for the last token in the sequence\n", " logits = model(tokenized_question[\"input_ids\"]).logits[0, -1, :]\n", " \n", " # Apply temperature scaling\n", " logits = logits / temperature\n", " \n", " # Convert logits to probabilities using softmax\n", " probs = F.softmax(logits, dim=-1)\n", " \n", " # Sample from the distribution\n", " next_token_id = torch.multinomial(probs, num_samples=1).item()\n", " \n", " # Update the input_ids with the new token\n", " tokenized_question['input_ids'] = torch.cat([tokenized_question['input_ids'], torch.tensor([[next_token_id]])], dim=1)\n", " \n", " # Decode the next token for readability\n", " next_token = tokenizer.decode(next_token_id)\n", " \n", " print(f\"Generated token: {next_token}\")\n" ] }, { "cell_type": "code", "execution_count": 129, "execution_state": "idle", "id": "e8bffe5d-d830-4992-9381-b484672ffeda", "metadata": {}, "outputs": [ { "ename": "IndexError", "evalue": "too many indices for tensor of dimension 1", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[129], line 35\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Create candidates\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(beam_width):\n\u001b[0;32m---> 35\u001b[0m token_id \u001b[38;5;241m=\u001b[39m \u001b[43mtop_k_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 36\u001b[0m token_prob \u001b[38;5;241m=\u001b[39m top_k_probs[\u001b[38;5;241m0\u001b[39m, i]\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# Create a new sequence by appending the token to the existing sequence\u001b[39;00m\n", "\u001b[0;31mIndexError\u001b[0m: too many indices for tensor of dimension 1" ] } ], "source": [ "import torch\n", "import torch.nn.functional as F\n", "\n", "question = \"Question: Are Vampires real. Answer:\"\n", "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n", "\n", "# Parameters\n", "beam_width = 3 # The number of beams to consider\n", "max_length = 30 # Maximum number of tokens to generate\n", "temperature = 1.0 # Temperature for softmax\n", "\n", "# Initialize beams\n", "beams = [(tokenized_question['input_ids'], 0.0)] # Each beam is a tuple (sequence, score)\n", "finished_beams = []\n", "\n", "with torch.no_grad():\n", " for step in range(max_length):\n", " all_candidates = []\n", " \n", " for seq, score in beams:\n", " # Get logits for the last token in the sequence\n", " logits = model(input_ids=seq).logits[0, -1, :]\n", " \n", " # Apply temperature scaling\n", " logits = logits / temperature\n", " \n", " # Convert logits to probabilities using softmax\n", " probs = F.softmax(logits, dim=-1)\n", " \n", " # Get top-k candidate tokens and their probabilities\n", " top_k_probs, top_k_ids = torch.topk(probs, beam_width, dim=-1)\n", " \n", " # Create candidates\n", " for i in range(beam_width):\n", " token_id = top_k_ids[0, i].item()\n", " token_prob = top_k_probs[0, i].item()\n", " \n", " # Create a new sequence by appending the token to the existing sequence\n", " new_seq = torch.cat([seq, torch.tensor([[token_id]])], dim=1)\n", " \n", " # Update the score (cumulative log probability)\n", " new_score = score + torch.log(torch.tensor(token_prob))\n", " \n", " # If the token is the end-of-sequence token, consider it a finished beam\n", " if token_id == tokenizer.eos_token_id:\n", " finished_beams.append((new_seq, new_score))\n", " else:\n", " all_candidates.append((new_seq, new_score))\n", " \n", " # Sort candidates by score (highest first) and select top-k\n", " beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]\n", " \n", " # If no beams are left to explore, break\n", " if not beams:\n", " break\n", " \n", " # If we have finished beams, we can stop early\n", " if len(finished_beams) >= beam_width:\n", " break\n", "\n", "# Decode and print the top finished sequences\n", "for seq, score in sorted(finished_beams, key=lambda x: x[1], reverse=True):\n", " generated_text = tokenizer.decode(seq[0], skip_special_tokens=True)\n", " print(f\"Generated text: {generated_text}\")\n", " print(f\"Score: {score.item()}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "218621c4-2b18-4e6e-8fba-adde97038ae8", "metadata": {}, "outputs": [], "source": [ "WANDB_MODE=offline CUDA_VISIBLE_DEVICES= accelerate launch --mixed_precision bf16 finetune_QA.py \\\n", "--model_direction ltr \\\n", "--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \\\n", "--tokenizer_name distilbert/distilbert-base-uncased \\\n", "--warmup_steps 100 \\\n", "--learning_rate 1e-5 \\\n", "--per_device_train_batch_size 128 \\\n", "--per_device_eval_batch_size 128 \\\n", "--output_dir checkpoints/qa_distilbert_base_ltr/ \\\n", "--eval_steps 38 \\\n", "--block_size 128 \\\n", "--num_train_epochs 50 \\\n", "--weight_decay 1e-4\n", "\n", "yes but we overfit to this example\n", "\n", "oh wait we did actually" ] }, { "cell_type": "markdown", "id": "f5e37c9a-156a-4fcc-bd05-5ae6c3671e62", "metadata": {}, "source": [ "wait so now it does really good because we're actually doing it the same way that we trained it?\n", "\n", "\n", "\n", "thats fine, let's train it on everything again!\n", "\n", "\n", "i don't think that's true... but low perplexity!\n", "\n", "\n", "\n", "the question was supposed to be, \n", "what happens if you swallow a watermellon seed?" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }