{ "cells": [ { "cell_type": "code", "execution_count": 1, "execution_state": "idle", "id": "1ddfc692-bda7-4d38-a549-2fb0d40d437d", "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import re\n", "import sys\n", "\n", "sys.path.append(\"..\")\n", "\n", "import torch\n", "import transformers\n", "from safetensors import safe_open\n", "\n", "from utils import add_attn_hooks\n", "\n", "# text_dir = \"rtl\"\n", "text_dir = \"ltr\"\n", "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", "# model = transformers.AutoModelForMaskedLM.from_pretrained(f\"/home/sipb/nlp-class-project/checkpoints/bert_base_{text_dir}/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n", "\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")" ] }, { "cell_type": "code", "execution_count": 6, "execution_state": "idle", "id": "5bd236ae-119c-4ea6-9a0f-03272f528caf", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 19, "execution_state": "idle", "id": "eaf99031-9141-43dd-89ba-be9b8e63a1ba", "metadata": {}, "outputs": [], "source": [ "with open(\"/home/sipb/nlp-class-project/data/riddles.txt\", \"r\") as f:\n", " riddles = [line.rstrip() for line in f.readlines()]\n", "# with open(\"/home/sipb/nlp-class-project/data/ltr_riddles.txt\", \"r\") as f:\n", "# riddles_aq = [line.rstrip() for line in f.readlines()]" ] }, { "cell_type": "code", "execution_count": 22, "execution_state": "idle", "id": "9652550f-61b9-4b9b-ad10-9d9873a9e80b", "metadata": {}, "outputs": [], "source": [ "pattern = r\"^(.*)(What am I\\?\\s*)(.+)$\"\n", "\n", "riddles_qa = []\n", "riddles_aq = []\n", "for riddle in riddles:\n", " riddles_qa.append(re.sub(pattern, r\"Q: \\1\\2A: \\3\", riddle))\n", " riddles_aq.append(re.sub(pattern, r\"A: \\3 Q: \\1\\2\", riddle).rstrip())" ] }, { "cell_type": "code", "execution_count": 28, "execution_state": "idle", "id": "94da0be0-d6ef-46be-9fff-4ebf022e4fed", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/distilbert_base_rtl/epoch_3_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_rtl/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_large_ltr/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_large_rtl/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "results = []\n", "# path_prefixes = [f\"bert_{size}\" for size in (6, 11, 19, 35)] + [\"distilbert_base\"]\n", "# for path_prefix in path_prefixes:\n", "# for text_dir in (\"ltr\", \"rtl\"):\n", "# checkpt_dir = f\"/home/sipb/nlp-class-project/checkpoints/{path_prefix}_{text_dir}_scratch/epoch_3_checkpt\"\n", "\n", "path_prefixes = [\"distilbert_base\", \"bert_base\", \"bert_large\"]\n", "for path_prefix in path_prefixes:\n", " for text_dir in (\"ltr\", \"rtl\"):\n", " checkpt_dir = f\"/home/sipb/nlp-class-project/checkpoints/{path_prefix}_{text_dir}/epoch_3_checkpt\"\n", " try:\n", " model = transformers.AutoModelForMaskedLM.from_pretrained(checkpt_dir)\n", " except:\n", " config = transformers.AutoConfig.from_pretrained(os.path.join(checkpt_dir, \"config.json\"))\n", " config.max_position_embeddings = 512\n", " model = transformers.AutoModelForMaskedLM.from_pretrained(checkpt_dir, config=config)\n", " \n", " add_attn_hooks(model, text_dir)\n", " model.eval();\n", "\n", " for dataset_type, dataset in [\n", " (\"qa\", riddles_qa),\n", " (\"aq\", riddles_aq),\n", " ]:\n", " ppls = []\n", " for riddle in dataset:\n", " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n", " batch[\"labels\"] = batch[\"input_ids\"].clone()\n", " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n", " with torch.inference_mode():\n", " output = model(**batch)\n", " ppls.append(math.e ** output.loss.item())\n", "\n", " results.append((sum(ppls) / len(ppls), dataset_type, text_dir, path_prefix))" ] }, { "cell_type": "code", "execution_count": 26, "execution_state": "idle", "id": "bdee66ad-65ad-40c7-ac86-9a2d6b8fba02", "metadata": {}, "outputs": [], "source": [ "to_params = {\n", " \"bert_6\": 6,\n", " \"bert_11\": 11,\n", " \"bert_19\": 19,\n", " \"bert_35\": 35,\n", " \"distilbert_base\": 67,\n", " \"bert_base\": 110,\n", " \"bert_large\": 335,\n", "}" ] }, { "cell_type": "code", "execution_count": 27, "execution_state": "idle", "id": "d1668465-fe85-4310-8d88-031d4b8d361f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LTR & 6M & AQ & 1980 \\\\\n", "RTL & 6M & AQ & 1440 \\\\\n", "LTR & 11M & AQ & 1600 \\\\\n", "RTL & 11M & AQ & 1490 \\\\\n", "LTR & 19M & AQ & 2310 \\\\\n", "RTL & 19M & AQ & 2740 \\\\\n", "LTR & 35M & AQ & 3650 \\\\\n", "RTL & 35M & AQ & 4090 \\\\\n", "LTR & 67M & AQ & 6360 \\\\\n", "RTL & 67M & AQ & 22900 \\\\\n", "LTR & 6M & QA & 1790 \\\\\n", "RTL & 6M & QA & 1850 \\\\\n", "LTR & 11M & QA & 1430 \\\\\n", "RTL & 11M & QA & 1820 \\\\\n", "LTR & 19M & QA & 2280 \\\\\n", "RTL & 19M & QA & 3740 \\\\\n", "LTR & 35M & QA & 3690 \\\\\n", "RTL & 35M & QA & 4650 \\\\\n", "LTR & 67M & QA & 6340 \\\\\n", "RTL & 67M & QA & 22900 \\\\\n" ] } ], "source": [ "for ppl, task, text_dir, path_prefix in sorted(results, key=lambda x: (x[1], to_params[x[3]], x[2])):\n", " ppl = int(float(f\"{ppl:.3g}\"))\n", " print(rf\"{text_dir.upper()} & {to_params[path_prefix]}M & {task.upper()} & {ppl} \\\\\")" ] }, { "cell_type": "code", "execution_count": 29, "execution_state": "idle", "id": "e9ac0af8-2638-4076-b0cc-9ec9355c2c01", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LTR & 67M & AQ & 969 \\\\\n", "RTL & 67M & AQ & 675 \\\\\n", "LTR & 110M & AQ & 1880 \\\\\n", "RTL & 110M & AQ & 483 \\\\\n", "LTR & 335M & AQ & 2610 \\\\\n", "RTL & 335M & AQ & 695 \\\\\n", "LTR & 67M & QA & 952 \\\\\n", "RTL & 67M & QA & 781 \\\\\n", "LTR & 110M & QA & 2020 \\\\\n", "RTL & 110M & QA & 689 \\\\\n", "LTR & 335M & QA & 2950 \\\\\n", "RTL & 335M & QA & 782 \\\\\n" ] } ], "source": [ "for ppl, task, text_dir, path_prefix in sorted(results, key=lambda x: (x[1], to_params[x[3]], x[2])):\n", " ppl = int(float(f\"{ppl:.3g}\"))\n", " print(rf\"{text_dir.upper()} & {to_params[path_prefix]}M & {task.upper()} & {ppl} \\\\\")" ] }, { "cell_type": "code", "execution_count": 48, "execution_state": "idle", "id": "8894ca16-58e3-4448-bec8-c962f5135737", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "the more you take, the more you leave behind. what am i? @ @ @ @ @ @ @ @ @ @ @ @ @ the @ @ @ @ ( the the the. @ the @ @ ( @ @ ( @ @ @ @ ( the.. @ ( @ ) @ the @ the the\n" ] } ], "source": [ "# input_text = [\"The more you take, the more you leave behind. What am I?\"]\n", "# batch = tokenizer(input_text, return_tensors=\"pt\", padding_side=\"right\", padding=\"max_length\", max_length=64)\n", "# output_ids = model.generate(batch['input_ids'], max_length=128, do_sample=False) # do_sample=False ensures greedy decoding\n", "# decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n", "# print(decoded_output)" ] }, { "cell_type": "code", "execution_count": 47, "execution_state": "idle", "id": "a4098975-2df6-4435-bc93-1a5afd6d7e68", "metadata": {}, "outputs": [], "source": [ "# CPU is fast enough\n", "\n", "ppls = []\n", "for riddle in riddles:\n", " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n", " batch[\"labels\"] = batch[\"input_ids\"].clone()\n", " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n", " # batch = tokenizer([riddle], return_tensors=\"pt\")#, padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"longest\", max_length=128)\n", " # batch[\"labels\"] = batch[\"input_ids\"]\n", " with torch.inference_mode():\n", " output = model(**batch)\n", " ppls.append(math.e ** output.loss.item())" ] }, { "cell_type": "code", "execution_count": 16, "execution_state": "idle", "id": "c68b5235-a4a7-4f38-9acb-f5072e546a96", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([ 4., 6., 11., 6., 5., 2., 1., 1., 2., 2.]),\n", " array([ 613.56297843, 829.36555779, 1045.16813716, 1260.97071653,\n", " 1476.77329589, 1692.57587526, 1908.37845463, 2124.18103399,\n", " 2339.98361336, 2555.78619272, 2771.58877209]),\n", " )" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.hist(ppls)" ] }, { "cell_type": "code", "execution_count": null, "id": "86daa05b-5784-457b-b65e-8b8395128d6f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }