{ "cells": [ { "cell_type": "code", "execution_count": 2, "execution_state": "idle", "id": "1ddfc692-bda7-4d38-a549-2fb0d40d437d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n", " - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n", " - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n", " - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n", "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForMaskedLM were not initialized from the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt and are newly initialized because the shapes did not match:\n", "- bert.embeddings.position_embeddings.weight: found shape torch.Size([512, 768]) in the checkpoint and torch.Size([128, 768]) in the model instantiated\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "import math\n", "import sys\n", "\n", "sys.path.append(\"..\")\n", "\n", "import torch\n", "import transformers\n", "\n", "import pandas as pd\n", "\n", "from utils import add_attn_hooks\n", "\n", "# text_dir = \"rtl\"\n", "text_dir = \"ltr\"\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", "model = transformers.AutoModelForMaskedLM.from_pretrained(f\"/home/sipb/nlp-class-project/checkpoints/bert_base_{text_dir}/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n", "\n", "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n", "# model = transformers.AutoModelForMaskedLM.from_pretrained(f\"/home/sipb/nlp-class-project/checkpoints/distilbert_base_{text_dir}/epoch_3_checkpt\", ignore_mismatched_sizes=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "145d2ffd-db55-4b8f-9fbb-85a51e0b3d11", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "execution_state": "idle", "id": "a732375b-1682-45c6-8df0-8db1458559c9", "metadata": {}, "outputs": [], "source": [ "add_attn_hooks(model, text_dir)\n", "model.eval();" ] }, { "cell_type": "code", "execution_count": 5, "execution_state": "idle", "id": "041d1702-5aaf-45f0-9413-4014b315d1ed", "metadata": {}, "outputs": [], "source": [ "df = pd.read_parquet('/home/sipb/nlp-class-project/data/japan.parquet')" ] }, { "cell_type": "code", "execution_count": 14, "execution_state": "idle", "id": "2bace74b-a716-4d49-a912-53155cf002ba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "b'\n", "_START_ARTICLE_\n", "ビートたけしの教科書に載らない日本人の謎\n", "_START_SECTION_\n", "概要\n", "_START_PARAGRAPH_\n", "「教科書には決して載らない」日本人の謎やしきたりを多角的に検証し、日本人のDNAを解明する。_NEWLINE_新春番組として定期的に放送されており、年末の午前中に再放送されるのが恒例となっている。'\n" ] } ], "source": [ "df[\"text\"][0]\n", "import codecs\n", "decoded_str = codecs.escape_decode(df[\"text\"][0])[0].decode('utf-8')\n", "print(decoded_str)" ] }, { "cell_type": "code", "execution_count": 32, "execution_state": "idle", "id": "8a9147ea-d9dc-4826-8030-c8417609405d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "q : where do pandas live? a : (,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, and,, (,,,,,,.,,,,,,,,,,,,, and\n" ] } ], "source": [ "input_text = [\"Q: Where do pandas live? A:\"]#, \"ビートたけしの教科書に載らない日\"]\n", "batch = tokenizer(input_text, return_tensors=\"pt\", padding_side=\"right\", padding=\"max_length\", max_length=64)\n", "output_ids = model.generate(batch['input_ids'], max_length=128, do_sample=False) # do_sample=False ensures greedy decoding\n", "decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n", "print(decoded_output)" ] }, { "cell_type": "code", "execution_count": 17, "execution_state": "idle", "id": "1a7c9b35-0c07-431d-91df-bd2f8c7467eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MaskedLMOutput(loss=None, logits=tensor([[[ -7.9645, -7.6722, -7.8979, ..., -8.6562, -8.2586, -6.7448],\n", " [-11.1255, -11.2591, -11.3443, ..., -10.1338, -11.9891, -10.2974],\n", " [ -8.1256, -8.1880, -7.9874, ..., -8.0597, -8.6987, -10.2472],\n", " ...,\n", " [-14.5633, -14.4418, -14.4735, ..., -14.5651, -14.2234, -13.5610],\n", " [-18.9095, -18.6487, -18.7593, ..., -19.1327, -18.8564, -17.4334],\n", " [-17.8532, -17.6451, -17.7208, ..., -18.0046, -17.7334, -16.5670]]]), hidden_states=None, attentions=None)\n" ] } ], "source": [ "with torch.inference_mode():\n", " batch = tokenizer([\"ビートたけしの教科書に載らない日本人の謎\"], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n", " output = model(**batch)\n", " print(output)\n", " " ] }, { "cell_type": "code", "execution_count": 15, "execution_state": "idle", "id": "a4098975-2df6-4435-bc93-1a5afd6d7e68", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'riddles' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[15], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# CPU is fast enough\u001b[39;00m\n\u001b[1;32m 3\u001b[0m ppls \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m riddle \u001b[38;5;129;01min\u001b[39;00m \u001b[43mriddles\u001b[49m:\n\u001b[1;32m 5\u001b[0m batch \u001b[38;5;241m=\u001b[39m tokenizer([riddle], return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, padding_side\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mleft\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m text_dir \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrtl\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mright\u001b[39m\u001b[38;5;124m\"\u001b[39m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_length\u001b[39m\u001b[38;5;124m\"\u001b[39m, max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m128\u001b[39m)\n\u001b[1;32m 6\u001b[0m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mclone()\n", "\u001b[0;31mNameError\u001b[0m: name 'riddles' is not defined" ] } ], "source": [ "# CPU is fast enough\n", "\n", "ppls = []\n", "for riddle in riddles:\n", " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n", " batch[\"labels\"] = batch[\"input_ids\"].clone()\n", " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n", " # batch = tokenizer([riddle], return_tensors=\"pt\")#, padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"longest\", max_length=128)\n", " # batch[\"labels\"] = batch[\"input_ids\"]\n", " with torch.inference_mode():\n", " output = model(**batch)\n", " ppls.append(math.e ** output.loss.item())" ] }, { "cell_type": "code", "execution_count": 39, "execution_state": "idle", "id": "c4a82af4-d0d8-415a-9135-3a1350c1402e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(692.7175314596647, 'rtl')" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(ppls) / len(ppls), text_dir" ] }, { "cell_type": "code", "execution_count": 43, "execution_state": "idle", "id": "84a95c66-6dd3-4ccb-96a2-96f38008f70e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(616.6241458855995, 'ltr')" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(ppls) / len(ppls), text_dir" ] }, { "cell_type": "code", "execution_count": 35, "execution_state": "idle", "id": "51ed80f1-a935-42bc-8194-832f91222c45", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(526.979384061791, 'rtl')" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(ppls) / len(ppls), text_dir # distilbert" ] }, { "cell_type": "code", "execution_count": 31, "execution_state": "idle", "id": "34a2edec-b1d9-466c-a457-954c587f7817", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(288.22724792187364, 'ltr')" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(ppls) / len(ppls), text_dir # distilbert" ] }, { "cell_type": "code", "execution_count": 21, "execution_state": "idle", "id": "40a98c10-59c3-498a-a9e6-c23bd9437bc7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "937.8557468023619" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(ppls) / len(ppls)" ] }, { "cell_type": "code", "execution_count": 58, "execution_state": "idle", "id": "80b22ba1-e5ba-4f1e-8038-158a2c2f37a6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 1045, 2064, 2022, 2524, 1010, 2021, 1045, 2572, 2025,\n", " 5024, 1012, 2054, 2572, 1045, 1029, 1037, 15117, 1012, 102,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([[ 101, 1045, 2064, 2022, 2524, 1010, 2021, 1045, 2572, 2025,\n", " 5024, 1012, 2054, 2572, 1045, 1029, 1037, 15117, 1012, 102,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100, -100, -100, -100, -100, -100]])}" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch" ] }, { "cell_type": "code", "execution_count": 16, "execution_state": "idle", "id": "c68b5235-a4a7-4f38-9acb-f5072e546a96", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([ 4., 6., 11., 6., 5., 2., 1., 1., 2., 2.]),\n", " array([ 613.56297843, 829.36555779, 1045.16813716, 1260.97071653,\n", " 1476.77329589, 1692.57587526, 1908.37845463, 2124.18103399,\n", " 2339.98361336, 2555.78619272, 2771.58877209]),\n", " )" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.hist(ppls)" ] }, { "cell_type": "code", "execution_count": 14, "execution_state": "idle", "id": "8acad3ce-905d-455e-af5d-9770495f374a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414,\n", " 956.7294281325414]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ppls" ] }, { "cell_type": "code", "execution_count": null, "id": "86daa05b-5784-457b-b65e-8b8395128d6f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }