aboutsummaryrefslogtreecommitdiff
path: root/notebooks/Inference.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/Inference.ipynb')
-rw-r--r--notebooks/Inference.ipynb638
1 files changed, 638 insertions, 0 deletions
diff --git a/notebooks/Inference.ipynb b/notebooks/Inference.ipynb
new file mode 100644
index 0000000..670a127
--- /dev/null
+++ b/notebooks/Inference.ipynb
@@ -0,0 +1,638 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "34c536f2-3ccb-4df0-bd47-913d6ef040a2",
+ "metadata": {},
+ "source": [
+ "# Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "execution_state": "idle",
+ "id": "3c6381c4-2a02-415f-a5f1-450fe42b30d3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "\n",
+ "from utils import add_attn_hooks\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr/epoch_49_checkpt\", ignore_mismatched_sizes=True)\n",
+ "model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt\", ignore_mismatched_sizes=True)\n",
+ "\n",
+ "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "execution_state": "idle",
+ "id": "41edf867-7f6a-4d44-871c-8af0b7af7543",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from utils import add_attn_hooks\n",
+ "add_attn_hooks(model, \"ltr\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "execution_state": "idle",
+ "id": "d79a3d16-b982-42ed-9b6a-fad8328e177e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.eval();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "execution_state": "idle",
+ "id": "f06d4bc2-9df8-42c2-9397-3bfb3728da0b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "from typing import override\n",
+ "class DecoderMLM(transformers.AutoModelForMaskedLM, transformers.GenerationMixin):\n",
+ " @override\n",
+ " # @classmethod\n",
+ " def can_generate(cls):\n",
+ " return True\n",
+ "\n",
+ "model2 = DecoderMLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "execution_state": "idle",
+ "id": "6feb4fdd-ae43-466d-8dce-a4f9a632a5e6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# model2.can_generate = (lambda s: True)\n",
+ "model2.can_generate = (lambda: True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "execution_state": "idle",
+ "id": "c178761c-7124-42ed-9bfc-7ab0f782aad7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "execution_state": "idle",
+ "id": "3ab53852-f333-47ea-9e96-55266cda84a6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model2.can_generate()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "execution_state": "idle",
+ "id": "a30f6240-f982-45b0-b75f-3be5bbb43049",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "execution_state": "idle",
+ "id": "2cefa784-a1cc-445a-8ffa-066e7cfccaf0",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq'].",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[25], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m tokenized_question \u001b[38;5;241m=\u001b[39m tokenizer(question, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39minference_mode():\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokenized_question\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_beams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdo_sample\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.\u001b[39;49m\u001b[43m)\u001b[49m))\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1967\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1882\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1883\u001b[0m \n\u001b[1;32m 1884\u001b[0m \u001b[38;5;124;03mGenerates sequences of token ids for models with a language modeling head.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1963\u001b[0m \u001b[38;5;124;03m - [`~generation.GenerateBeamEncoderDecoderOutput`]\u001b[39;00m\n\u001b[1;32m 1964\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1966\u001b[0m \u001b[38;5;66;03m# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call\u001b[39;00m\n\u001b[0;32m-> 1967\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_model_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1968\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1969\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1269\u001b[0m, in \u001b[0;36mGenerationMixin._validate_model_class\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1262\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcan_generate():\n\u001b[1;32m 1263\u001b[0m terminations_with_generation_support \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 1264\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForCausalLM\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForConditionalGeneration\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForSpeechSeq2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForVision2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1268\u001b[0m ]\n\u001b[0;32m-> 1269\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 1270\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe current model class (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) is not compatible with `.generate()`, as \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mit doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt have a language model head. Classes that support generation often end in one of these \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mterminations_with_generation_support\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1273\u001b[0m )\n",
+ "\u001b[0;31mTypeError\u001b[0m: The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq']."
+ ]
+ }
+ ],
+ "source": [
+ "question = \"Answer: Grapes are toxic to foxes in large quantities\"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "with torch.inference_mode():\n",
+ " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True, temperature=1.)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c475851e-7c86-46fd-b4ad-cd51caa6e7b8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "question = \"Apples are red and \"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "with torch.inference_mode():\n",
+ " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "execution_state": "idle",
+ "id": "1365ab75-a022-42fe-9168-d49a645af0d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :'"
+ ]
+ },
+ "execution_count": 97,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode([3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n",
+ " 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n",
+ " 1024])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "execution_state": "idle",
+ "id": "58bffbb3-77fb-4f57-a77e-303fca05a84f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated token: what\n",
+ "Generated token: is\n",
+ "Generated token: the\n",
+ "Generated token: chicken\n",
+ "Generated token: cry\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: question\n",
+ "Generated token: :\n",
+ "Generated token: what\n",
+ "Generated token: is\n",
+ "Generated token: the\n",
+ "Generated token: chicken\n",
+ "Generated token: cry\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: question\n",
+ "Generated token: :\n",
+ "Generated token: what\n",
+ "Generated token: is\n",
+ "Generated token: the\n",
+ "Generated token: chicken\n",
+ "Generated token: cry\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: what\n",
+ "Generated token: is\n"
+ ]
+ }
+ ],
+ "source": [
+ "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n",
+ "question = \"answer : your chicken cries. [SEP] [CLS] question :\"\n",
+ "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n",
+ "\n",
+ "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n",
+ "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n",
+ "#tokenized_question = {\n",
+ "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n",
+ "#}\n",
+ "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n",
+ "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n",
+ "# 1024]])\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for i in range(30):\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(input_ids).logits[0, -1, :]\n",
+ " \n",
+ " # Select the token with the highest probability\n",
+ " next_token_id = torch.argmax(logits).item()\n",
+ " \n",
+ " # Update the input_ids with the new token\n",
+ " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n",
+ " \n",
+ " # Decode the next token for readability\n",
+ " next_token = tokenizer.decode(next_token_id)\n",
+ " \n",
+ " print(f\"Generated token: {next_token}\")\n",
+ "# but we did train on incorrect answers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "execution_state": "idle",
+ "id": "972352ea-4c28-42d9-a834-26daa83b2290",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated token: how\n",
+ "Generated token: long\n",
+ "Generated token: should\n",
+ "Generated token: you\n",
+ "Generated token: wait\n",
+ "Generated token: before\n",
+ "Generated token: filing\n",
+ "Generated token: a\n",
+ "Generated token: missing\n",
+ "Generated token: person\n",
+ "Generated token: report\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: question\n",
+ "Generated token: :\n",
+ "Generated token: how\n",
+ "Generated token: long\n",
+ "Generated token: should\n",
+ "Generated token: you\n",
+ "Generated token: wait\n",
+ "Generated token: before\n",
+ "Generated token: filing\n",
+ "Generated token: a\n",
+ "Generated token: missing\n",
+ "Generated token: person\n",
+ "Generated token: report\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n",
+ "question = \"answer : you are late to work. [SEP] [CLS] question :\"\n",
+ "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n",
+ "\n",
+ "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n",
+ "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n",
+ "#tokenized_question = {\n",
+ "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n",
+ "#}\n",
+ "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n",
+ "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n",
+ "# 1024]])\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for i in range(30):\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(input_ids).logits[0, -1, :]\n",
+ " \n",
+ " # Select the token with the highest probability\n",
+ " next_token_id = torch.argmax(logits).item()\n",
+ " \n",
+ " # Update the input_ids with the new token\n",
+ " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n",
+ " \n",
+ " # Decode the next token for readability\n",
+ " next_token = tokenizer.decode(next_token_id)\n",
+ " \n",
+ " print(f\"Generated token: {next_token}\")\n",
+ "# but we did train on incorrect answers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "execution_state": "idle",
+ "id": "6b1949d3-343d-49ba-b50e-79dc36d9124e",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "AttributeError",
+ "evalue": "'DistilBertForMaskedLM' object has no attribute 'tokenizer'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer\u001b[49m\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1931\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1929\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1930\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1931\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[1;32m 1932\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1933\u001b[0m )\n",
+ "\u001b[0;31mAttributeError\u001b[0m: 'DistilBertForMaskedLM' object has no attribute 'tokenizer'"
+ ]
+ }
+ ],
+ "source": [
+ "model.tokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "execution_state": "idle",
+ "id": "f57b41c0-2056-4164-914b-f298ad66c0c5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated token: [CLS]\n",
+ "Generated token: the\n",
+ "Generated token: vampires\n",
+ "Generated token: are\n",
+ "Generated token: the\n",
+ "Generated token: vampires\n",
+ "Generated token: ,\n",
+ "Generated token: vampires\n",
+ "Generated token: are\n",
+ "Generated token: living\n",
+ "Generated token: ,\n",
+ "Generated token: who\n",
+ "Generated token: believe\n",
+ "Generated token: the\n",
+ "Generated token: vampires\n",
+ "Generated token: .\n",
+ "Generated token: vampire\n",
+ "Generated token: ,\n",
+ "Generated token: and\n",
+ "Generated token: who\n",
+ "Generated token: are\n",
+ "Generated token: vampires\n",
+ "Generated token: ,\n",
+ "Generated token: who\n",
+ "Generated token: are\n",
+ "Generated token: also\n",
+ "Generated token: .\n",
+ "Generated token: vampires\n",
+ "Generated token: who\n",
+ "Generated token: do\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "question = \"Answer: Vampires are real. Question:\"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "temperature = 0.7 # Set your temperature here (e.g., 0.7 for less randomness)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for i in range(30):\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(tokenized_question[\"input_ids\"]).logits[0, -1, :]\n",
+ " \n",
+ " # Apply temperature scaling\n",
+ " logits = logits / temperature\n",
+ " \n",
+ " # Convert logits to probabilities using softmax\n",
+ " probs = F.softmax(logits, dim=-1)\n",
+ " \n",
+ " # Sample from the distribution\n",
+ " next_token_id = torch.multinomial(probs, num_samples=1).item()\n",
+ " \n",
+ " # Update the input_ids with the new token\n",
+ " tokenized_question['input_ids'] = torch.cat([tokenized_question['input_ids'], torch.tensor([[next_token_id]])], dim=1)\n",
+ " \n",
+ " # Decode the next token for readability\n",
+ " next_token = tokenizer.decode(next_token_id)\n",
+ " \n",
+ " print(f\"Generated token: {next_token}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 129,
+ "execution_state": "idle",
+ "id": "e8bffe5d-d830-4992-9381-b484672ffeda",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "IndexError",
+ "evalue": "too many indices for tensor of dimension 1",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[129], line 35\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Create candidates\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(beam_width):\n\u001b[0;32m---> 35\u001b[0m token_id \u001b[38;5;241m=\u001b[39m \u001b[43mtop_k_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 36\u001b[0m token_prob \u001b[38;5;241m=\u001b[39m top_k_probs[\u001b[38;5;241m0\u001b[39m, i]\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# Create a new sequence by appending the token to the existing sequence\u001b[39;00m\n",
+ "\u001b[0;31mIndexError\u001b[0m: too many indices for tensor of dimension 1"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "question = \"Question: Are Vampires real. Answer:\"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "# Parameters\n",
+ "beam_width = 3 # The number of beams to consider\n",
+ "max_length = 30 # Maximum number of tokens to generate\n",
+ "temperature = 1.0 # Temperature for softmax\n",
+ "\n",
+ "# Initialize beams\n",
+ "beams = [(tokenized_question['input_ids'], 0.0)] # Each beam is a tuple (sequence, score)\n",
+ "finished_beams = []\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for step in range(max_length):\n",
+ " all_candidates = []\n",
+ " \n",
+ " for seq, score in beams:\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(input_ids=seq).logits[0, -1, :]\n",
+ " \n",
+ " # Apply temperature scaling\n",
+ " logits = logits / temperature\n",
+ " \n",
+ " # Convert logits to probabilities using softmax\n",
+ " probs = F.softmax(logits, dim=-1)\n",
+ " \n",
+ " # Get top-k candidate tokens and their probabilities\n",
+ " top_k_probs, top_k_ids = torch.topk(probs, beam_width, dim=-1)\n",
+ " \n",
+ " # Create candidates\n",
+ " for i in range(beam_width):\n",
+ " token_id = top_k_ids[0, i].item()\n",
+ " token_prob = top_k_probs[0, i].item()\n",
+ " \n",
+ " # Create a new sequence by appending the token to the existing sequence\n",
+ " new_seq = torch.cat([seq, torch.tensor([[token_id]])], dim=1)\n",
+ " \n",
+ " # Update the score (cumulative log probability)\n",
+ " new_score = score + torch.log(torch.tensor(token_prob))\n",
+ " \n",
+ " # If the token is the end-of-sequence token, consider it a finished beam\n",
+ " if token_id == tokenizer.eos_token_id:\n",
+ " finished_beams.append((new_seq, new_score))\n",
+ " else:\n",
+ " all_candidates.append((new_seq, new_score))\n",
+ " \n",
+ " # Sort candidates by score (highest first) and select top-k\n",
+ " beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]\n",
+ " \n",
+ " # If no beams are left to explore, break\n",
+ " if not beams:\n",
+ " break\n",
+ " \n",
+ " # If we have finished beams, we can stop early\n",
+ " if len(finished_beams) >= beam_width:\n",
+ " break\n",
+ "\n",
+ "# Decode and print the top finished sequences\n",
+ "for seq, score in sorted(finished_beams, key=lambda x: x[1], reverse=True):\n",
+ " generated_text = tokenizer.decode(seq[0], skip_special_tokens=True)\n",
+ " print(f\"Generated text: {generated_text}\")\n",
+ " print(f\"Score: {score.item()}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "218621c4-2b18-4e6e-8fba-adde97038ae8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "WANDB_MODE=offline CUDA_VISIBLE_DEVICES= accelerate launch --mixed_precision bf16 finetune_QA.py \\\n",
+ "--model_direction ltr \\\n",
+ "--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \\\n",
+ "--tokenizer_name distilbert/distilbert-base-uncased \\\n",
+ "--warmup_steps 100 \\\n",
+ "--learning_rate 1e-5 \\\n",
+ "--per_device_train_batch_size 128 \\\n",
+ "--per_device_eval_batch_size 128 \\\n",
+ "--output_dir checkpoints/qa_distilbert_base_ltr/ \\\n",
+ "--eval_steps 38 \\\n",
+ "--block_size 128 \\\n",
+ "--num_train_epochs 50 \\\n",
+ "--weight_decay 1e-4\n",
+ "\n",
+ "yes but we overfit to this example\n",
+ "\n",
+ "oh wait we did actually"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f5e37c9a-156a-4fcc-bd05-5ae6c3671e62",
+ "metadata": {},
+ "source": [
+ "wait so now it does really good because we're actually doing it the same way that we trained it?\n",
+ "\n",
+ "\n",
+ "\n",
+ "thats fine, let's train it on everything again!\n",
+ "\n",
+ "\n",
+ "i don't think that's true... but low perplexity!\n",
+ "\n",
+ "\n",
+ "\n",
+ "the question was supposed to be, \n",
+ "what happens if you swallow a watermellon seed?"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}