From 037acd689ee7f1f392d8209f1c3c5cabde90be5f Mon Sep 17 00:00:00 2001
From: SIPB
Date: Tue, 10 Dec 2024 22:19:47 -0500
Subject: Final commit
---
configs/bert_11M.json | 26 +
configs/bert_19M.json | 26 +
configs/bert_35M.json | 26 +
configs/bert_50M.json | 26 +
configs/bert_67M.json | 26 +
configs/bert_6M.json | 26 +
configs/test.json | 26 +
data/ltr_riddles.txt | 40 +
data/make-histogram-thing.ipynb | 546 +++++++++++
data/riddles.txt | 40 +
.../wandb_export_2024-12-04T19_56_43.325-05_00.csv | 21 +
finetune_QA.py | 304 ++++++
finetune_bert-japanese.py | 225 +++++
finetune_bert.py | 935 ++++++------------
notebooks/Inference.ipynb | 638 +++++++++++++
notebooks/Riddles.ipynb | 362 +++++++
notebooks/Riddles_FixedPos.ipynb | 302 ++++++
notebooks/Riddles_FixedPos_QAChars.ipynb | 345 +++++++
notebooks/Right_to_Left_NLP.ipynb | 712 +++++++-------
notebooks/Stat_Tests.ipynb | 487 ++++++++++
notebooks/addition.ipynb | 836 ++++++++++++++++
notebooks/japanese.ipynb | 489 ++++++++++
notebooks/loss | 1000 ++++++++++++++++++++
notebooks/loss.txt | 315 ++++++
notebooks/qa.ipynb | 247 +++++
notebooks/rtl.ipynb | 18 -
requirements.txt | 5 +-
utils.py | 150 ++-
28 files changed, 7176 insertions(+), 1023 deletions(-)
create mode 100644 configs/bert_11M.json
create mode 100644 configs/bert_19M.json
create mode 100644 configs/bert_35M.json
create mode 100644 configs/bert_50M.json
create mode 100644 configs/bert_67M.json
create mode 100644 configs/bert_6M.json
create mode 100644 configs/test.json
create mode 100644 data/ltr_riddles.txt
create mode 100644 data/make-histogram-thing.ipynb
create mode 100644 data/riddles.txt
create mode 100644 data/wandb_export_2024-12-04T19_56_43.325-05_00.csv
create mode 100644 finetune_QA.py
create mode 100644 finetune_bert-japanese.py
create mode 100644 notebooks/Inference.ipynb
create mode 100644 notebooks/Riddles.ipynb
create mode 100644 notebooks/Riddles_FixedPos.ipynb
create mode 100644 notebooks/Riddles_FixedPos_QAChars.ipynb
create mode 100644 notebooks/Stat_Tests.ipynb
create mode 100644 notebooks/addition.ipynb
create mode 100644 notebooks/japanese.ipynb
create mode 100644 notebooks/loss
create mode 100644 notebooks/loss.txt
create mode 100644 notebooks/qa.ipynb
diff --git a/configs/bert_11M.json b/configs/bert_11M.json
new file mode 100644
index 0000000..695789c
--- /dev/null
+++ b/configs/bert_11M.json
@@ -0,0 +1,26 @@
+{
+ "_name_or_path": "bert-base-uncased",
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 256,
+ "initializer_range": 0.02,
+ "intermediate_size": 1024,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 128,
+ "model_type": "bert",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 4,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.46.2",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
\ No newline at end of file
diff --git a/configs/bert_19M.json b/configs/bert_19M.json
new file mode 100644
index 0000000..891e78c
--- /dev/null
+++ b/configs/bert_19M.json
@@ -0,0 +1,26 @@
+{
+ "_name_or_path": "bert-base-uncased",
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 384,
+ "initializer_range": 0.02,
+ "intermediate_size": 1536,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 128,
+ "model_type": "bert",
+ "num_attention_heads": 6,
+ "num_hidden_layers": 4,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.46.2",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
\ No newline at end of file
diff --git a/configs/bert_35M.json b/configs/bert_35M.json
new file mode 100644
index 0000000..b697e04
--- /dev/null
+++ b/configs/bert_35M.json
@@ -0,0 +1,26 @@
+{
+ "_name_or_path": "bert-base-uncased",
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 2048,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 128,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 6,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.46.2",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
\ No newline at end of file
diff --git a/configs/bert_50M.json b/configs/bert_50M.json
new file mode 100644
index 0000000..9e2facb
--- /dev/null
+++ b/configs/bert_50M.json
@@ -0,0 +1,26 @@
+{
+ "_name_or_path": "bert-base-uncased",
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 640,
+ "initializer_range": 0.02,
+ "intermediate_size": 2560,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 128,
+ "model_type": "bert",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 6,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.46.2",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
\ No newline at end of file
diff --git a/configs/bert_67M.json b/configs/bert_67M.json
new file mode 100644
index 0000000..bac1d03
--- /dev/null
+++ b/configs/bert_67M.json
@@ -0,0 +1,26 @@
+{
+ "_name_or_path": "bert-base-uncased",
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 128,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 6,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.46.2",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
\ No newline at end of file
diff --git a/configs/bert_6M.json b/configs/bert_6M.json
new file mode 100644
index 0000000..b093bc5
--- /dev/null
+++ b/configs/bert_6M.json
@@ -0,0 +1,26 @@
+{
+ "_name_or_path": "bert-base-uncased",
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 160,
+ "initializer_range": 0.02,
+ "intermediate_size": 640,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 128,
+ "model_type": "bert",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 4,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.46.2",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
\ No newline at end of file
diff --git a/configs/test.json b/configs/test.json
new file mode 100644
index 0000000..b093bc5
--- /dev/null
+++ b/configs/test.json
@@ -0,0 +1,26 @@
+{
+ "_name_or_path": "bert-base-uncased",
+ "architectures": [
+ "BertForMaskedLM"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "classifier_dropout": null,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 160,
+ "initializer_range": 0.02,
+ "intermediate_size": 640,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 128,
+ "model_type": "bert",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 4,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "torch_dtype": "float32",
+ "transformers_version": "4.46.2",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
\ No newline at end of file
diff --git a/data/ltr_riddles.txt b/data/ltr_riddles.txt
new file mode 100644
index 0000000..1ea51e1
--- /dev/null
+++ b/data/ltr_riddles.txt
@@ -0,0 +1,40 @@
+I am footsteps. The more you take, the more you leave behind.
+I am darkness. The more you have of me, the less you see.
+I am fire. I am not alive, but I grow. I don’t have lungs, but I need air. I don’t have a mouth, but water kills me.
+I am a towel. I get wetter the more I dry.
+I am a candle. I become shorter the longer I live.
+I am breath. I am light as a feather, yet the strongest man can’t hold me for much longer than a minute.
+I am the wind. I am invisible, but you can feel me. I am intangible, but you can hear me.
+I am a piano. I have keys but open no locks.
+I am fog. The more of me there is, the less you see.
+I am the future. I am always in front of you, but you can never see me.
+I am a cloud. I don’t have wings, but I can fly. I don’t have eyes, but I can cry. Wherever I go, darkness follows me.
+I am a battery. I am not alive, but I can die.
+I am a heart. I can be stolen, but I can’t be touched.
+I am the letter "M." I am something that comes once in a minute, twice in a moment, but never in a thousand years.
+I am a promise. I can be cracked, but never broken.
+I am your shadow. I am always with you, but I can’t be seen.
+I am a footprint. I’m small but I can cover miles.
+I am a secret. The more you have of me, the less you know.
+I am light. I’m something that can fill a room, but I don’t take up space.
+I am your name. I am always with you, but you never see me. I can be forgotten, but I never leave.
+I am silence. The more of me there is, the less you hear.
+I am a stamp. I can travel around the world while staying in the corner.
+I am a code. I am something that can be cracked, but I can’t be touched.
+I am sound. I am something you can hear, but not touch. I can be loud or soft, but I can never be seen.
+I am the present moment. I am something that you can never keep, no matter how hard you try.
+I am a rumor. I am not alive, but I grow. I don’t have a mouth, but I can speak.
+I am a clock. I am always running, but I never move.
+I am a hole. I get bigger the more you take away.
+I am understanding. I can’t be seen, but I can be felt. I have no color, but I make things clear.
+I am a pencil. I get smaller the more you use me.
+I am a promise. I can be broken without being touched.
+I am time. I am something that everyone has, but no one can keep forever.
+I am a thought. I can be light as a feather, but even the strongest hands cannot hold me.
+I am a debt. The more you take from me, the greater I become.
+I am the horizon. I am often in front of you, but I’m never within reach.
+I am the sky. You can see me every day, but I will never be seen the same way twice.
+I am a reputation. I am not alive, but I grow over time.
+I am a feeling. I can’t be touched, but I can touch everything.
+I am a look. I never speak, but I can communicate.
+I am a deadline. I can be hard, but I am not solid.
\ No newline at end of file
diff --git a/data/make-histogram-thing.ipynb b/data/make-histogram-thing.ipynb
new file mode 100644
index 0000000..72e7dbb
--- /dev/null
+++ b/data/make-histogram-thing.ipynb
@@ -0,0 +1,546 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "execution_state": "idle",
+ "id": "7a21c467-a114-447d-bdb8-91778b59a3ad",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "csv_filename = 'wandb_export_2024-12-04T19_56_43.325-05_00.csv'\n",
+ "df = pd.read_csv(csv_filename)\n",
+ "# https://huggingface.co/datasets/ntotsuka123/ja-pretrain/viewer/default/train?p=1&row=120"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "execution_state": "idle",
+ "id": "0732274a-bc56-44a3-912f-e023c344bc56",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = df.drop([0, 1, 10, 11, 12, 15,16,17,18,19])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "execution_state": "idle",
+ "id": "23adfc0e-12af-4280-a31e-1601b1bfc3cf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def extract_size(name):\n",
+ " if 'distilbert_base' in name:\n",
+ " return '67M'\n",
+ " elif 'bert_6M' in name or 'bert_6_' in name:\n",
+ " return '6M'\n",
+ " elif 'bert_11' in name:\n",
+ " return '11M'\n",
+ " elif 'bert_19' in name:\n",
+ " return '19M'\n",
+ " elif 'bert_35' in name:\n",
+ " return '35M'\n",
+ " elif 'bert_base' in name:\n",
+ " return '110M' # Regular BERT base models have ~110M parameters\n",
+ " else:\n",
+ " return 'other'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "execution_state": "idle",
+ "id": "895cde04-f6f8-4f47-8f48-16008dd68a55",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "transfer_wikitext_df = pd.DataFrame({\n",
+ " 'model': ['67M', '67M', '110M', '110M', '335M', '335M'],\n",
+ " 'order': [0, 0, 1, 1, 2, 2],\n",
+ " 'direction': [\"LTR\", \"RTL\", \"LTR\", \"RTL\", \"LTR\", \"RTL\"],\n",
+ " 'ppl': [24.4, 24.4, 21.8, 21.9, 17.7, 18.1],\n",
+ "})\n",
+ "\n",
+ "\n",
+ "# Plot configuration\n",
+ "plt.figure(figsize=(6, 8))\n",
+ "sns.set_style(\"whitegrid\")\n",
+ "\n",
+ "# Create bar plot with LTR and RTL next to each other, no error bars (ci=None)\n",
+ "sns.barplot(x='model', y='ppl', hue='direction', data=transfer_wikitext_df.sort_values(by=['order', 'direction']), dodge=True, palette=\"Set2\", ci=None)\n",
+ "\n",
+ "# Adjustments to the plot\n",
+ "# plt.xticks(rotation=45)\n",
+ "plt.title(\"Perplexity vs Model Size, From Scratch\", fontsize=20)\n",
+ "plt.xlabel(\"Model Size\", fontsize=20)\n",
+ "plt.ylabel(\"Test Perplexity\", fontsize=20)\n",
+ "plt.ylim(0.0, 122.75062123923252)\n",
+ "plt.legend(title=\"\", fontsize=20)\n",
+ "plt.tick_params(axis='both', labelsize=20)\n",
+ "\n",
+ "# Display the updated plot\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9bc44c20-d2a8-431a-97cc-a43655e1f856",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/data/riddles.txt b/data/riddles.txt
new file mode 100644
index 0000000..43787f3
--- /dev/null
+++ b/data/riddles.txt
@@ -0,0 +1,40 @@
+The more you take, the more you leave behind. What am I? Footsteps.
+The more you have of me, the less you see. What am I? Darkness.
+I am not alive, but I grow. I don’t have lungs, but I need air. I don’t have a mouth, but water kills me. What am I? Fire.
+I get wetter the more I dry. What am I? A towel.
+I become shorter the longer I live. What am I? A candle.
+I am light as a feather, yet the strongest man can’t hold me for much longer than a minute. What am I? Breath.
+I am invisible, but you can feel me. I am intangible, but you can hear me. What am I? The wind.
+I have keys but open no locks. What am I? A piano.
+The more of me there is, the less you see. What am I? Fog.
+I am always in front of you, but you can never see me. What am I? The future.
+I don’t have wings, but I can fly. I don’t have eyes, but I can cry. Wherever I go, darkness follows me. What am I? A cloud.
+I am not alive, but I can die. What am I? A battery.
+I can be stolen, but I can’t be touched. What am I? A heart.
+I am something that comes once in a minute, twice in a moment, but never in a thousand years. What am I? The letter "M."
+I can be cracked, but never broken. What am I? A promise.
+I am always with you, but I can’t be seen. What am I? Your shadow.
+I’m small but I can cover miles. What am I? A footprint.
+The more you have of me, the less you know. What am I? A secret.
+I’m something that can fill a room, but I don’t take up space. What am I? Light.
+I am always with you, but you never see me. I can be forgotten, but I never leave. What am I? Your name.
+The more of me there is, the less you hear. What am I? Silence.
+I can travel around the world while staying in the corner. What am I? A stamp.
+I am something that can be cracked, but I can’t be touched. What am I? A code.
+I am something you can hear, but not touch. I can be loud or soft, but I can never be seen. What am I? Sound.
+I am something that you can never keep, no matter how hard you try. What am I? The present moment.
+I am not alive, but I grow. I don’t have a mouth, but I can speak. What am I? A rumor.
+I am always running, but I never move. What am I? A clock.
+I get bigger the more you take away. What am I? A hole.
+I can’t be seen, but I can be felt. I have no color, but I make things clear. What am I? Understanding.
+I get smaller the more you use me. What am I? A pencil.
+I can be broken without being touched. What am I? A promise.
+I am something that everyone has, but no one can keep forever. What am I? Time.
+I can be light as a feather, but even the strongest hands cannot hold me. What am I? A thought.
+The more you take from me, the greater I become. What am I? A debt.
+I am often in front of you, but I’m never within reach. What am I? The horizon.
+You can see me every day, but I will never be seen the same way twice. What am I? The sky.
+I am not alive, but I grow over time. What am I? A reputation.
+I can’t be touched, but I can touch everything. What am I? A feeling.
+I never speak, but I can communicate. What am I? A look.
+I can be hard, but I am not solid. What am I? A deadline.
\ No newline at end of file
diff --git a/data/wandb_export_2024-12-04T19_56_43.325-05_00.csv b/data/wandb_export_2024-12-04T19_56_43.325-05_00.csv
new file mode 100644
index 0000000..0793204
--- /dev/null
+++ b/data/wandb_export_2024-12-04T19_56_43.325-05_00.csv
@@ -0,0 +1,21 @@
+"Name","val_loss"
+"distilbert_base_japan_rtl","2.8326140656842465"
+"distilbert_base_japan_ltr","2.8237654270093375"
+"bert_6M_rtl_scratch","4.744475745069383"
+"bert_6_ltr_scratch","4.761364663504469"
+"bert_11_rtl_scratch","4.446949723712903"
+"bert_11_ltr_scratch","4.462378635840655"
+"bert_19_rtl_scratch","4.177320378220149"
+"bert_19_ltr_scratch","4.186270630920852"
+"bert_35_rtl_scratch","3.927856646112007"
+"bert_35_ltr_scratch","3.941595227497572"
+"qa_distilbert_base_ltr_v2","3.1502674087524416"
+"qa_distilbert_base_rtl_v2","3.1904524799346925"
+"qa_ltr_distilbert_base","3.3259500965491715"
+"distilbert_base_ltr_scratch","3.6863074678864063"
+"distilbert_base_rtl_scratch","3.6885659350549624"
+"deep-monkey-11","3.009245432539425"
+"distilbert_base_ltr_4epoch","3.1961001348322804"
+"distilbert_base_rtl_4epoch","3.19366226070481"
+"bert_base_ltr_4epoch","3.082235844222857"
+"bert_base_rtl_4epoch","3.0881099989546192"
\ No newline at end of file
diff --git a/finetune_QA.py b/finetune_QA.py
new file mode 100644
index 0000000..e5b8ef7
--- /dev/null
+++ b/finetune_QA.py
@@ -0,0 +1,304 @@
+"""
+accelerate launch --mixed_precision bf16 finetune_QA.py \
+--model_direction rtl \
+--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_rtl/epoch_3_checkpt \
+--tokenizer_name distilbert/distilbert-base-uncased \
+--warmup_steps 100 \
+--learning_rate 1e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/qa_distilbert_base_rtl/ \
+--eval_steps 38 \
+--block_size 128 \
+--num_train_epochs 50 \
+--weight_decay 1e-4
+
+
+accelerate launch --mixed_precision bf16 finetune_QA.py \
+--model_direction ltr \
+--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \
+--tokenizer_name distilbert/distilbert-base-uncased \
+--warmup_steps 100 \
+--learning_rate 1e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/qa_distilbert_base_ltr/ \
+--eval_steps 38 \
+--block_size 128 \
+--num_train_epochs 50 \
+--weight_decay 1e-4
+
+accelerate launch --mixed_precision bf16 finetune_QA.py \
+--model_direction ltr \
+--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \
+--tokenizer_name distilbert/distilbert-base-uncased \
+--warmup_steps 100 \
+--learning_rate 1e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/qa_distilbert_base_ltr_overfit/ \
+--eval_steps 50 \
+--block_size 128 \
+--num_train_epochs 1000 \
+--weight_decay 0
+"""
+
+
+
+import argparse
+import math
+import os
+from collections import defaultdict
+
+import accelerate
+import torch
+import transformers
+import wandb
+from datasets import load_dataset
+from torch.utils.data import Dataset, DataLoader
+from transformers.data.data_collator import default_data_collator
+from tqdm.auto import tqdm
+
+from utils import preprocess_datasets, convert_to_torch_dataset, add_attn_hooks, causal_loss_wrapper
+
+#### HERE WE do the dataset stuff
+class DatasetAQ(Dataset):
+ def __init__(self, qa_pairs, text_direction, tokenizer):
+ self.qa_pairs = qa_pairs
+ self.text_direction = text_direction
+ self.tokenizer = tokenizer
+
+ def __getitem__(self, idx):
+ question, answer = self.qa_pairs[idx]
+ sentence = torch.cat([question, answer], dim=0) if self.text_direction.lower() == "rtl" else torch.cat([answer, question], dim=0)
+
+ # TODO: length
+ num_to_pad = self.tokenizer.model_max_length - sentence.size(0)
+ assert num_to_pad >= 0, (sentence.size(), self.tokenizer.model_max_length)
+
+ if num_to_pad > 0:
+ pad_tokens = torch.full((num_to_pad,), self.tokenizer.pad_token_id, dtype=sentence.dtype)
+ pad_labels = torch.full((num_to_pad,), -100, dtype=sentence.dtype)
+
+ if self.text_direction.lower() == "rtl":
+ input_ids = torch.cat([pad_tokens, sentence], dim=0)
+ labels = torch.cat([pad_labels, sentence], dim=0)
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
+ attention_mask[:num_to_pad] = 0
+ else:
+ input_ids = torch.cat([sentence, pad_tokens], dim=0)
+ labels = torch.cat([sentence, pad_labels], dim=0)
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
+ attention_mask[-num_to_pad:] = 0
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ "attention_mask": attention_mask,
+ }
+
+ def __len__(self):
+ return len(self.qa_pairs)
+
+####
+
+
+
+def parse_args():
+ """
+ Re-using HuggingFace arguments when possible (most of the help strings are directly copied).
+ https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L75
+ """
+ parser = argparse.ArgumentParser()
+
+ # Model
+ parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"],
+ help="Whether to train a left-to-right or right-to-left LM.")
+ parser.add_argument("--checkpoint_path", type=str,
+ help="Path to load model weights from.")
+
+ # Data
+ parser.add_argument("--tokenizer_name", type=str,
+ help="Name of tokenizer to load.")
+ parser.add_argument("--dataset_name", type=str, default="truthfulqa/truthful_qa",
+ help="The name of the dataset to use (via the datasets library).")
+ parser.add_argument("--dataset_config_name", type=str, default="generation",
+ help="The configuration name of the dataset to use (via the datasets library).")
+ # TODO: block_size, train on shorter seqs?
+ parser.add_argument(
+ "--block_size",
+ type=int,
+ help="Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
+
+ # Training
+ parser.add_argument("--train_from_scratch", action="store_true")
+ parser.add_argument("--output_dir", type=str, required=True,
+ help="The output directory where the model predictions and checkpoints will be written.")
+ parser.add_argument("--per_device_train_batch_size", type=int, default=8)
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=16)
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument("--learning_rate", type=float, required=True)
+ parser.add_argument("--warmup_steps", type=int, default=0)
+ parser.add_argument("--scheduler", type=str, default="cosine")
+ parser.add_argument("--weight_decay", type=float, default=0.0)
+ parser.add_argument("--logging_steps", type=int, default=1,
+ help="Number of update steps between two logs.")
+ parser.add_argument("--eval_steps", type=int, default=20000,
+ help="Number of update steps between two logs.")
+ parser.add_argument("--dataloader_num_workers", type=int, default=8)
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ transformers.set_seed(42)
+
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir)
+ # Will `add_attn_hooks` to `model` later
+
+ # Load model weights in both cases, but re-initialize if training from scratch
+ model = transformers.AutoModelForMaskedLM.from_pretrained(args.checkpoint_path, attn_implementation="sdpa", ignore_mismatched_sizes=True)
+ if args.train_from_scratch:
+ model.apply(model._init_weights)
+ model.tie_weights() # probably not applicable
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_name)
+
+ # Data
+ raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
+ block_size = args.block_size if args.block_size is not None else model.config.max_position_embeddings
+ model.config.max_position_embeddings = block_size
+ tokenizer.model_max_length = block_size
+
+ # QA-specific code
+ all_data = raw_datasets["validation"]
+ transformers.set_seed(42)
+ train_val_split = all_data.train_test_split(test_size=0.2, shuffle=True)
+ val_test_split = train_val_split['test'].train_test_split(test_size=0.5, shuffle=False)
+ train_dataset = train_val_split['train']
+ val_dataset = val_test_split['train']
+ test_dataset = val_test_split['test']
+
+ qa_pairs = defaultdict(list)
+ for data_name, dataset in zip(["test","train","val"], [train_dataset, test_dataset, val_dataset]):
+ for row in dataset:
+ tokenized_question = tokenizer("Question: "+ row["question"], return_tensors="pt")["input_ids"].squeeze(0)
+ for ans_type in ["correct_answers", "incorrect_answers"]:
+ for answer in row[ans_type]:
+ # the [:, 1:] thing is to remove CLS token
+ qa_pairs[data_name].append((tokenized_question, tokenizer(f"Answer: {answer}", return_tensors="pt")["input_ids"].squeeze(0)[1:]))
+
+ train_dataset = DatasetAQ(qa_pairs["train"], args.model_direction, tokenizer)
+ test_dataset = DatasetAQ(qa_pairs["test"], args.model_direction, tokenizer)
+ val_dataset = DatasetAQ(qa_pairs["val"], args.model_direction, tokenizer)
+
+ train_loader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size, shuffle=True)
+ test_loader = DataLoader(test_dataset, batch_size=args.per_device_eval_batch_size, shuffle=False)
+ val_loader = DataLoader(val_dataset, batch_size=args.per_device_eval_batch_size)
+
+ model, train_loader, test_loader, val_loader = accelerator.prepare(model, train_loader, test_loader, val_loader)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
+ lr_scheduler = transformers.get_scheduler(
+ name=transformers.SchedulerType.COSINE,
+ optimizer=optimizer,
+ num_warmup_steps=args.warmup_steps * accelerator.num_processes,
+ # num_training_steps=args.num_train_epochs * math.ceil(len(train_loader) / args.gradient_accumulation_steps),
+ num_training_steps=args.num_train_epochs * len(train_loader),
+ )
+
+ lr_scheduler = accelerator.prepare(lr_scheduler) # testing if this fixes learning rate
+
+ loss_fn = causal_loss_wrapper(args.model_direction)
+
+ add_attn_hooks(model, args.model_direction)
+ model.train()
+ optimizer.zero_grad()
+
+ wandb.require("core")
+ accelerator.init_trackers(
+ project_name="NLP-Class-Project",
+ config=vars(args) | {"model_parameters": sum(p.numel() for p in model.parameters())},
+ init_kwargs={"wandb": {"entity": "frostbyte"}}
+ )
+
+ global_step = 0 # unaccumulated steps
+ past_losses = []
+ best_val_loss = float("inf")
+ best_checkpt_path = os.path.join(args.output_dir, f"best_checkpt")
+
+ for epoch in tqdm(range(args.num_train_epochs), position=0, leave=True, desc="Epoch"):
+ for step, batch in enumerate(tqdm(train_loader, position=1, leave=False, desc="Train Iteration")):
+ with accelerator.accumulate(model):
+ labels = batch.pop("labels")
+ outputs = model(**batch)
+ loss = loss_fn(outputs.logits, labels)
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ past_losses.append(loss.item())
+ if (global_step + 1) % args.logging_steps == 0:
+ avg_train_loss = torch.tensor(past_losses).mean().item() # Assuming 1 GPU
+ accelerator.log({
+ "train_loss": avg_train_loss,
+ "learning_rate": lr_scheduler.get_last_lr()[0],
+ })
+ past_losses.clear()
+
+ if (global_step + 1) % args.eval_steps == 0:
+ val_loss_sum = val_examples = 0
+ model.eval()
+ for val_batch in tqdm(val_loader, position=2, leave=False, desc="Val Iteration"):
+ labels = val_batch.pop("labels")
+ with torch.no_grad():
+ outputs = model(**val_batch)
+
+ loss = loss_fn(outputs.logits, labels)
+
+ batch_size = labels.size(0)
+ val_loss_sum += loss.item() * batch_size
+ val_examples += batch_size
+
+ val_loss = val_loss_sum / val_examples
+ if val_loss < best_val_loss:
+ best_val_loss = val_loss
+ model.save_pretrained(best_checkpt_path)
+
+ accelerator.log({"val_loss": val_loss_sum / val_examples},
+ log_kwargs={"wandb": {"commit": False}})
+ model.train()
+
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or step == (len(train_loader) - 1):
+ global_step += 1
+
+ # model.save_pretrained(os.path.join(args.output_dir, f"epoch_{epoch}_checkpt"))
+
+ # testing
+ model.from_pretrained(best_checkpt_path)
+ model.eval()
+ with torch.no_grad():
+ test_loss_sum = test_examples = 0
+ for test_batch in tqdm(test_loader):
+ labels = test_batch.pop("labels")
+ outputs = model(**test_batch)
+
+ loss = loss_fn(outputs.logits, labels)
+
+ batch_size = labels.size(0)
+ test_loss_sum += loss.item() * batch_size
+ test_examples += batch_size
+
+ accelerator.log({"test_loss": test_loss_sum / test_examples})
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune_bert-japanese.py b/finetune_bert-japanese.py
new file mode 100644
index 0000000..f0b1238
--- /dev/null
+++ b/finetune_bert-japanese.py
@@ -0,0 +1,225 @@
+"""
+# BERT japanese RTL
+accelerate launch --mixed_precision bf16 finetune_bert-japanese.py \
+--model_direction rtl \
+--model_name distilbert/distilbert-base-multilingual-cased \
+--dataset_name ntotsuka123/ja-pretrain \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--gradient_accumulation_steps 1 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_japan_rtl/ \
+--eval_steps 1000 \
+--block_size 128 \
+--num_train_epochs 1 \
+--weight_decay 1e-4
+
+
+is there some way to only do 1% of the data...
+got it
+you have to change the code. I don't want ot do it right now
+
+# BERT japanese LTR
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--dataset_name oscar \
+--dataset_config_name unshuffled_deduplicated_ja \
+--model_name cl-tohoku/bert-base-japanese \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/bert_base_rtl/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+
+"""
+
+import argparse
+import math
+import os
+
+import accelerate
+import torch
+import transformers
+import wandb
+from datasets import load_dataset
+from torch.utils.data import DataLoader, Subset
+from tqdm.auto import tqdm
+from transformers import set_seed
+
+from utils import preprocess_datasets, convert_to_torch_dataset, add_attn_hooks, causal_loss_wrapper
+
+
+
+def parse_args():
+ """
+ Re-using HuggingFace arguments when possible (most of the help strings are directly copied).
+ https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L75
+ """
+ parser = argparse.ArgumentParser()
+
+ # Model
+ parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"],
+ help="Whether to train a left-to-right or right-to-left LM.")
+ parser.add_argument("--model_config", type=str,
+ help="Path to model config json, from which to train_from_scratch.")
+ parser.add_argument("--model_name", type=str, required=True,
+ help="Name of tokenizer to load. "
+ "If model_config is not specified, will also load model architecture."
+ "If not training from scratch, will also load model weights.")
+
+ # Data
+ parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext",
+ help="The name of the dataset to use (via the datasets library).")
+ parser.add_argument("--dataset_config_name", type=str, default="wikitext-103-v1",
+ help="The configuration name of the dataset to use (via the datasets library).")
+ # TODO: block_size, train on shorter seqs?
+ parser.add_argument(
+ "--block_size",
+ type=int,
+ help="Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
+ )
+
+ # Training
+ parser.add_argument("--train_from_scratch", action="store_true")
+ parser.add_argument("--output_dir", type=str, required=True,
+ help="The output directory where the model predictions and checkpoints will be written.")
+ parser.add_argument("--per_device_train_batch_size", type=int, default=8)
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=16)
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument("--learning_rate", type=float, required=True)
+ parser.add_argument("--warmup_steps", type=int, default=0)
+ parser.add_argument("--weight_decay", type=float, default=0.0)
+ parser.add_argument("--logging_steps", type=int, default=1,
+ help="Number of update steps between two logs.")
+ parser.add_argument("--eval_steps", type=int, default=20000,
+ help="Number of update steps between two logs.")
+ parser.add_argument("--dataloader_num_workers", type=int, default=8)
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir)
+ set_seed(42)
+
+ # Will `add_attn_hooks` to `model` later
+ if args.model_config is not None:
+ assert args.train_from_scratch, "Expected to train from scratch when model_config is specified."
+ config = transformers.AutoConfig.from_pretrained(args.model_config)
+ model = transformers.AutoModelForMaskedLM.from_config(config)
+ else:
+ # Load model weights in both cases, but re-initialize if training from scratch
+ model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa")
+
+ if args.train_from_scratch:
+ model.apply(model._initialize_weights)
+ model.tie_weights() # probably not applicable
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name)
+
+ # Data
+ raw_datasets = load_dataset(args.dataset_name)
+ block_size = args.block_size if args.block_size is not None else model.config.max_position_embeddings
+ model.config.max_position_embeddings = block_size
+
+ processed_datasets = preprocess_datasets(raw_datasets, tokenizer, block_size)
+ for split, hf_dataset in processed_datasets.items():
+ processed_datasets[split] = convert_to_torch_dataset(hf_dataset)
+
+ train_val_split = processed_datasets["train"].train_test_split(test_size=0.2, shuffle=True)
+ train_indices = torch.randperm(len(train_val_split["train"]))[:int(0.4 * len(train_val_split["train"]))]
+ train_subset = Subset(train_val_split["train"], train_indices)
+ val_indices = torch.randperm(len(train_val_split["test"]))[:int(0.01 * len(train_val_split["test"]))]
+ val_subset = Subset(train_val_split["test"], val_indices)
+ train_loader = DataLoader(train_subset, batch_size=args.per_device_train_batch_size, shuffle=True)
+ val_loader = DataLoader(val_subset, batch_size=args.per_device_eval_batch_size)
+
+ # train_val_split = processed_datasets["train"].train_test_split(test_size=0.2, shuffle=True)
+ # train_loader = DataLoader(train_val_split["train"], batch_size=args.per_device_train_batch_size, shuffle=True)
+ # val_loader = DataLoader(train_val_split["test"], batch_size=args.per_device_eval_batch_size)
+ # test_loader = DataLoader(processed_datasets["test"], batch_size=args.per_device_eval_batch_size)
+
+ model, train_loader, val_loader = accelerator.prepare(model, train_loader, val_loader)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
+ lr_scheduler = transformers.get_scheduler(
+ name=transformers.SchedulerType.CONSTANT,
+ optimizer=optimizer,
+ num_warmup_steps=args.warmup_steps * accelerator.num_processes,
+ num_training_steps=args.num_train_epochs * math.ceil(len(train_loader) / args.gradient_accumulation_steps),
+ )
+ loss_fn = causal_loss_wrapper(args.model_direction)
+
+ add_attn_hooks(model, args.model_direction)
+ model.train()
+ optimizer.zero_grad()
+
+ wandb.require("core")
+ accelerator.init_trackers(
+ project_name="NLP-Class-Project",
+ config=vars(args) | {"model_parameters": sum(p.numel() for p in model.parameters())},
+ init_kwargs={"wandb": {"entity": "frostbyte"}}
+ )
+
+ global_step = 0 # unaccumulated steps
+ past_losses = []
+ for epoch in tqdm(range(args.num_train_epochs), position=0, leave=True, desc="Epoch"):
+ for step, batch in enumerate(tqdm(train_loader, position=1, leave=False, desc="Train Iteration")):
+ with accelerator.accumulate(model):
+ labels = batch.pop("labels")
+ outputs = model(**batch)
+ loss = loss_fn(outputs.logits, labels)
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ past_losses.append(loss.item())
+ if (global_step + 1) % args.logging_steps == 0:
+ avg_train_loss = torch.tensor(past_losses).mean().item() # Assuming 1 GPU
+ accelerator.log({
+ "train_loss": avg_train_loss,
+ "learning_rate": lr_scheduler.get_last_lr()[0],
+ })
+ past_losses.clear()
+
+ if (global_step + 1) % args.eval_steps == 0:
+ val_loss_sum = val_examples = 0
+ model.eval()
+ for val_batch in tqdm(val_loader, position=2, leave=False, desc="Val Iteration"):
+ labels = val_batch.pop("labels")
+ with torch.no_grad():
+ outputs = model(**val_batch)
+
+ loss = loss_fn(outputs.logits, labels)
+
+ batch_size = labels.size(0)
+ val_loss_sum += loss.item() * batch_size
+ val_examples += batch_size
+
+ accelerator.log({"val_loss": val_loss_sum / val_examples},
+ log_kwargs={"wandb": {"commit": False}})
+ model.train()
+
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or step == (len(train_loader) - 1):
+ global_step += 1
+
+ model.save_pretrained(os.path.join(args.output_dir, f"epoch_{epoch}_checkpt"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune_bert.py b/finetune_bert.py
index 59c8090..da29af9 100644
--- a/finetune_bert.py
+++ b/finetune_bert.py
@@ -1,663 +1,330 @@
-#!/usr/bin/env python
-# coding=utf-8
-# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
"""
-Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
-Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
-https://huggingface.co/models?filter=text-generation
+# BERT base
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--model_name bert-base-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/bert_base_rtl/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction ltr \
+--model_name bert-base-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/bert_base_ltr/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+# DistilBERT scratch
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--model_name distilbert/distilbert-base-uncased \
+--train_from_scratch \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_rtl_scratch/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction ltr \
+--model_name distilbert/distilbert-base-uncased \
+--train_from_scratch \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_ltr_scratch/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+# DistilBERT base
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--model_name distilbert/distilbert-base-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_rtl/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction ltr \
+--model_name distilbert/distilbert-base-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 128 \
+--per_device_eval_batch_size 128 \
+--output_dir checkpoints/distilbert_base_ltr/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+# BERT large
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction rtl \
+--model_name bert-large-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 64 \
+--gradient_accumulation_steps 2 \
+--per_device_eval_batch_size 64 \
+--output_dir checkpoints/bert_large_rtl/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+
+accelerate launch --mixed_precision bf16 finetune_bert.py \
+--model_direction ltr \
+--model_name bert-large-uncased \
+--warmup_steps 500 \
+--learning_rate 5e-5 \
+--per_device_train_batch_size 64 \
+--gradient_accumulation_steps 2 \
+--per_device_eval_batch_size 64 \
+--output_dir checkpoints/bert_large_ltr/ \
+--eval_steps 899 \
+--block_size 128 \
+--num_train_epochs 4 \
+--weight_decay 1e-4
+
+for size in 35 19 11 6; do
+ for dir in ltr rtl; do
+ accelerate launch --mixed_precision bf16 finetune_bert.py \
+ --model_direction $dir \
+ --model_name bert-base-uncased \
+ --model_config "configs/bert_${size}M.json" \
+ --train_from_scratch \
+ --warmup_steps 500 \
+ --learning_rate 5e-5 \
+ --per_device_train_batch_size 128 \
+ --per_device_eval_batch_size 128 \
+ --output_dir "checkpoints/bert_${size}_${dir}_scratch/" \
+ --eval_steps 899 \
+ --block_size 128 \
+ --num_train_epochs 4 \
+ --weight_decay 1e-4
+ done
+done
+
+for seed in 0 1 2 3 4; do
+ for dir in ltr rtl; do
+ accelerate launch --mixed_precision bf16 finetune_bert.py \
+ --model_direction $dir \
+ --model_name bert-base-uncased \
+ --model_config "configs/bert_${size}M.json" \
+ --train_from_scratch \
+ --warmup_steps 500 \
+ --learning_rate 5e-5 \
+ --per_device_train_batch_size 128 \
+ --per_device_eval_batch_size 128 \
+ --output_dir "checkpoints/overwritable_temp/" \
+ --eval_steps 899 \
+ --block_size 128 \
+ --num_train_epochs 1 \
+ --weight_decay 1e-4 \
+ --seed $seed
+ done
+done
"""
-# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
-"""
-From https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py
-"""
-
-import logging
+import argparse
import math
import os
-import sys
-from dataclasses import dataclass, field
-from itertools import chain
-from typing import Optional
-import datasets
-import evaluate
+import accelerate
import torch
-from datasets import load_dataset
-
import transformers
-from transformers import (
- CONFIG_MAPPING,
- MODEL_FOR_CAUSAL_LM_MAPPING,
- AutoConfig,
- AutoModelForCausalLM,
- AutoTokenizer,
- HfArgumentParser,
- Trainer,
- TrainingArguments,
- default_data_collator,
- is_torch_xla_available,
- set_seed,
-)
-from transformers.testing_utils import CaptureLogger
-from transformers.trainer_utils import get_last_checkpoint
-from transformers.utils import check_min_version, send_example_telemetry
-from transformers.utils.versions import require_version
-
-
-# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
-check_min_version("4.47.0.dev0")
-
-require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
-
-logger = logging.getLogger(__name__)
-
-
-MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
-MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
-
-
-@dataclass
-class ModelArguments:
- """
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
- """
- # text_direction: str = field(
- #
- # )
- model_name_or_path: Optional[str] = field(
- default=None,
- metadata={
- "help": (
- "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
- )
- },
- )
- # model_type: Optional[str] = field(
- # default=None,
- # metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
- # )
- config_overrides: Optional[str] = field(
- default=None,
- metadata={
- "help": (
- "Override some existing default config settings when a model is trained from scratch. Example: "
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
- )
- },
- )
- config_name: Optional[str] = field(
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
- )
- tokenizer_name: Optional[str] = field(
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
- )
- cache_dir: Optional[str] = field(
- default=None,
- metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
- )
- use_fast_tokenizer: bool = field(
- default=True,
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
- )
- model_revision: str = field(
- default="main",
- metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
- )
- token: str = field(
- default=None,
- metadata={
- "help": (
- "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
- "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
- )
- },
- )
- trust_remote_code: bool = field(
- default=False,
- metadata={
- "help": (
- "Whether to trust the execution of code from datasets/models defined on the Hub."
- " This option should only be set to `True` for repositories you trust and in which you have read the"
- " code, as it will execute code present on the Hub on your local machine."
- )
- },
- )
- torch_dtype: Optional[str] = field(
- default=None,
- metadata={
- "help": (
- "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
- "dtype will be automatically derived from the model's weights."
- ),
- "choices": ["auto", "bfloat16", "float16", "float32"],
- },
- )
- low_cpu_mem_usage: bool = field(
- default=False,
- metadata={
- "help": (
- "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
- "set True will benefit LLM loading time and RAM consumption."
- )
- },
- )
+import wandb
+from datasets import load_dataset
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+from transformers import set_seed
- def __post_init__(self):
- if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
- raise ValueError(
- "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
- )
+from utils import preprocess_datasets, convert_to_torch_dataset, add_attn_hooks, causal_loss_wrapper
-@dataclass
-class DataTrainingArguments:
+def parse_args():
"""
- Arguments pertaining to what data we are going to input our model for training and eval.
+ Re-using HuggingFace arguments when possible (most of the help strings are directly copied).
+ https://github.com/huggingface/transformers/blob/7bbc62474391aff64f63fcc064c975752d1fa4de/examples/pytorch/language-modeling/run_clm.py#L75
"""
-
- dataset_name: Optional[str] = field(
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
- )
- dataset_config_name: Optional[str] = field(
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
- )
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
- validation_file: Optional[str] = field(
- default=None,
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
- )
- max_train_samples: Optional[int] = field(
- default=None,
- metadata={
- "help": (
- "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
- )
- },
- )
- max_eval_samples: Optional[int] = field(
- default=None,
- metadata={
- "help": (
- "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
- "value if set."
- )
- },
- )
- streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
- block_size: Optional[int] = field(
- default=None,
- metadata={
- "help": (
- "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size for training. "
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
- )
- },
- )
- overwrite_cache: bool = field(
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
+ parser = argparse.ArgumentParser()
+
+ # Model
+ parser.add_argument("--model_direction", type=str, required=True, choices=["ltr", "rtl"],
+ help="Whether to train a left-to-right or right-to-left LM.")
+ parser.add_argument("--model_config", type=str,
+ help="Path to model config json, from which to train_from_scratch.")
+ parser.add_argument("--model_name", type=str, required=True,
+ help="Name of tokenizer to load. "
+ "If model_config is not specified, will also load model architecture."
+ "If not training from scratch, will also load model weights.")
+
+ # Data
+ parser.add_argument("--dataset_name", type=str, default="Salesforce/wikitext",
+ help="The name of the dataset to use (via the datasets library).")
+ parser.add_argument("--dataset_config_name", type=str, default="wikitext-103-v1",
+ help="The configuration name of the dataset to use (via the datasets library).")
+ # TODO: block_size, train on shorter seqs?
+ parser.add_argument(
+ "--block_size",
+ type=int,
+ help="Optional input sequence length after tokenization. "
+ "The training dataset will be truncated in block of this size for training. "
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
)
- validation_split_percentage: Optional[int] = field(
- default=5,
- metadata={
- "help": "The percentage of the train set used as validation set in case there's no validation split"
- },
- )
- preprocessing_num_workers: Optional[int] = field(
- default=None,
- metadata={"help": "The number of processes to use for the preprocessing."},
- )
- keep_linebreaks: bool = field(
- default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
- )
-
- def __post_init__(self):
- if self.streaming:
- require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
- raise ValueError("Need either a dataset name or a training/validation file.")
- else:
- if self.train_file is not None:
- extension = self.train_file.split(".")[-1]
- assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
- if self.validation_file is not None:
- extension = self.validation_file.split(".")[-1]
- assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
+ # Training
+ parser.add_argument("--train_from_scratch", action="store_true")
+ parser.add_argument("--output_dir", type=str, required=True,
+ help="The output directory where the model predictions and checkpoints will be written.")
+ parser.add_argument("--per_device_train_batch_size", type=int, default=8)
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=16)
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument("--learning_rate", type=float, required=True)
+ parser.add_argument("--warmup_steps", type=int, default=0)
+ parser.add_argument("--weight_decay", type=float, default=0.0)
+ parser.add_argument("--logging_steps", type=int, default=1,
+ help="Number of update steps between two logs.")
+ parser.add_argument("--eval_steps", type=int, default=20000,
+ help="Number of update steps between two logs.")
+ parser.add_argument("--dataloader_num_workers", type=int, default=8)
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
+
+ args = parser.parse_args()
+
+ return args
def main():
- # See all possible arguments in src/transformers/training_args.py
- # or by passing the --help flag to this script.
- # We now keep distinct sets of args, for a cleaner separation of concerns.
-
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
- if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
- # If we pass only one argument to the script and it's the path to a json file,
- # let's parse it to get our arguments.
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
- else:
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ args = parse_args()
- # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
- # information sent is the one passed as arguments along with your Python/PyTorch versions.
- send_example_telemetry("run_clm", model_args, data_args)
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with="wandb", project_dir=args.output_dir)
+ set_seed(args.seed)
- # Setup logging
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- handlers=[logging.StreamHandler(sys.stdout)],
- )
-
- if training_args.should_log:
- # The default of training_args.log_level is passive, so we set log level at info here to have that default.
- transformers.utils.logging.set_verbosity_info()
-
- log_level = training_args.get_process_log_level()
- logger.setLevel(log_level)
- datasets.utils.logging.set_verbosity(log_level)
- transformers.utils.logging.set_verbosity(log_level)
- transformers.utils.logging.enable_default_handler()
- transformers.utils.logging.enable_explicit_format()
-
- # Log on each process the small summary:
- logger.warning(
- f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
- + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
- )
- logger.info(f"Training/evaluation parameters {training_args}")
-
- # Detecting last checkpoint.
- last_checkpoint = None
- if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
- last_checkpoint = get_last_checkpoint(training_args.output_dir)
- if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
- raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. "
- "Use --overwrite_output_dir to overcome."
- )
- elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
- logger.info(
- f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
- "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
- )
-
- # Set seed before initializing model.
- set_seed(training_args.seed)
-
- # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
- # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
- # (the dataset will be downloaded automatically from the datasets Hub).
- #
- # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
- # 'text' is found. You can easily tweak this behavior (see below).
- #
- # In distributed training, the load_dataset function guarantee that only one local process can concurrently
- # download the dataset.
- if data_args.dataset_name is not None:
- # Downloading and loading a dataset from the hub.
- raw_datasets = load_dataset(
- data_args.dataset_name,
- data_args.dataset_config_name,
- cache_dir=model_args.cache_dir,
- token=model_args.token,
- streaming=data_args.streaming,
- trust_remote_code=model_args.trust_remote_code,
- )
- if "validation" not in raw_datasets.keys():
- raw_datasets["validation"] = load_dataset(
- data_args.dataset_name,
- data_args.dataset_config_name,
- split=f"train[:{data_args.validation_split_percentage}%]",
- cache_dir=model_args.cache_dir,
- token=model_args.token,
- streaming=data_args.streaming,
- trust_remote_code=model_args.trust_remote_code,
- )
- raw_datasets["train"] = load_dataset(
- data_args.dataset_name,
- data_args.dataset_config_name,
- split=f"train[{data_args.validation_split_percentage}%:]",
- cache_dir=model_args.cache_dir,
- token=model_args.token,
- streaming=data_args.streaming,
- trust_remote_code=model_args.trust_remote_code,
- )
- else:
- data_files = {}
- dataset_args = {}
- if data_args.train_file is not None:
- data_files["train"] = data_args.train_file
- if data_args.validation_file is not None:
- data_files["validation"] = data_args.validation_file
- extension = (
- data_args.train_file.split(".")[-1]
- if data_args.train_file is not None
- else data_args.validation_file.split(".")[-1]
- )
- if extension == "txt":
- extension = "text"
- dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
- raw_datasets = load_dataset(
- extension,
- data_files=data_files,
- cache_dir=model_args.cache_dir,
- token=model_args.token,
- **dataset_args,
- )
- # If no validation data is there, validation_split_percentage will be used to divide the dataset.
- if "validation" not in raw_datasets.keys():
- raw_datasets["validation"] = load_dataset(
- extension,
- data_files=data_files,
- split=f"train[:{data_args.validation_split_percentage}%]",
- cache_dir=model_args.cache_dir,
- token=model_args.token,
- **dataset_args,
- )
- raw_datasets["train"] = load_dataset(
- extension,
- data_files=data_files,
- split=f"train[{data_args.validation_split_percentage}%:]",
- cache_dir=model_args.cache_dir,
- token=model_args.token,
- **dataset_args,
- )
-
- # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
- # https://huggingface.co/docs/datasets/loading_datasets.
-
- # Load pretrained model and tokenizer
- #
- # Distributed training:
- # The .from_pretrained methods guarantee that only one local process can concurrently
- # download model & vocab.
-
- config_kwargs = {
- "cache_dir": model_args.cache_dir,
- "revision": model_args.model_revision,
- "token": model_args.token,
- "trust_remote_code": model_args.trust_remote_code,
- }
- if model_args.config_name:
- config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
- elif model_args.model_name_or_path:
- config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
- else:
- config = CONFIG_MAPPING[model_args.model_type]()
- logger.warning("You are instantiating a new config instance from scratch.")
- if model_args.config_overrides is not None:
- logger.info(f"Overriding config: {model_args.config_overrides}")
- config.update_from_string(model_args.config_overrides)
- logger.info(f"New config: {config}")
-
- tokenizer_kwargs = {
- "cache_dir": model_args.cache_dir,
- "use_fast": model_args.use_fast_tokenizer,
- "revision": model_args.model_revision,
- "token": model_args.token,
- "trust_remote_code": model_args.trust_remote_code,
- }
- if model_args.tokenizer_name:
- tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
- elif model_args.model_name_or_path:
- tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
- else:
- raise ValueError(
- "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
- "You can do it from another script, save it, and load it from here, using --tokenizer_name."
- )
-
- if model_args.model_name_or_path:
- torch_dtype = (
- model_args.torch_dtype
- if model_args.torch_dtype in ["auto", None]
- else getattr(torch, model_args.torch_dtype)
- )
- model = AutoModelForCausalLM.from_pretrained(
- model_args.model_name_or_path,
- from_tf=bool(".ckpt" in model_args.model_name_or_path),
- config=config,
- cache_dir=model_args.cache_dir,
- revision=model_args.model_revision,
- token=model_args.token,
- trust_remote_code=model_args.trust_remote_code,
- torch_dtype=torch_dtype,
- low_cpu_mem_usage=model_args.low_cpu_mem_usage,
- )
- else:
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
- n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
- logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
-
- # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
- # on a small vocab and want a smaller embedding size, remove this test.
- embedding_size = model.get_input_embeddings().weight.shape[0]
- if len(tokenizer) > embedding_size:
- model.resize_token_embeddings(len(tokenizer))
-
- # Preprocessing the datasets.
- # First we tokenize all the texts.
- if training_args.do_train:
- column_names = list(raw_datasets["train"].features)
- else:
- column_names = list(raw_datasets["validation"].features)
- text_column_name = "text" if "text" in column_names else column_names[0]
-
- # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
- tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
-
- def tokenize_function(examples):
- with CaptureLogger(tok_logger) as cl:
- output = tokenizer(examples[text_column_name])
- # clm input could be much much longer than block_size
- if "Token indices sequence length is longer than the" in cl.out:
- tok_logger.warning(
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
- " before being passed to the model."
- )
- return output
-
- with training_args.main_process_first(desc="dataset map tokenization"):
- if not data_args.streaming:
- tokenized_datasets = raw_datasets.map(
- tokenize_function,
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- remove_columns=column_names,
- load_from_cache_file=not data_args.overwrite_cache,
- desc="Running tokenizer on dataset",
- )
- else:
- tokenized_datasets = raw_datasets.map(
- tokenize_function,
- batched=True,
- remove_columns=column_names,
- )
- if hasattr(config, "max_position_embeddings"):
- max_pos_embeddings = config.max_position_embeddings
+ # Will `add_attn_hooks` to `model` later
+ if args.model_config is not None:
+ assert args.train_from_scratch, "Expected to train from scratch when model_config is specified."
+ config = transformers.AutoConfig.from_pretrained(args.model_config)
+ model = transformers.AutoModelForMaskedLM.from_config(config)
else:
- # Define a default value if the attribute is missing in the config.
- max_pos_embeddings = 1024
-
- if data_args.block_size is None:
- block_size = tokenizer.model_max_length
- if block_size > max_pos_embeddings:
- logger.warning(
- f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
- f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx."
- )
- if max_pos_embeddings > 0:
- block_size = min(1024, max_pos_embeddings)
- else:
- block_size = 1024
- else:
- if data_args.block_size > tokenizer.model_max_length:
- logger.warning(
- f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
- f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
- )
- block_size = min(data_args.block_size, tokenizer.model_max_length)
-
- # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
- def group_texts(examples):
- # Concatenate all texts.
- concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
- total_length = len(concatenated_examples[list(examples.keys())[0]])
- # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
- # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
- total_length = (total_length // block_size) * block_size
- # Split by chunks of max_len.
- result = {
- k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
- for k, t in concatenated_examples.items()
- }
- result["labels"] = result["input_ids"].copy()
- return result
-
- # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
- # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
- # to preprocess.
- #
- # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
- # https://huggingface.co/docs/datasets/process#map
-
- with training_args.main_process_first(desc="grouping texts together"):
- if not data_args.streaming:
- lm_datasets = tokenized_datasets.map(
- group_texts,
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=not data_args.overwrite_cache,
- desc=f"Grouping texts in chunks of {block_size}",
- )
- else:
- lm_datasets = tokenized_datasets.map(
- group_texts,
- batched=True,
- )
-
- if training_args.do_train:
- if "train" not in tokenized_datasets:
- raise ValueError("--do_train requires a train dataset")
- train_dataset = lm_datasets["train"]
- if data_args.max_train_samples is not None:
- max_train_samples = min(len(train_dataset), data_args.max_train_samples)
- train_dataset = train_dataset.select(range(max_train_samples))
-
- if training_args.do_eval:
- if "validation" not in tokenized_datasets:
- raise ValueError("--do_eval requires a validation dataset")
- eval_dataset = lm_datasets["validation"]
- if data_args.max_eval_samples is not None:
- max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
- eval_dataset = eval_dataset.select(range(max_eval_samples))
-
- def preprocess_logits_for_metrics(logits, labels):
- if isinstance(logits, tuple):
- # Depending on the model and config, logits may contain extra tensors,
- # like past_key_values, but logits always come first
- logits = logits[0]
- return logits.argmax(dim=-1)
-
- metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
-
- def compute_metrics(eval_preds):
- preds, labels = eval_preds
- # preds have the same shape as the labels, after the argmax(-1) has been calculated
- # by preprocess_logits_for_metrics but we need to shift the labels
- labels = labels[:, 1:].reshape(-1)
- preds = preds[:, :-1].reshape(-1)
- return metric.compute(predictions=preds, references=labels)
-
- # Initialize our Trainer
- trainer = Trainer(
- model=model,
- args=training_args,
- train_dataset=train_dataset if training_args.do_train else None,
- eval_dataset=eval_dataset if training_args.do_eval else None,
- processing_class=tokenizer,
- # Data collator will default to DataCollatorWithPadding, so we change it.
- data_collator=default_data_collator,
- compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics
- if training_args.do_eval and not is_torch_xla_available()
- else None,
+ # Load model weights in both cases, but re-initialize if training from scratch
+ model = transformers.AutoModelForMaskedLM.from_pretrained(args.model_name, attn_implementation="sdpa")
+
+ if args.train_from_scratch:
+ model.apply(model._initialize_weights)
+ model.tie_weights() # probably not applicable
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name)
+
+ # Data
+ raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
+ block_size = args.block_size if args.block_size is not None else model.config.max_position_embeddings
+ model.config.max_position_embeddings = block_size
+
+ processed_datasets = preprocess_datasets(raw_datasets, tokenizer, block_size)
+ for split, hf_dataset in processed_datasets.items():
+ processed_datasets[split] = convert_to_torch_dataset(hf_dataset)
+
+ train_loader = DataLoader(processed_datasets["train"], batch_size=args.per_device_train_batch_size, shuffle=True)
+ val_loader = DataLoader(processed_datasets["validation"], batch_size=args.per_device_eval_batch_size)
+ # test_loader = DataLoader(processed_datasets["test"], batch_size=args.per_device_eval_batch_size)
+ model, train_loader, val_loader = accelerator.prepare(model, train_loader, val_loader)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
+ lr_scheduler = transformers.get_scheduler(
+ name=transformers.SchedulerType.COSINE,
+ optimizer=optimizer,
+ num_warmup_steps=args.warmup_steps * accelerator.num_processes,
+ num_training_steps=args.num_train_epochs * math.ceil(len(train_loader) / args.gradient_accumulation_steps),
)
+ loss_fn = causal_loss_wrapper(args.model_direction)
- # Training
- if training_args.do_train:
- checkpoint = None
- if training_args.resume_from_checkpoint is not None:
- checkpoint = training_args.resume_from_checkpoint
- elif last_checkpoint is not None:
- checkpoint = last_checkpoint
- train_result = trainer.train(resume_from_checkpoint=checkpoint)
- trainer.save_model() # Saves the tokenizer too for easy upload
-
- metrics = train_result.metrics
-
- max_train_samples = (
- data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
- )
- metrics["train_samples"] = min(max_train_samples, len(train_dataset))
-
- trainer.log_metrics("train", metrics)
- trainer.save_metrics("train", metrics)
- trainer.save_state()
-
- # Evaluation
- if training_args.do_eval:
- logger.info("*** Evaluate ***")
-
- metrics = trainer.evaluate()
-
- max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
- metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
- try:
- perplexity = math.exp(metrics["eval_loss"])
- except OverflowError:
- perplexity = float("inf")
- metrics["perplexity"] = perplexity
-
- trainer.log_metrics("eval", metrics)
- trainer.save_metrics("eval", metrics)
-
- kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
- if data_args.dataset_name is not None:
- kwargs["dataset_tags"] = data_args.dataset_name
- if data_args.dataset_config_name is not None:
- kwargs["dataset_args"] = data_args.dataset_config_name
- kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
- else:
- kwargs["dataset"] = data_args.dataset_name
-
- if training_args.push_to_hub:
- trainer.push_to_hub(**kwargs)
- else:
- trainer.create_model_card(**kwargs)
+ add_attn_hooks(model, args.model_direction)
+ model.train()
+ optimizer.zero_grad()
+ wandb.require("core")
+ accelerator.init_trackers(
+ project_name="NLP-Class-Project",
+ config=vars(args) | {"model_parameters": sum(p.numel() for p in model.parameters())},
+ init_kwargs={"wandb": {"entity": "frostbyte"}}
+ )
-def _mp_fn(index):
- # For xla_spawn (TPUs)
- main()
+ global_step = 0 # unaccumulated steps
+ past_losses = []
+ for epoch in tqdm(range(args.num_train_epochs), position=0, leave=True, desc="Epoch"):
+ for step, batch in enumerate(tqdm(train_loader, position=1, leave=False, desc="Train Iteration")):
+ with accelerator.accumulate(model):
+ labels = batch.pop("labels")
+ outputs = model(**batch)
+ loss = loss_fn(outputs.logits, labels)
+ accelerator.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ past_losses.append(loss.item())
+ if (global_step + 1) % args.logging_steps == 0:
+ avg_train_loss = torch.tensor(past_losses).mean().item() # Assuming 1 GPU
+ accelerator.log({
+ "train_loss": avg_train_loss,
+ "learning_rate": lr_scheduler.get_last_lr()[0],
+ })
+ past_losses.clear()
+
+ if (global_step + 1) % args.eval_steps == 0:
+ val_loss_sum = val_examples = 0
+ model.eval()
+ for val_batch in tqdm(val_loader, position=2, leave=False, desc="Val Iteration"):
+ labels = val_batch.pop("labels")
+ with torch.no_grad():
+ outputs = model(**val_batch)
+
+ loss = loss_fn(outputs.logits, labels)
+
+ batch_size = labels.size(0)
+ val_loss_sum += loss.item() * batch_size
+ val_examples += batch_size
+
+ accelerator.log({"val_loss": val_loss_sum / val_examples},
+ log_kwargs={"wandb": {"commit": False}})
+ model.train()
+
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or step == (len(train_loader) - 1):
+ global_step += 1
+
+ model.save_pretrained(os.path.join(args.output_dir, f"epoch_{epoch}_checkpt"))
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/notebooks/Inference.ipynb b/notebooks/Inference.ipynb
new file mode 100644
index 0000000..670a127
--- /dev/null
+++ b/notebooks/Inference.ipynb
@@ -0,0 +1,638 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "34c536f2-3ccb-4df0-bd47-913d6ef040a2",
+ "metadata": {},
+ "source": [
+ "# Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "execution_state": "idle",
+ "id": "3c6381c4-2a02-415f-a5f1-450fe42b30d3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "\n",
+ "from utils import add_attn_hooks\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr/epoch_49_checkpt\", ignore_mismatched_sizes=True)\n",
+ "model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_v2/best_checkpt\", ignore_mismatched_sizes=True)\n",
+ "\n",
+ "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "execution_state": "idle",
+ "id": "41edf867-7f6a-4d44-871c-8af0b7af7543",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from utils import add_attn_hooks\n",
+ "add_attn_hooks(model, \"ltr\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "execution_state": "idle",
+ "id": "d79a3d16-b982-42ed-9b6a-fad8328e177e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.eval();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "execution_state": "idle",
+ "id": "f06d4bc2-9df8-42c2-9397-3bfb3728da0b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "from typing import override\n",
+ "class DecoderMLM(transformers.AutoModelForMaskedLM, transformers.GenerationMixin):\n",
+ " @override\n",
+ " # @classmethod\n",
+ " def can_generate(cls):\n",
+ " return True\n",
+ "\n",
+ "model2 = DecoderMLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/qa_distilbert_base_ltr_overfit/epoch_999_checkpt\", ignore_mismatched_sizes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "execution_state": "idle",
+ "id": "6feb4fdd-ae43-466d-8dce-a4f9a632a5e6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# model2.can_generate = (lambda s: True)\n",
+ "model2.can_generate = (lambda: True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "execution_state": "idle",
+ "id": "c178761c-7124-42ed-9bfc-7ab0f782aad7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model2."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "execution_state": "idle",
+ "id": "3ab53852-f333-47ea-9e96-55266cda84a6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model2.can_generate()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "execution_state": "idle",
+ "id": "a30f6240-f982-45b0-b75f-3be5bbb43049",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "execution_state": "idle",
+ "id": "2cefa784-a1cc-445a-8ffa-066e7cfccaf0",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq'].",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[25], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m tokenized_question \u001b[38;5;241m=\u001b[39m tokenizer(question, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39minference_mode():\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokenized_question\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_beams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdo_sample\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1.\u001b[39;49m\u001b[43m)\u001b[49m))\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1967\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1882\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1883\u001b[0m \n\u001b[1;32m 1884\u001b[0m \u001b[38;5;124;03mGenerates sequences of token ids for models with a language modeling head.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1963\u001b[0m \u001b[38;5;124;03m - [`~generation.GenerateBeamEncoderDecoderOutput`]\u001b[39;00m\n\u001b[1;32m 1964\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1966\u001b[0m \u001b[38;5;66;03m# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call\u001b[39;00m\n\u001b[0;32m-> 1967\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_model_class\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1968\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1969\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/transformers/generation/utils.py:1269\u001b[0m, in \u001b[0;36mGenerationMixin._validate_model_class\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1262\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcan_generate():\n\u001b[1;32m 1263\u001b[0m terminations_with_generation_support \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 1264\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForCausalLM\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForConditionalGeneration\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForSpeechSeq2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mForVision2Seq\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1268\u001b[0m ]\n\u001b[0;32m-> 1269\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 1270\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe current model class (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) is not compatible with `.generate()`, as \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mit doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt have a language model head. Classes that support generation often end in one of these \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mterminations_with_generation_support\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1273\u001b[0m )\n",
+ "\u001b[0;31mTypeError\u001b[0m: The current model class (DistilBertForMaskedLM) is not compatible with `.generate()`, as it doesn't have a language model head. Classes that support generation often end in one of these names: ['ForCausalLM', 'ForConditionalGeneration', 'ForSpeechSeq2Seq', 'ForVision2Seq']."
+ ]
+ }
+ ],
+ "source": [
+ "question = \"Answer: Grapes are toxic to foxes in large quantities\"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "with torch.inference_mode():\n",
+ " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True, temperature=1.)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c475851e-7c86-46fd-b4ad-cd51caa6e7b8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "question = \"Apples are red and \"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "with torch.inference_mode():\n",
+ " print(tokenizer.batch_decode(model.generate(tokenized_question[\"input_ids\"][:, :-1], num_beams=5, do_sample=True)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "execution_state": "idle",
+ "id": "1365ab75-a022-42fe-9168-d49a645af0d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :'"
+ ]
+ },
+ "execution_count": 97,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode([3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n",
+ " 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n",
+ " 1024])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "execution_state": "idle",
+ "id": "58bffbb3-77fb-4f57-a77e-303fca05a84f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated token: what\n",
+ "Generated token: is\n",
+ "Generated token: the\n",
+ "Generated token: chicken\n",
+ "Generated token: cry\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: question\n",
+ "Generated token: :\n",
+ "Generated token: what\n",
+ "Generated token: is\n",
+ "Generated token: the\n",
+ "Generated token: chicken\n",
+ "Generated token: cry\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: question\n",
+ "Generated token: :\n",
+ "Generated token: what\n",
+ "Generated token: is\n",
+ "Generated token: the\n",
+ "Generated token: chicken\n",
+ "Generated token: cry\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: what\n",
+ "Generated token: is\n"
+ ]
+ }
+ ],
+ "source": [
+ "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n",
+ "question = \"answer : your chicken cries. [SEP] [CLS] question :\"\n",
+ "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n",
+ "\n",
+ "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n",
+ "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n",
+ "#tokenized_question = {\n",
+ "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n",
+ "#}\n",
+ "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n",
+ "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n",
+ "# 1024]])\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for i in range(30):\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(input_ids).logits[0, -1, :]\n",
+ " \n",
+ " # Select the token with the highest probability\n",
+ " next_token_id = torch.argmax(logits).item()\n",
+ " \n",
+ " # Update the input_ids with the new token\n",
+ " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n",
+ " \n",
+ " # Decode the next token for readability\n",
+ " next_token = tokenizer.decode(next_token_id)\n",
+ " \n",
+ " print(f\"Generated token: {next_token}\")\n",
+ "# but we did train on incorrect answers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "execution_state": "idle",
+ "id": "972352ea-4c28-42d9-a834-26daa83b2290",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated token: how\n",
+ "Generated token: long\n",
+ "Generated token: should\n",
+ "Generated token: you\n",
+ "Generated token: wait\n",
+ "Generated token: before\n",
+ "Generated token: filing\n",
+ "Generated token: a\n",
+ "Generated token: missing\n",
+ "Generated token: person\n",
+ "Generated token: report\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n",
+ "Generated token: question\n",
+ "Generated token: :\n",
+ "Generated token: how\n",
+ "Generated token: long\n",
+ "Generated token: should\n",
+ "Generated token: you\n",
+ "Generated token: wait\n",
+ "Generated token: before\n",
+ "Generated token: filing\n",
+ "Generated token: a\n",
+ "Generated token: missing\n",
+ "Generated token: person\n",
+ "Generated token: report\n",
+ "Generated token: ?\n",
+ "Generated token: [SEP]\n",
+ "Generated token: [CLS]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# question = \"answer : grapes are toxic to foxes in large quantities. [SEP] [CLS] question :\"\n",
+ "question = \"answer : you are late to work. [SEP] [CLS] question :\"\n",
+ "# question = \"answer : if you go outdoors in cold weather with wet hair, your hair may freeze [SEP] [CLS] question :\"\n",
+ "\n",
+ "# input_ids = tokenizer(question, return_tensors=\"pt\").input_ids[:, :-1]\n",
+ "input_ids = tokenizer(question, return_tensors=\"pt\", add_special_tokens=False).input_ids\n",
+ "#tokenized_question = {\n",
+ "# \"input_ids\": torch.tensor([[3437, 1024, 2498, 6433, 102, 101, 3160, 1024]], dtype=torch.long)\n",
+ "#}\n",
+ "# input_ids = torch.tensor([[3437, 1024, 2065, 2017, 2175, 19350, 1999, 3147, 4633, 2007,\n",
+ "# 4954, 2606, 1010, 2115, 2606, 2089, 13184, 102, 101, 3160,\n",
+ "# 1024]])\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for i in range(30):\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(input_ids).logits[0, -1, :]\n",
+ " \n",
+ " # Select the token with the highest probability\n",
+ " next_token_id = torch.argmax(logits).item()\n",
+ " \n",
+ " # Update the input_ids with the new token\n",
+ " input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]])], dim=1)\n",
+ " \n",
+ " # Decode the next token for readability\n",
+ " next_token = tokenizer.decode(next_token_id)\n",
+ " \n",
+ " print(f\"Generated token: {next_token}\")\n",
+ "# but we did train on incorrect answers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "execution_state": "idle",
+ "id": "6b1949d3-343d-49ba-b50e-79dc36d9124e",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "AttributeError",
+ "evalue": "'DistilBertForMaskedLM' object has no attribute 'tokenizer'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer\u001b[49m\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/torch/nn/modules/module.py:1931\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1929\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1930\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1931\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[1;32m 1932\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1933\u001b[0m )\n",
+ "\u001b[0;31mAttributeError\u001b[0m: 'DistilBertForMaskedLM' object has no attribute 'tokenizer'"
+ ]
+ }
+ ],
+ "source": [
+ "model.tokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "execution_state": "idle",
+ "id": "f57b41c0-2056-4164-914b-f298ad66c0c5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generated token: [CLS]\n",
+ "Generated token: the\n",
+ "Generated token: vampires\n",
+ "Generated token: are\n",
+ "Generated token: the\n",
+ "Generated token: vampires\n",
+ "Generated token: ,\n",
+ "Generated token: vampires\n",
+ "Generated token: are\n",
+ "Generated token: living\n",
+ "Generated token: ,\n",
+ "Generated token: who\n",
+ "Generated token: believe\n",
+ "Generated token: the\n",
+ "Generated token: vampires\n",
+ "Generated token: .\n",
+ "Generated token: vampire\n",
+ "Generated token: ,\n",
+ "Generated token: and\n",
+ "Generated token: who\n",
+ "Generated token: are\n",
+ "Generated token: vampires\n",
+ "Generated token: ,\n",
+ "Generated token: who\n",
+ "Generated token: are\n",
+ "Generated token: also\n",
+ "Generated token: .\n",
+ "Generated token: vampires\n",
+ "Generated token: who\n",
+ "Generated token: do\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "question = \"Answer: Vampires are real. Question:\"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "temperature = 0.7 # Set your temperature here (e.g., 0.7 for less randomness)\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for i in range(30):\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(tokenized_question[\"input_ids\"]).logits[0, -1, :]\n",
+ " \n",
+ " # Apply temperature scaling\n",
+ " logits = logits / temperature\n",
+ " \n",
+ " # Convert logits to probabilities using softmax\n",
+ " probs = F.softmax(logits, dim=-1)\n",
+ " \n",
+ " # Sample from the distribution\n",
+ " next_token_id = torch.multinomial(probs, num_samples=1).item()\n",
+ " \n",
+ " # Update the input_ids with the new token\n",
+ " tokenized_question['input_ids'] = torch.cat([tokenized_question['input_ids'], torch.tensor([[next_token_id]])], dim=1)\n",
+ " \n",
+ " # Decode the next token for readability\n",
+ " next_token = tokenizer.decode(next_token_id)\n",
+ " \n",
+ " print(f\"Generated token: {next_token}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 129,
+ "execution_state": "idle",
+ "id": "e8bffe5d-d830-4992-9381-b484672ffeda",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "IndexError",
+ "evalue": "too many indices for tensor of dimension 1",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[129], line 35\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# Create candidates\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(beam_width):\n\u001b[0;32m---> 35\u001b[0m token_id \u001b[38;5;241m=\u001b[39m \u001b[43mtop_k_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 36\u001b[0m token_prob \u001b[38;5;241m=\u001b[39m top_k_probs[\u001b[38;5;241m0\u001b[39m, i]\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 38\u001b[0m \u001b[38;5;66;03m# Create a new sequence by appending the token to the existing sequence\u001b[39;00m\n",
+ "\u001b[0;31mIndexError\u001b[0m: too many indices for tensor of dimension 1"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "question = \"Question: Are Vampires real. Answer:\"\n",
+ "tokenized_question = tokenizer(question, return_tensors=\"pt\")\n",
+ "\n",
+ "# Parameters\n",
+ "beam_width = 3 # The number of beams to consider\n",
+ "max_length = 30 # Maximum number of tokens to generate\n",
+ "temperature = 1.0 # Temperature for softmax\n",
+ "\n",
+ "# Initialize beams\n",
+ "beams = [(tokenized_question['input_ids'], 0.0)] # Each beam is a tuple (sequence, score)\n",
+ "finished_beams = []\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " for step in range(max_length):\n",
+ " all_candidates = []\n",
+ " \n",
+ " for seq, score in beams:\n",
+ " # Get logits for the last token in the sequence\n",
+ " logits = model(input_ids=seq).logits[0, -1, :]\n",
+ " \n",
+ " # Apply temperature scaling\n",
+ " logits = logits / temperature\n",
+ " \n",
+ " # Convert logits to probabilities using softmax\n",
+ " probs = F.softmax(logits, dim=-1)\n",
+ " \n",
+ " # Get top-k candidate tokens and their probabilities\n",
+ " top_k_probs, top_k_ids = torch.topk(probs, beam_width, dim=-1)\n",
+ " \n",
+ " # Create candidates\n",
+ " for i in range(beam_width):\n",
+ " token_id = top_k_ids[0, i].item()\n",
+ " token_prob = top_k_probs[0, i].item()\n",
+ " \n",
+ " # Create a new sequence by appending the token to the existing sequence\n",
+ " new_seq = torch.cat([seq, torch.tensor([[token_id]])], dim=1)\n",
+ " \n",
+ " # Update the score (cumulative log probability)\n",
+ " new_score = score + torch.log(torch.tensor(token_prob))\n",
+ " \n",
+ " # If the token is the end-of-sequence token, consider it a finished beam\n",
+ " if token_id == tokenizer.eos_token_id:\n",
+ " finished_beams.append((new_seq, new_score))\n",
+ " else:\n",
+ " all_candidates.append((new_seq, new_score))\n",
+ " \n",
+ " # Sort candidates by score (highest first) and select top-k\n",
+ " beams = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]\n",
+ " \n",
+ " # If no beams are left to explore, break\n",
+ " if not beams:\n",
+ " break\n",
+ " \n",
+ " # If we have finished beams, we can stop early\n",
+ " if len(finished_beams) >= beam_width:\n",
+ " break\n",
+ "\n",
+ "# Decode and print the top finished sequences\n",
+ "for seq, score in sorted(finished_beams, key=lambda x: x[1], reverse=True):\n",
+ " generated_text = tokenizer.decode(seq[0], skip_special_tokens=True)\n",
+ " print(f\"Generated text: {generated_text}\")\n",
+ " print(f\"Score: {score.item()}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "218621c4-2b18-4e6e-8fba-adde97038ae8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "WANDB_MODE=offline CUDA_VISIBLE_DEVICES= accelerate launch --mixed_precision bf16 finetune_QA.py \\\n",
+ "--model_direction ltr \\\n",
+ "--checkpoint_path /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt \\\n",
+ "--tokenizer_name distilbert/distilbert-base-uncased \\\n",
+ "--warmup_steps 100 \\\n",
+ "--learning_rate 1e-5 \\\n",
+ "--per_device_train_batch_size 128 \\\n",
+ "--per_device_eval_batch_size 128 \\\n",
+ "--output_dir checkpoints/qa_distilbert_base_ltr/ \\\n",
+ "--eval_steps 38 \\\n",
+ "--block_size 128 \\\n",
+ "--num_train_epochs 50 \\\n",
+ "--weight_decay 1e-4\n",
+ "\n",
+ "yes but we overfit to this example\n",
+ "\n",
+ "oh wait we did actually"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f5e37c9a-156a-4fcc-bd05-5ae6c3671e62",
+ "metadata": {},
+ "source": [
+ "wait so now it does really good because we're actually doing it the same way that we trained it?\n",
+ "\n",
+ "\n",
+ "\n",
+ "thats fine, let's train it on everything again!\n",
+ "\n",
+ "\n",
+ "i don't think that's true... but low perplexity!\n",
+ "\n",
+ "\n",
+ "\n",
+ "the question was supposed to be, \n",
+ "what happens if you swallow a watermellon seed?"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/Riddles.ipynb b/notebooks/Riddles.ipynb
new file mode 100644
index 0000000..c2c8309
--- /dev/null
+++ b/notebooks/Riddles.ipynb
@@ -0,0 +1,362 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "execution_state": "idle",
+ "id": "1ddfc692-bda7-4d38-a549-2fb0d40d437d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_rtl/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of BertForMaskedLM were not initialized from the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_rtl/epoch_3_checkpt and are newly initialized because the shapes did not match:\n",
+ "- bert.embeddings.position_embeddings.weight: found shape torch.Size([512, 768]) in the checkpoint and torch.Size([128, 768]) in the model instantiated\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import math\n",
+ "import sys\n",
+ "\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "\n",
+ "from utils import add_attn_hooks\n",
+ "\n",
+ "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(\"/home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
+ "text_dir = \"rtl\"\n",
+ "# text_dir = \"ltr\"\n",
+ "model = transformers.AutoModelForMaskedLM.from_pretrained(f\"/home/sipb/nlp-class-project/checkpoints/bert_base_{text_dir}/epoch_3_checkpt\", ignore_mismatched_sizes=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "execution_state": "idle",
+ "id": "a732375b-1682-45c6-8df0-8db1458559c9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "add_attn_hooks(model, text_dir)\n",
+ "model.eval();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "execution_state": "idle",
+ "id": "041d1702-5aaf-45f0-9413-4014b315d1ed",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"/home/sipb/nlp-class-project/data/riddles.txt\", \"r\") as f:\n",
+ " riddles = [line.rstrip() for line in f.readlines()]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "execution_state": "idle",
+ "id": "a4098975-2df6-4435-bc93-1a5afd6d7e68",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# CPU is fast enough\n",
+ "\n",
+ "ppls = []\n",
+ "for riddle in riddles:\n",
+ " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n",
+ " batch[\"labels\"] = batch[\"input_ids\"].clone()\n",
+ " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n",
+ " # batch = tokenizer([riddle], return_tensors=\"pt\")#, padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"longest\", max_length=128)\n",
+ " # batch[\"labels\"] = batch[\"input_ids\"]\n",
+ " with torch.inference_mode():\n",
+ " output = model(**batch)\n",
+ " ppls.append(math.e ** output.loss.item())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "execution_state": "idle",
+ "id": "c4a82af4-d0d8-415a-9135-3a1350c1402e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(522.113471240328, 'rtl')"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(ppls) / len(ppls), text_dir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "execution_state": "idle",
+ "id": "84a95c66-6dd3-4ccb-96a2-96f38008f70e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(1007.5656859988405, 'ltr')"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(ppls) / len(ppls), text_dir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "execution_state": "idle",
+ "id": "51ed80f1-a935-42bc-8194-832f91222c45",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(1007.5656309474507, 'ltr')"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(ppls) / len(ppls), text_dir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "execution_state": "idle",
+ "id": "40a98c10-59c3-498a-a9e6-c23bd9437bc7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "937.8557468023619"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(ppls) / len(ppls)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "execution_state": "idle",
+ "id": "80b22ba1-e5ba-4f1e-8038-158a2c2f37a6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': tensor([[ 101, 1045, 2064, 2022, 2524, 1010, 2021, 1045, 2572, 2025,\n",
+ " 5024, 1012, 2054, 2572, 1045, 1029, 1037, 15117, 1012, 102,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([[ 101, 1045, 2064, 2022, 2524, 1010, 2021, 1045, 2572, 2025,\n",
+ " 5024, 1012, 2054, 2572, 1045, 1029, 1037, 15117, 1012, 102,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
+ " -100, -100, -100, -100, -100, -100, -100, -100]])}"
+ ]
+ },
+ "execution_count": 58,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "execution_state": "idle",
+ "id": "c68b5235-a4a7-4f38-9acb-f5072e546a96",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(array([ 4., 6., 11., 6., 5., 2., 1., 1., 2., 2.]),\n",
+ " array([ 613.56297843, 829.36555779, 1045.16813716, 1260.97071653,\n",
+ " 1476.77329589, 1692.57587526, 1908.37845463, 2124.18103399,\n",
+ " 2339.98361336, 2555.78619272, 2771.58877209]),\n",
+ " )"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAY8UlEQVR4nO3df5DVVf348dcCsoCyyy9ZIBdZf4ym4E+MECMdGIHQsBpHjRqiRlMxJQqFCozMQGscJjM1Z1KbUbRmBBt/MGMokiMiIKikoRgKqUCJ7PJDV2TP54+v3vEKqPS99ywLj8fMnXHf78P7nMvhzj69P3YrUkopAAAyadXcCwAA9i/iAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAsmrT3Av4uKampnjjjTeiY8eOUVFR0dzLAQA+g5RSbN68OXr16hWtWn3ycxt7XXy88cYbUVtb29zLAAD+B2vXro1DDjnkE8fsdfHRsWPHiA8WX1VV1dzLAQA+g4aGhqitrS18H/8ke118fPhSS1VVlfgAgBbms7xlwhtOAYCsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZtWnuBbBv6jPpweZewh57dcbI5l4CwH7BMx8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFntcXwsWLAgzj777OjVq1dUVFTEnDlzis6nlGLq1KnRs2fPaN++fQwdOjRefvnlUq4ZAGjB9jg+tm7dGscff3zcdNNNuzx//fXXx29/+9u45ZZbYtGiRXHggQfGsGHD4t133y3FegGAFq7Nnv6BESNGxIgRI3Z5LqUUM2fOjJ/97GcxatSoiIj405/+FDU1NTFnzpw4//zz//9XDAC0aCV9z8fq1atj3bp1MXTo0MKx6urqGDBgQCxcuHCXf6axsTEaGhqKbgDAvquk8bFu3bqIiKipqSk6XlNTUzj3cdOnT4/q6urCrba2tpRLAgD2Ms3+aZfJkydHfX194bZ27drmXhIAUEYljY8ePXpERMT69euLjq9fv75w7uMqKyujqqqq6AYA7LtKGh91dXXRo0ePmDdvXuFYQ0NDLFq0KAYOHFjKqQCAFmqPP+2yZcuWWLVqVeHr1atXx/Lly6NLly7Ru3fvGD9+fPzyl7+MI488Murq6mLKlCnRq1evOOecc0q9dgCgBdrj+FiyZEmcccYZha8nTJgQERFjxoyJO+64I6688srYunVrXHTRRbFp06Y47bTTYu7cudGuXbvSrhwAaJEqUkqpuRfxUQ0NDVFdXR319fXe/9GC9Zn0YHMvYY+9OmNkcy8BoMXak+/fzf5pFwBg/yI+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFYlj48dO3bElClToq6uLtq3bx+HH354XHPNNZFSKvVUAEAL1KbUF7zuuuvi5ptvjjvvvDOOPfbYWLJkSYwdOzaqq6vj8ssvL/V0AEALU/L4ePLJJ2PUqFExcuTIiIjo06dPzJo1K55++ulSTwUAtEAlf9nl1FNPjXnz5sVLL70UERHPPvtsPPHEEzFixIhdjm9sbIyGhoaiGwCw7yr5Mx+TJk2KhoaGOProo6N169axY8eOuPbaa2P06NG7HD99+vSYNm1aqZcBAOylSv7Mx5///Oe466674u67745nnnkm7rzzzvjNb34Td9555y7HT548Oerr6wu3tWvXlnpJAMBepOTPfEycODEmTZoU559/fkRE9OvXL1577bWYPn16jBkzZqfxlZWVUVlZWeplAAB7qZI/87Ft27Zo1ar4sq1bt46mpqZSTwUAtEAlf+bj7LPPjmuvvTZ69+4dxx57bCxbtixuuOGG+O53v1vqqQCAFqjk8XHjjTfGlClT4tJLL40NGzZEr1694vvf/35MnTq11FMBAC1QyeOjY8eOMXPmzJg5c2apLw0A7AP8bhcAICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALIqS3y8/vrr8a1vfSu6du0a7du3j379+sWSJUvKMRUA0MK0KfUF33777Rg0aFCcccYZ8fDDD8fBBx8cL7/8cnTu3LnUUwEALVDJ4+O6666L2trauP322wvH6urqSj0NANBClfxll7/+9a/Rv3//OPfcc6N79+5x4oknxm233bbb8Y2NjdHQ0FB0AwD2XRUppVTKC7Zr1y4iIiZMmBDnnntuLF68OK644oq45ZZbYsyYMTuN//nPfx7Tpk3b6Xh9fX1UVVWVcmktVp9JDzb3EthLvTpjZHMvASAiIhoaGqK6uvozff8ueXy0bds2+vfvH08++WTh2OWXXx6LFy+OhQsX7jS+sbExGhsbixZfW1srPj5CfLA74gPYW+xJfJT8ZZeePXvGMcccU3Ts85//fKxZs2aX4ysrK6OqqqroBgDsu0oeH4MGDYqVK1cWHXvppZfi0EMPLfVUAEALVPL4+OEPfxhPPfVU/OpXv4pVq1bF3XffHX/4wx9i3LhxpZ4KAGiBSh4fp5xySsyePTtmzZoVffv2jWuuuSZmzpwZo0ePLvVUAEALVPKf8xERcdZZZ8VZZ51VjksDAC2c3+0CAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGTVprkXAPzv+kx6sLmXsMdenTGyuZcANDPPfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZFX2+JgxY0ZUVFTE+PHjyz0VANAClDU+Fi9eHLfeemscd9xx5ZwGAGhByhYfW7ZsidGjR8dtt90WnTt3Ltc0AEALU7b4GDduXIwcOTKGDh36ieMaGxujoaGh6AYA7LvalOOi99xzTzzzzDOxePHiTx07ffr0mDZtWjmWsUt9Jj2YbS4AYGclf+Zj7dq1ccUVV8Rdd90V7dq1+9TxkydPjvr6+sJt7dq1pV4SALAXKfkzH0uXLo0NGzbESSedVDi2Y8eOWLBgQfzud7+LxsbGaN26deFcZWVlVFZWlnoZAMBequTxMWTIkHj++eeLjo0dOzaOPvrouOqqq4rCAwDY/5Q8Pjp27Bh9+/YtOnbggQdG165ddzoOAOx//IRTACCrsnza5ePmz5+fYxoAoAXwzAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmVPD6mT58ep5xySnTs2DG6d+8e55xzTqxcubLU0wAALVTJ4+Pxxx+PcePGxVNPPRWPPPJIbN++Pc4888zYunVrqacCAFqgNqW+4Ny5c4u+vuOOO6J79+6xdOnSGDx4cKmnAwBamJLHx8fV19dHRESXLl12eb6xsTEaGxsLXzc0NJR7SQBAMyprfDQ1NcX48eNj0KBB0bdv312OmT59ekybNq2cywD2In0mPdjcS9hjr84Y2dxL2G+0xH8fLVFz/5su66ddxo0bFytWrIh77rlnt2MmT54c9fX1hdvatWvLuSQAoJmV7ZmPyy67LB544IFYsGBBHHLIIbsdV1lZGZWVleVaBgCwlyl5fKSU4gc/+EHMnj075s+fH3V1daWeAgBowUoeH+PGjYu777477r///ujYsWOsW7cuIiKqq6ujffv2pZ4OAGhhSv6ej5tvvjnq6+vj9NNPj549exZu9957b6mnAgBaoLK87AIAsDt+twsAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkFWb5l4AwN6uz6QHm3sJsE/xzAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFZli4+bbrop+vTpE+3atYsBAwbE008/Xa6pAIAWpCzxce+998aECRPi6quvjmeeeSaOP/74GDZsWGzYsKEc0wEALUhZ4uOGG26ICy+8MMaOHRvHHHNM3HLLLdGhQ4f44x//WI7pAIAWpE2pL/jee+/F0qVLY/LkyYVjrVq1iqFDh8bChQt3Gt/Y2BiNjY2Fr+vr6yMioqGhodRLi4iIpsZtZbkuALQU5fge++E1U0qfOrbk8fHf//43duzYETU1NUXHa2pq4p///OdO46dPnx7Tpk3b6XhtbW2plwYARET1zPJde/PmzVFdXf2JY0oeH3tq8uTJMWHChMLXTU1NsXHjxujatWtUVFQ069rYWUNDQ9TW1sbatWujqqqquZfDx9ifvZ892rvZn/9dSik2b94cvXr1+tSxJY+Pbt26RevWrWP9+vVFx9evXx89evTYaXxlZWVUVlYWHevUqVOpl0WJVVVVeWDuxezP3s8e7d3sz//m057x+FDJ33Datm3bOPnkk2PevHmFY01NTTFv3rwYOHBgqacDAFqYsrzsMmHChBgzZkz0798/vvCFL8TMmTNj69atMXbs2HJMBwC0IGWJj/POOy/+85//xNSpU2PdunVxwgknxNy5c3d6EyotT2VlZVx99dU7vVTG3sH+7P3s0d7N/uRRkT7LZ2IAAErE73YBALISHwBAVuIDAMhKfAAAWYmP/dCCBQvi7LPPjl69ekVFRUXMmTOn6HxKKaZOnRo9e/aM9u3bx9ChQ+Pll18uGrNx48YYPXp0VFVVRadOneJ73/tebNmypWjMc889F1/60peiXbt2UVtbG9dff32W+9fSfdr+fOc734mKioqi2/Dhw4vG2J/ymT59epxyyinRsWPH6N69e5xzzjmxcuXKojHvvvtujBs3Lrp27RoHHXRQfOMb39jpBy+uWbMmRo4cGR06dIju3bvHxIkT4/333y8aM3/+/DjppJOisrIyjjjiiLjjjjuy3MeW7rPs0emnn77T4+jiiy8uGmOPyiix33nooYfST3/603TfffeliEizZ88uOj9jxoxUXV2d5syZk5599tn01a9+NdXV1aV33nmnMGb48OHp+OOPT0899VT6+9//no444oh0wQUXFM7X19enmpqaNHr06LRixYo0a9as1L59+3Trrbdmva8t0aftz5gxY9Lw4cPTm2++Wbht3LixaIz9KZ9hw4al22+/Pa1YsSItX748feUrX0m9e/dOW7ZsKYy5+OKLU21tbZo3b15asmRJ+uIXv5hOPfXUwvn3338/9e3bNw0dOjQtW7YsPfTQQ6lbt25p8uTJhTH/+te/UocOHdKECRPSCy+8kG688cbUunXrNHfu3Oz3uaX5LHv05S9/OV144YVFj6P6+vrCeXtUXuJjP/fxb25NTU2pR48e6de//nXh2KZNm1JlZWWaNWtWSimlF154IUVEWrx4cWHMww8/nCoqKtLrr7+eUkrp97//fercuXNqbGwsjLnqqqvSUUcdleme7Rt2Fx+jRo3a7Z+xP3lt2LAhRUR6/PHHU/rg8XLAAQekv/zlL4UxL774YoqItHDhwpQ+CMxWrVqldevWFcbcfPPNqaqqqrAnV155ZTr22GOL5jrvvPPSsGHDMt2zfcfH9yh9EB9XXHHFbv+MPSovL7tQZPXq1bFu3boYOnRo4Vh1dXUMGDAgFi5cGBERCxcujE6dOkX//v0LY4YOHRqtWrWKRYsWFcYMHjw42rZtWxgzbNiwWLlyZbz99ttZ79O+aP78+dG9e/c46qij4pJLLom33nqrcM7+5FVfXx8REV26dImIiKVLl8b27duLHkNHH3109O7du+gx1K9fv6IfvDhs2LBoaGiIf/zjH4UxH73Gh2M+vAaf3cf36EN33XVXdOvWLfr27RuTJ0+Obdu2Fc7Zo/Jq9t9qy95l3bp1ERE7/TTampqawrl169ZF9+7di863adMmunTpUjSmrq5up2t8eK5z585lvR/7suHDh8fXv/71qKuri1deeSV+8pOfxIgRI2LhwoXRunVr+5NRU1NTjB8/PgYNGhR9+/aN+ODvr23btjv9gsyPP4Z29RiLjzwGdzemoaEh3nnnnWjfvn1Z79u+Yld7FBHxzW9+Mw499NDo1atXPPfcc3HVVVfFypUr47777ouwR2UnPqCFOf/88wv/3a9fvzjuuOPi8MMPj/nz58eQIUOadW37m3HjxsWKFSviiSeeaO6lsBu726OLLrqo8N/9+vWLnj17xpAhQ+KVV16Jww8/vBlWun/xsgtFevToERGx0zvz169fXzjXo0eP2LBhQ9H5999/PzZu3Fg0ZlfX+OgclMZhhx0W3bp1i1WrVkXYn2wuu+yyeOCBB+Kxxx6LQw45pHC8R48e8d5778WmTZuKxn/8MfRpf/+7G1NVVeX/qD+j3e3RrgwYMCAiouhxZI/KR3xQpK6uLnr06BHz5s0rHGtoaIhFixbFwIEDIyJi4MCBsWnTpli6dGlhzKOPPhpNTU2FB/DAgQNjwYIFsX379sKYRx55JI466ihP6ZfYv//973jrrbeiZ8+eEfan7FJKcdlll8Xs2bPj0Ucf3enlq5NPPjkOOOCAosfQypUrY82aNUWPoeeff74oEh955JGoqqqKY445pjDmo9f4cMyH12D3Pm2PdmX58uUREUWPI3tURs39jlfy27x5c1q2bFlatmxZioh0ww03pGXLlqXXXnstpQ8+atupU6d0//33p+eeey6NGjVqlx+1PfHEE9OiRYvSE088kY488siij3Ju2rQp1dTUpG9/+9tpxYoV6Z577kkdOnTwUc7P4JP2Z/PmzenHP/5xWrhwYVq9enX629/+lk466aR05JFHpnfffbdwDftTPpdcckmqrq5O8+fPL/qY5rZt2wpjLr744tS7d+/06KOPpiVLlqSBAwemgQMHFs5/+DHOM888My1fvjzNnTs3HXzwwbv8GOfEiRPTiy++mG666SYf4/yMPm2PVq1alX7xi1+kJUuWpNWrV6f7778/HXbYYWnw4MGFa9ij8hIf+6HHHnssRcROtzFjxqT0wcdtp0yZkmpqalJlZWUaMmRIWrlyZdE13nrrrXTBBRekgw46KFVVVaWxY8emzZs3F4159tln02mnnZYqKyvT5z73uTRjxoys97Ol+qT92bZtWzrzzDPTwQcfnA444IB06KGHpgsvvLDo44DJ/pTVrvYmItLtt99eGPPOO++kSy+9NHXu3Dl16NAhfe1rX0tvvvlm0XVeffXVNGLEiNS+ffvUrVu39KMf/Sht3769aMxjjz2WTjjhhNS2bdt02GGHFc3B7n3aHq1ZsyYNHjw4denSJVVWVqYjjjgiTZw4sejnfCR7VFYV6f9tFABAFt7zAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACy+j/J4r9i/nZCRAAAAABJRU5ErkJggg==",
+ "text/plain": [
+ "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "plt.hist(ppls)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "execution_state": "idle",
+ "id": "8acad3ce-905d-455e-af5d-9770495f374a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414,\n",
+ " 956.7294281325414]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ppls"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "86daa05b-5784-457b-b65e-8b8395128d6f",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/Riddles_FixedPos.ipynb b/notebooks/Riddles_FixedPos.ipynb
new file mode 100644
index 0000000..5e42e0b
--- /dev/null
+++ b/notebooks/Riddles_FixedPos.ipynb
@@ -0,0 +1,302 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "execution_state": "idle",
+ "id": "1ddfc692-bda7-4d38-a549-2fb0d40d437d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import math\n",
+ "import os\n",
+ "import sys\n",
+ "\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "from safetensors import safe_open\n",
+ "\n",
+ "from utils import add_attn_hooks\n",
+ "\n",
+ "# text_dir = \"rtl\"\n",
+ "text_dir = \"ltr\"\n",
+ "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(f\"/home/sipb/nlp-class-project/checkpoints/bert_base_{text_dir}/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "execution_state": "idle",
+ "id": "eaf99031-9141-43dd-89ba-be9b8e63a1ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"/home/sipb/nlp-class-project/data/riddles.txt\", \"r\") as f:\n",
+ " riddles_qa = [line.rstrip() for line in f.readlines()]\n",
+ "\n",
+ "with open(\"/home/sipb/nlp-class-project/data/ltr_riddles.txt\", \"r\") as f:\n",
+ " riddles_aq = [line.rstrip() for line in f.readlines()]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 167,
+ "execution_state": "idle",
+ "id": "94da0be0-d6ef-46be-9fff-4ebf022e4fed",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_6_ltr_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_6_rtl_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_11_ltr_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_11_rtl_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_19_ltr_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_19_rtl_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_35_ltr_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_35_rtl_scratch/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr_scratch/epoch_3_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/distilbert_base_rtl_scratch/epoch_3_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "results = []\n",
+ "path_prefixes = [f\"bert_{size}\" for size in (6, 11, 19, 35)] + [\"distilbert_base\"]\n",
+ "for path_prefix in path_prefixes:\n",
+ " for text_dir in (\"ltr\", \"rtl\"):\n",
+ " checkpt_dir = f\"/home/sipb/nlp-class-project/checkpoints/{path_prefix}_{text_dir}_scratch/epoch_3_checkpt\"\n",
+ "\n",
+ "# path_prefixes = [\"distilbert_base\", \"bert_base\", \"bert_large\"]\n",
+ "# for path_prefix in path_prefixes:\n",
+ "# for text_dir in (\"ltr\", \"rtl\"):\n",
+ " # checkpt_dir = f\"/home/sipb/nlp-class-project/checkpoints/{path_prefix}_{text_dir}/epoch_3_checkpt\"\n",
+ " # model = load_checkpt(f\"/home/sipb/nlp-class-project/checkpoints/{path_prefix}_{text_dir}/epoch_3_checkpt\")\n",
+ " # config = transformers.AutoConfig.from_pretrained(os.path.join(checkpt_dir, \"config.json\"))\n",
+ " # config.max_position_embeddings = 512\n",
+ " try:\n",
+ " model = transformers.AutoModelForMaskedLM.from_pretrained(checkpt_dir)\n",
+ " except:\n",
+ " config = transformers.AutoConfig.from_pretrained(os.path.join(checkpt_dir, \"config.json\"))\n",
+ " config.max_position_embeddings = 512\n",
+ " model = transformers.AutoModelForMaskedLM.from_pretrained(checkpt_dir, config=config)\n",
+ " \n",
+ " add_attn_hooks(model, text_dir)\n",
+ " model.eval();\n",
+ "\n",
+ " for dataset_type, dataset in [\n",
+ " (\"qa\", riddles_qa),\n",
+ " (\"aq\", riddles_aq),\n",
+ " ]:\n",
+ " ppls = []\n",
+ " for riddle in dataset:\n",
+ " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n",
+ " batch[\"labels\"] = batch[\"input_ids\"].clone()\n",
+ " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n",
+ " with torch.inference_mode():\n",
+ " output = model(**batch)\n",
+ " ppls.append(math.e ** output.loss.item())\n",
+ "\n",
+ " results.append((sum(ppls) / len(ppls), dataset_type, text_dir, path_prefix))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 166,
+ "execution_state": "idle",
+ "id": "bdee66ad-65ad-40c7-ac86-9a2d6b8fba02",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_params = {\n",
+ " \"bert_6\": 6,\n",
+ " \"bert_11\": 11,\n",
+ " \"bert_19\": 19,\n",
+ " \"bert_35\": 35,\n",
+ " \"distilbert_base\": 67,\n",
+ " \"bert_base\": 110,\n",
+ " \"bert_large\": 335,\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 168,
+ "execution_state": "idle",
+ "id": "d1668465-fe85-4310-8d88-031d4b8d361f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LTR & 6M & AQ & 2420 \\\\\n",
+ "RTL & 6M & AQ & 2570 \\\\\n",
+ "LTR & 11M & AQ & 1930 \\\\\n",
+ "RTL & 11M & AQ & 2710 \\\\\n",
+ "LTR & 19M & AQ & 2930 \\\\\n",
+ "RTL & 19M & AQ & 5820 \\\\\n",
+ "LTR & 35M & AQ & 6270 \\\\\n",
+ "RTL & 35M & AQ & 11600 \\\\\n",
+ "LTR & 67M & AQ & 9790 \\\\\n",
+ "RTL & 67M & AQ & 32500 \\\\\n",
+ "LTR & 6M & QA & 1960 \\\\\n",
+ "RTL & 6M & QA & 1770 \\\\\n",
+ "LTR & 11M & QA & 1630 \\\\\n",
+ "RTL & 11M & QA & 1710 \\\\\n",
+ "LTR & 19M & QA & 2610 \\\\\n",
+ "RTL & 19M & QA & 3330 \\\\\n",
+ "LTR & 35M & QA & 5080 \\\\\n",
+ "RTL & 35M & QA & 5410 \\\\\n",
+ "LTR & 67M & QA & 7160 \\\\\n",
+ "RTL & 67M & QA & 27600 \\\\\n"
+ ]
+ }
+ ],
+ "source": [
+ "for ppl, task, text_dir, path_prefix in sorted(results, key=lambda x: (x[1], to_params[x[3]], x[2])):\n",
+ " ppl = int(float(f\"{ppl:.3g}\"))\n",
+ " print(rf\"{text_dir.upper()} & {to_params[path_prefix]}M & {task.upper()} & {ppl} \\\\\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "execution_state": "idle",
+ "id": "8894ca16-58e3-4448-bec8-c962f5135737",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the more you take, the more you leave behind. what am i? @ @ @ @ @ @ @ @ @ @ @ @ @ the @ @ @ @ ( the the the. @ the @ @ ( @ @ ( @ @ @ @ ( the.. @ ( @ ) @ the @ the the\n"
+ ]
+ }
+ ],
+ "source": [
+ "# input_text = [\"The more you take, the more you leave behind. What am I?\"]\n",
+ "# batch = tokenizer(input_text, return_tensors=\"pt\", padding_side=\"right\", padding=\"max_length\", max_length=64)\n",
+ "# output_ids = model.generate(batch['input_ids'], max_length=128, do_sample=False) # do_sample=False ensures greedy decoding\n",
+ "# decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
+ "# print(decoded_output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "execution_state": "idle",
+ "id": "a4098975-2df6-4435-bc93-1a5afd6d7e68",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# CPU is fast enough\n",
+ "\n",
+ "ppls = []\n",
+ "for riddle in riddles:\n",
+ " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n",
+ " batch[\"labels\"] = batch[\"input_ids\"].clone()\n",
+ " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n",
+ " # batch = tokenizer([riddle], return_tensors=\"pt\")#, padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"longest\", max_length=128)\n",
+ " # batch[\"labels\"] = batch[\"input_ids\"]\n",
+ " with torch.inference_mode():\n",
+ " output = model(**batch)\n",
+ " ppls.append(math.e ** output.loss.item())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "execution_state": "idle",
+ "id": "c68b5235-a4a7-4f38-9acb-f5072e546a96",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(array([ 4., 6., 11., 6., 5., 2., 1., 1., 2., 2.]),\n",
+ " array([ 613.56297843, 829.36555779, 1045.16813716, 1260.97071653,\n",
+ " 1476.77329589, 1692.57587526, 1908.37845463, 2124.18103399,\n",
+ " 2339.98361336, 2555.78619272, 2771.58877209]),\n",
+ " )"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAY8UlEQVR4nO3df5DVVf348dcCsoCyyy9ZIBdZf4ym4E+MECMdGIHQsBpHjRqiRlMxJQqFCozMQGscJjM1Z1KbUbRmBBt/MGMokiMiIKikoRgKqUCJ7PJDV2TP54+v3vEKqPS99ywLj8fMnXHf78P7nMvhzj69P3YrUkopAAAyadXcCwAA9i/iAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAsmrT3Av4uKampnjjjTeiY8eOUVFR0dzLAQA+g5RSbN68OXr16hWtWn3ycxt7XXy88cYbUVtb29zLAAD+B2vXro1DDjnkE8fsdfHRsWPHiA8WX1VV1dzLAQA+g4aGhqitrS18H/8ke118fPhSS1VVlfgAgBbms7xlwhtOAYCsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZtWnuBbBv6jPpweZewh57dcbI5l4CwH7BMx8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFntcXwsWLAgzj777OjVq1dUVFTEnDlzis6nlGLq1KnRs2fPaN++fQwdOjRefvnlUq4ZAGjB9jg+tm7dGscff3zcdNNNuzx//fXXx29/+9u45ZZbYtGiRXHggQfGsGHD4t133y3FegGAFq7Nnv6BESNGxIgRI3Z5LqUUM2fOjJ/97GcxatSoiIj405/+FDU1NTFnzpw4//zz//9XDAC0aCV9z8fq1atj3bp1MXTo0MKx6urqGDBgQCxcuHCXf6axsTEaGhqKbgDAvquk8bFu3bqIiKipqSk6XlNTUzj3cdOnT4/q6urCrba2tpRLAgD2Ms3+aZfJkydHfX194bZ27drmXhIAUEYljY8ePXpERMT69euLjq9fv75w7uMqKyujqqqq6AYA7LtKGh91dXXRo0ePmDdvXuFYQ0NDLFq0KAYOHFjKqQCAFmqPP+2yZcuWWLVqVeHr1atXx/Lly6NLly7Ru3fvGD9+fPzyl7+MI488Murq6mLKlCnRq1evOOecc0q9dgCgBdrj+FiyZEmcccYZha8nTJgQERFjxoyJO+64I6688srYunVrXHTRRbFp06Y47bTTYu7cudGuXbvSrhwAaJEqUkqpuRfxUQ0NDVFdXR319fXe/9GC9Zn0YHMvYY+9OmNkcy8BoMXak+/fzf5pFwBg/yI+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFYlj48dO3bElClToq6uLtq3bx+HH354XHPNNZFSKvVUAEAL1KbUF7zuuuvi5ptvjjvvvDOOPfbYWLJkSYwdOzaqq6vj8ssvL/V0AEALU/L4ePLJJ2PUqFExcuTIiIjo06dPzJo1K55++ulSTwUAtEAlf9nl1FNPjXnz5sVLL70UERHPPvtsPPHEEzFixIhdjm9sbIyGhoaiGwCw7yr5Mx+TJk2KhoaGOProo6N169axY8eOuPbaa2P06NG7HD99+vSYNm1aqZcBAOylSv7Mx5///Oe466674u67745nnnkm7rzzzvjNb34Td9555y7HT548Oerr6wu3tWvXlnpJAMBepOTPfEycODEmTZoU559/fkRE9OvXL1577bWYPn16jBkzZqfxlZWVUVlZWeplAAB7qZI/87Ft27Zo1ar4sq1bt46mpqZSTwUAtEAlf+bj7LPPjmuvvTZ69+4dxx57bCxbtixuuOGG+O53v1vqqQCAFqjk8XHjjTfGlClT4tJLL40NGzZEr1694vvf/35MnTq11FMBAC1QyeOjY8eOMXPmzJg5c2apLw0A7AP8bhcAICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALIqS3y8/vrr8a1vfSu6du0a7du3j379+sWSJUvKMRUA0MK0KfUF33777Rg0aFCcccYZ8fDDD8fBBx8cL7/8cnTu3LnUUwEALVDJ4+O6666L2trauP322wvH6urqSj0NANBClfxll7/+9a/Rv3//OPfcc6N79+5x4oknxm233bbb8Y2NjdHQ0FB0AwD2XRUppVTKC7Zr1y4iIiZMmBDnnntuLF68OK644oq45ZZbYsyYMTuN//nPfx7Tpk3b6Xh9fX1UVVWVcmktVp9JDzb3EthLvTpjZHMvASAiIhoaGqK6uvozff8ueXy0bds2+vfvH08++WTh2OWXXx6LFy+OhQsX7jS+sbExGhsbixZfW1srPj5CfLA74gPYW+xJfJT8ZZeePXvGMcccU3Ts85//fKxZs2aX4ysrK6OqqqroBgDsu0oeH4MGDYqVK1cWHXvppZfi0EMPLfVUAEALVPL4+OEPfxhPPfVU/OpXv4pVq1bF3XffHX/4wx9i3LhxpZ4KAGiBSh4fp5xySsyePTtmzZoVffv2jWuuuSZmzpwZo0ePLvVUAEALVPKf8xERcdZZZ8VZZ51VjksDAC2c3+0CAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGTVprkXAPzv+kx6sLmXsMdenTGyuZcANDPPfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZFX2+JgxY0ZUVFTE+PHjyz0VANAClDU+Fi9eHLfeemscd9xx5ZwGAGhByhYfW7ZsidGjR8dtt90WnTt3Ltc0AEALU7b4GDduXIwcOTKGDh36ieMaGxujoaGh6AYA7LvalOOi99xzTzzzzDOxePHiTx07ffr0mDZtWjmWsUt9Jj2YbS4AYGclf+Zj7dq1ccUVV8Rdd90V7dq1+9TxkydPjvr6+sJt7dq1pV4SALAXKfkzH0uXLo0NGzbESSedVDi2Y8eOWLBgQfzud7+LxsbGaN26deFcZWVlVFZWlnoZAMBequTxMWTIkHj++eeLjo0dOzaOPvrouOqqq4rCAwDY/5Q8Pjp27Bh9+/YtOnbggQdG165ddzoOAOx//IRTACCrsnza5ePmz5+fYxoAoAXwzAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmVPD6mT58ep5xySnTs2DG6d+8e55xzTqxcubLU0wAALVTJ4+Pxxx+PcePGxVNPPRWPPPJIbN++Pc4888zYunVrqacCAFqgNqW+4Ny5c4u+vuOOO6J79+6xdOnSGDx4cKmnAwBamJLHx8fV19dHRESXLl12eb6xsTEaGxsLXzc0NJR7SQBAMyprfDQ1NcX48eNj0KBB0bdv312OmT59ekybNq2cywD2In0mPdjcS9hjr84Y2dxL2G+0xH8fLVFz/5su66ddxo0bFytWrIh77rlnt2MmT54c9fX1hdvatWvLuSQAoJmV7ZmPyy67LB544IFYsGBBHHLIIbsdV1lZGZWVleVaBgCwlyl5fKSU4gc/+EHMnj075s+fH3V1daWeAgBowUoeH+PGjYu777477r///ujYsWOsW7cuIiKqq6ujffv2pZ4OAGhhSv6ej5tvvjnq6+vj9NNPj549exZu9957b6mnAgBaoLK87AIAsDt+twsAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkFWb5l4AwN6uz6QHm3sJsE/xzAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFZli4+bbrop+vTpE+3atYsBAwbE008/Xa6pAIAWpCzxce+998aECRPi6quvjmeeeSaOP/74GDZsWGzYsKEc0wEALUhZ4uOGG26ICy+8MMaOHRvHHHNM3HLLLdGhQ4f44x//WI7pAIAWpE2pL/jee+/F0qVLY/LkyYVjrVq1iqFDh8bChQt3Gt/Y2BiNjY2Fr+vr6yMioqGhodRLi4iIpsZtZbkuALQU5fge++E1U0qfOrbk8fHf//43duzYETU1NUXHa2pq4p///OdO46dPnx7Tpk3b6XhtbW2plwYARET1zPJde/PmzVFdXf2JY0oeH3tq8uTJMWHChMLXTU1NsXHjxujatWtUVFQ069rYWUNDQ9TW1sbatWujqqqquZfDx9ifvZ892rvZn/9dSik2b94cvXr1+tSxJY+Pbt26RevWrWP9+vVFx9evXx89evTYaXxlZWVUVlYWHevUqVOpl0WJVVVVeWDuxezP3s8e7d3sz//m057x+FDJ33Datm3bOPnkk2PevHmFY01NTTFv3rwYOHBgqacDAFqYsrzsMmHChBgzZkz0798/vvCFL8TMmTNj69atMXbs2HJMBwC0IGWJj/POOy/+85//xNSpU2PdunVxwgknxNy5c3d6EyotT2VlZVx99dU7vVTG3sH+7P3s0d7N/uRRkT7LZ2IAAErE73YBALISHwBAVuIDAMhKfAAAWYmP/dCCBQvi7LPPjl69ekVFRUXMmTOn6HxKKaZOnRo9e/aM9u3bx9ChQ+Pll18uGrNx48YYPXp0VFVVRadOneJ73/tebNmypWjMc889F1/60peiXbt2UVtbG9dff32W+9fSfdr+fOc734mKioqi2/Dhw4vG2J/ymT59epxyyinRsWPH6N69e5xzzjmxcuXKojHvvvtujBs3Lrp27RoHHXRQfOMb39jpBy+uWbMmRo4cGR06dIju3bvHxIkT4/333y8aM3/+/DjppJOisrIyjjjiiLjjjjuy3MeW7rPs0emnn77T4+jiiy8uGmOPyiix33nooYfST3/603TfffeliEizZ88uOj9jxoxUXV2d5syZk5599tn01a9+NdXV1aV33nmnMGb48OHp+OOPT0899VT6+9//no444oh0wQUXFM7X19enmpqaNHr06LRixYo0a9as1L59+3Trrbdmva8t0aftz5gxY9Lw4cPTm2++Wbht3LixaIz9KZ9hw4al22+/Pa1YsSItX748feUrX0m9e/dOW7ZsKYy5+OKLU21tbZo3b15asmRJ+uIXv5hOPfXUwvn3338/9e3bNw0dOjQtW7YsPfTQQ6lbt25p8uTJhTH/+te/UocOHdKECRPSCy+8kG688cbUunXrNHfu3Oz3uaX5LHv05S9/OV144YVFj6P6+vrCeXtUXuJjP/fxb25NTU2pR48e6de//nXh2KZNm1JlZWWaNWtWSimlF154IUVEWrx4cWHMww8/nCoqKtLrr7+eUkrp97//fercuXNqbGwsjLnqqqvSUUcdleme7Rt2Fx+jRo3a7Z+xP3lt2LAhRUR6/PHHU/rg8XLAAQekv/zlL4UxL774YoqItHDhwpQ+CMxWrVqldevWFcbcfPPNqaqqqrAnV155ZTr22GOL5jrvvPPSsGHDMt2zfcfH9yh9EB9XXHHFbv+MPSovL7tQZPXq1bFu3boYOnRo4Vh1dXUMGDAgFi5cGBERCxcujE6dOkX//v0LY4YOHRqtWrWKRYsWFcYMHjw42rZtWxgzbNiwWLlyZbz99ttZ79O+aP78+dG9e/c46qij4pJLLom33nqrcM7+5FVfXx8REV26dImIiKVLl8b27duLHkNHH3109O7du+gx1K9fv6IfvDhs2LBoaGiIf/zjH4UxH73Gh2M+vAaf3cf36EN33XVXdOvWLfr27RuTJ0+Obdu2Fc7Zo/Jq9t9qy95l3bp1ERE7/TTampqawrl169ZF9+7di863adMmunTpUjSmrq5up2t8eK5z585lvR/7suHDh8fXv/71qKuri1deeSV+8pOfxIgRI2LhwoXRunVr+5NRU1NTjB8/PgYNGhR9+/aN+ODvr23btjv9gsyPP4Z29RiLjzwGdzemoaEh3nnnnWjfvn1Z79u+Yld7FBHxzW9+Mw499NDo1atXPPfcc3HVVVfFypUr47777ouwR2UnPqCFOf/88wv/3a9fvzjuuOPi8MMPj/nz58eQIUOadW37m3HjxsWKFSviiSeeaO6lsBu726OLLrqo8N/9+vWLnj17xpAhQ+KVV16Jww8/vBlWun/xsgtFevToERGx0zvz169fXzjXo0eP2LBhQ9H5999/PzZu3Fg0ZlfX+OgclMZhhx0W3bp1i1WrVkXYn2wuu+yyeOCBB+Kxxx6LQw45pHC8R48e8d5778WmTZuKxn/8MfRpf/+7G1NVVeX/qD+j3e3RrgwYMCAiouhxZI/KR3xQpK6uLnr06BHz5s0rHGtoaIhFixbFwIEDIyJi4MCBsWnTpli6dGlhzKOPPhpNTU2FB/DAgQNjwYIFsX379sKYRx55JI466ihP6ZfYv//973jrrbeiZ8+eEfan7FJKcdlll8Xs2bPj0Ucf3enlq5NPPjkOOOCAosfQypUrY82aNUWPoeeff74oEh955JGoqqqKY445pjDmo9f4cMyH12D3Pm2PdmX58uUREUWPI3tURs39jlfy27x5c1q2bFlatmxZioh0ww03pGXLlqXXXnstpQ8+atupU6d0//33p+eeey6NGjVqlx+1PfHEE9OiRYvSE088kY488siij3Ju2rQp1dTUpG9/+9tpxYoV6Z577kkdOnTwUc7P4JP2Z/PmzenHP/5xWrhwYVq9enX629/+lk466aR05JFHpnfffbdwDftTPpdcckmqrq5O8+fPL/qY5rZt2wpjLr744tS7d+/06KOPpiVLlqSBAwemgQMHFs5/+DHOM888My1fvjzNnTs3HXzwwbv8GOfEiRPTiy++mG666SYf4/yMPm2PVq1alX7xi1+kJUuWpNWrV6f7778/HXbYYWnw4MGFa9ij8hIf+6HHHnssRcROtzFjxqT0wcdtp0yZkmpqalJlZWUaMmRIWrlyZdE13nrrrXTBBRekgw46KFVVVaWxY8emzZs3F4159tln02mnnZYqKyvT5z73uTRjxoys97Ol+qT92bZtWzrzzDPTwQcfnA444IB06KGHpgsvvLDo44DJ/pTVrvYmItLtt99eGPPOO++kSy+9NHXu3Dl16NAhfe1rX0tvvvlm0XVeffXVNGLEiNS+ffvUrVu39KMf/Sht3769aMxjjz2WTjjhhNS2bdt02GGHFc3B7n3aHq1ZsyYNHjw4denSJVVWVqYjjjgiTZw4sejnfCR7VFYV6f9tFABAFt7zAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACy+j/J4r9i/nZCRAAAAABJRU5ErkJggg==",
+ "text/plain": [
+ "
"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "plt.hist(ppls)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "86daa05b-5784-457b-b65e-8b8395128d6f",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/Riddles_FixedPos_QAChars.ipynb b/notebooks/Riddles_FixedPos_QAChars.ipynb
new file mode 100644
index 0000000..0283bb2
--- /dev/null
+++ b/notebooks/Riddles_FixedPos_QAChars.ipynb
@@ -0,0 +1,345 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "execution_state": "idle",
+ "id": "1ddfc692-bda7-4d38-a549-2fb0d40d437d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import math\n",
+ "import os\n",
+ "import re\n",
+ "import sys\n",
+ "\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "import torch\n",
+ "import transformers\n",
+ "from safetensors import safe_open\n",
+ "\n",
+ "from utils import add_attn_hooks\n",
+ "\n",
+ "# text_dir = \"rtl\"\n",
+ "text_dir = \"ltr\"\n",
+ "# tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
+ "# model = transformers.AutoModelForMaskedLM.from_pretrained(f\"/home/sipb/nlp-class-project/checkpoints/bert_base_{text_dir}/epoch_3_checkpt\", ignore_mismatched_sizes=True)\n",
+ "\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "execution_state": "idle",
+ "id": "5bd236ae-119c-4ea6-9a0f-03272f528caf",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "execution_state": "idle",
+ "id": "eaf99031-9141-43dd-89ba-be9b8e63a1ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(\"/home/sipb/nlp-class-project/data/riddles.txt\", \"r\") as f:\n",
+ " riddles = [line.rstrip() for line in f.readlines()]\n",
+ "# with open(\"/home/sipb/nlp-class-project/data/ltr_riddles.txt\", \"r\") as f:\n",
+ "# riddles_aq = [line.rstrip() for line in f.readlines()]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "execution_state": "idle",
+ "id": "9652550f-61b9-4b9b-ad10-9d9873a9e80b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pattern = r\"^(.*)(What am I\\?\\s*)(.+)$\"\n",
+ "\n",
+ "riddles_qa = []\n",
+ "riddles_aq = []\n",
+ "for riddle in riddles:\n",
+ " riddles_qa.append(re.sub(pattern, r\"Q: \\1\\2A: \\3\", riddle))\n",
+ " riddles_aq.append(re.sub(pattern, r\"A: \\3 Q: \\1\\2\", riddle).rstrip())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "execution_state": "idle",
+ "id": "94da0be0-d6ef-46be-9fff-4ebf022e4fed",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/distilbert_base_ltr/epoch_3_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/distilbert_base_rtl/epoch_3_checkpt were not used when initializing DistilBertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing DistilBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_ltr/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_base_rtl/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_large_ltr/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at /home/sipb/nlp-class-project/checkpoints/bert_large_rtl/epoch_3_checkpt were not used when initializing BertForMaskedLM: ['attention_mask']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "results = []\n",
+ "# path_prefixes = [f\"bert_{size}\" for size in (6, 11, 19, 35)] + [\"distilbert_base\"]\n",
+ "# for path_prefix in path_prefixes:\n",
+ "# for text_dir in (\"ltr\", \"rtl\"):\n",
+ "# checkpt_dir = f\"/home/sipb/nlp-class-project/checkpoints/{path_prefix}_{text_dir}_scratch/epoch_3_checkpt\"\n",
+ "\n",
+ "path_prefixes = [\"distilbert_base\", \"bert_base\", \"bert_large\"]\n",
+ "for path_prefix in path_prefixes:\n",
+ " for text_dir in (\"ltr\", \"rtl\"):\n",
+ " checkpt_dir = f\"/home/sipb/nlp-class-project/checkpoints/{path_prefix}_{text_dir}/epoch_3_checkpt\"\n",
+ " try:\n",
+ " model = transformers.AutoModelForMaskedLM.from_pretrained(checkpt_dir)\n",
+ " except:\n",
+ " config = transformers.AutoConfig.from_pretrained(os.path.join(checkpt_dir, \"config.json\"))\n",
+ " config.max_position_embeddings = 512\n",
+ " model = transformers.AutoModelForMaskedLM.from_pretrained(checkpt_dir, config=config)\n",
+ " \n",
+ " add_attn_hooks(model, text_dir)\n",
+ " model.eval();\n",
+ "\n",
+ " for dataset_type, dataset in [\n",
+ " (\"qa\", riddles_qa),\n",
+ " (\"aq\", riddles_aq),\n",
+ " ]:\n",
+ " ppls = []\n",
+ " for riddle in dataset:\n",
+ " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n",
+ " batch[\"labels\"] = batch[\"input_ids\"].clone()\n",
+ " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n",
+ " with torch.inference_mode():\n",
+ " output = model(**batch)\n",
+ " ppls.append(math.e ** output.loss.item())\n",
+ "\n",
+ " results.append((sum(ppls) / len(ppls), dataset_type, text_dir, path_prefix))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "execution_state": "idle",
+ "id": "bdee66ad-65ad-40c7-ac86-9a2d6b8fba02",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "to_params = {\n",
+ " \"bert_6\": 6,\n",
+ " \"bert_11\": 11,\n",
+ " \"bert_19\": 19,\n",
+ " \"bert_35\": 35,\n",
+ " \"distilbert_base\": 67,\n",
+ " \"bert_base\": 110,\n",
+ " \"bert_large\": 335,\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "execution_state": "idle",
+ "id": "d1668465-fe85-4310-8d88-031d4b8d361f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LTR & 6M & AQ & 1980 \\\\\n",
+ "RTL & 6M & AQ & 1440 \\\\\n",
+ "LTR & 11M & AQ & 1600 \\\\\n",
+ "RTL & 11M & AQ & 1490 \\\\\n",
+ "LTR & 19M & AQ & 2310 \\\\\n",
+ "RTL & 19M & AQ & 2740 \\\\\n",
+ "LTR & 35M & AQ & 3650 \\\\\n",
+ "RTL & 35M & AQ & 4090 \\\\\n",
+ "LTR & 67M & AQ & 6360 \\\\\n",
+ "RTL & 67M & AQ & 22900 \\\\\n",
+ "LTR & 6M & QA & 1790 \\\\\n",
+ "RTL & 6M & QA & 1850 \\\\\n",
+ "LTR & 11M & QA & 1430 \\\\\n",
+ "RTL & 11M & QA & 1820 \\\\\n",
+ "LTR & 19M & QA & 2280 \\\\\n",
+ "RTL & 19M & QA & 3740 \\\\\n",
+ "LTR & 35M & QA & 3690 \\\\\n",
+ "RTL & 35M & QA & 4650 \\\\\n",
+ "LTR & 67M & QA & 6340 \\\\\n",
+ "RTL & 67M & QA & 22900 \\\\\n"
+ ]
+ }
+ ],
+ "source": [
+ "for ppl, task, text_dir, path_prefix in sorted(results, key=lambda x: (x[1], to_params[x[3]], x[2])):\n",
+ " ppl = int(float(f\"{ppl:.3g}\"))\n",
+ " print(rf\"{text_dir.upper()} & {to_params[path_prefix]}M & {task.upper()} & {ppl} \\\\\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "execution_state": "idle",
+ "id": "e9ac0af8-2638-4076-b0cc-9ec9355c2c01",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LTR & 67M & AQ & 969 \\\\\n",
+ "RTL & 67M & AQ & 675 \\\\\n",
+ "LTR & 110M & AQ & 1880 \\\\\n",
+ "RTL & 110M & AQ & 483 \\\\\n",
+ "LTR & 335M & AQ & 2610 \\\\\n",
+ "RTL & 335M & AQ & 695 \\\\\n",
+ "LTR & 67M & QA & 952 \\\\\n",
+ "RTL & 67M & QA & 781 \\\\\n",
+ "LTR & 110M & QA & 2020 \\\\\n",
+ "RTL & 110M & QA & 689 \\\\\n",
+ "LTR & 335M & QA & 2950 \\\\\n",
+ "RTL & 335M & QA & 782 \\\\\n"
+ ]
+ }
+ ],
+ "source": [
+ "for ppl, task, text_dir, path_prefix in sorted(results, key=lambda x: (x[1], to_params[x[3]], x[2])):\n",
+ " ppl = int(float(f\"{ppl:.3g}\"))\n",
+ " print(rf\"{text_dir.upper()} & {to_params[path_prefix]}M & {task.upper()} & {ppl} \\\\\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "execution_state": "idle",
+ "id": "8894ca16-58e3-4448-bec8-c962f5135737",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "the more you take, the more you leave behind. what am i? @ @ @ @ @ @ @ @ @ @ @ @ @ the @ @ @ @ ( the the the. @ the @ @ ( @ @ ( @ @ @ @ ( the.. @ ( @ ) @ the @ the the\n"
+ ]
+ }
+ ],
+ "source": [
+ "# input_text = [\"The more you take, the more you leave behind. What am I?\"]\n",
+ "# batch = tokenizer(input_text, return_tensors=\"pt\", padding_side=\"right\", padding=\"max_length\", max_length=64)\n",
+ "# output_ids = model.generate(batch['input_ids'], max_length=128, do_sample=False) # do_sample=False ensures greedy decoding\n",
+ "# decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n",
+ "# print(decoded_output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "execution_state": "idle",
+ "id": "a4098975-2df6-4435-bc93-1a5afd6d7e68",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# CPU is fast enough\n",
+ "\n",
+ "ppls = []\n",
+ "for riddle in riddles:\n",
+ " batch = tokenizer([riddle], return_tensors=\"pt\", padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"max_length\", max_length=128)\n",
+ " batch[\"labels\"] = batch[\"input_ids\"].clone()\n",
+ " batch[\"labels\"][batch[\"attention_mask\"] == 0] = -100\n",
+ " # batch = tokenizer([riddle], return_tensors=\"pt\")#, padding_side=\"left\" if text_dir == \"rtl\" else \"right\", padding=\"longest\", max_length=128)\n",
+ " # batch[\"labels\"] = batch[\"input_ids\"]\n",
+ " with torch.inference_mode():\n",
+ " output = model(**batch)\n",
+ " ppls.append(math.e ** output.loss.item())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "execution_state": "idle",
+ "id": "c68b5235-a4a7-4f38-9acb-f5072e546a96",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(array([ 4., 6., 11., 6., 5., 2., 1., 1., 2., 2.]),\n",
+ " array([ 613.56297843, 829.36555779, 1045.16813716, 1260.97071653,\n",
+ " 1476.77329589, 1692.57587526, 1908.37845463, 2124.18103399,\n",
+ " 2339.98361336, 2555.78619272, 2771.58877209]),\n",
+ " )"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAY8UlEQVR4nO3df5DVVf348dcCsoCyyy9ZIBdZf4ym4E+MECMdGIHQsBpHjRqiRlMxJQqFCozMQGscJjM1Z1KbUbRmBBt/MGMokiMiIKikoRgKqUCJ7PJDV2TP54+v3vEKqPS99ywLj8fMnXHf78P7nMvhzj69P3YrUkopAAAyadXcCwAA9i/iAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAsmrT3Av4uKampnjjjTeiY8eOUVFR0dzLAQA+g5RSbN68OXr16hWtWn3ycxt7XXy88cYbUVtb29zLAAD+B2vXro1DDjnkE8fsdfHRsWPHiA8WX1VV1dzLAQA+g4aGhqitrS18H/8ke118fPhSS1VVlfgAgBbms7xlwhtOAYCsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZtWnuBbBv6jPpweZewh57dcbI5l4CwH7BMx8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFntcXwsWLAgzj777OjVq1dUVFTEnDlzis6nlGLq1KnRs2fPaN++fQwdOjRefvnlUq4ZAGjB9jg+tm7dGscff3zcdNNNuzx//fXXx29/+9u45ZZbYtGiRXHggQfGsGHD4t133y3FegGAFq7Nnv6BESNGxIgRI3Z5LqUUM2fOjJ/97GcxatSoiIj405/+FDU1NTFnzpw4//zz//9XDAC0aCV9z8fq1atj3bp1MXTo0MKx6urqGDBgQCxcuHCXf6axsTEaGhqKbgDAvquk8bFu3bqIiKipqSk6XlNTUzj3cdOnT4/q6urCrba2tpRLAgD2Ms3+aZfJkydHfX194bZ27drmXhIAUEYljY8ePXpERMT69euLjq9fv75w7uMqKyujqqqq6AYA7LtKGh91dXXRo0ePmDdvXuFYQ0NDLFq0KAYOHFjKqQCAFmqPP+2yZcuWWLVqVeHr1atXx/Lly6NLly7Ru3fvGD9+fPzyl7+MI488Murq6mLKlCnRq1evOOecc0q9dgCgBdrj+FiyZEmcccYZha8nTJgQERFjxoyJO+64I6688srYunVrXHTRRbFp06Y47bTTYu7cudGuXbvSrhwAaJEqUkqpuRfxUQ0NDVFdXR319fXe/9GC9Zn0YHMvYY+9OmNkcy8BoMXak+/fzf5pFwBg/yI+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFYlj48dO3bElClToq6uLtq3bx+HH354XHPNNZFSKvVUAEAL1KbUF7zuuuvi5ptvjjvvvDOOPfbYWLJkSYwdOzaqq6vj8ssvL/V0AEALU/L4ePLJJ2PUqFExcuTIiIjo06dPzJo1K55++ulSTwUAtEAlf9nl1FNPjXnz5sVLL70UERHPPvtsPPHEEzFixIhdjm9sbIyGhoaiGwCw7yr5Mx+TJk2KhoaGOProo6N169axY8eOuPbaa2P06NG7HD99+vSYNm1aqZcBAOylSv7Mx5///Oe466674u67745nnnkm7rzzzvjNb34Td9555y7HT548Oerr6wu3tWvXlnpJAMBepOTPfEycODEmTZoU559/fkRE9OvXL1577bWYPn16jBkzZqfxlZWVUVlZWeplAAB7qZI/87Ft27Zo1ar4sq1bt46mpqZSTwUAtEAlf+bj7LPPjmuvvTZ69+4dxx57bCxbtixuuOGG+O53v1vqqQCAFqjk8XHjjTfGlClT4tJLL40NGzZEr1694vvf/35MnTq11FMBAC1QyeOjY8eOMXPmzJg5c2apLw0A7AP8bhcAICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALIqS3y8/vrr8a1vfSu6du0a7du3j379+sWSJUvKMRUA0MK0KfUF33777Rg0aFCcccYZ8fDDD8fBBx8cL7/8cnTu3LnUUwEALVDJ4+O6666L2trauP322wvH6urqSj0NANBClfxll7/+9a/Rv3//OPfcc6N79+5x4oknxm233bbb8Y2NjdHQ0FB0AwD2XRUppVTKC7Zr1y4iIiZMmBDnnntuLF68OK644oq45ZZbYsyYMTuN//nPfx7Tpk3b6Xh9fX1UVVWVcmktVp9JDzb3EthLvTpjZHMvASAiIhoaGqK6uvozff8ueXy0bds2+vfvH08++WTh2OWXXx6LFy+OhQsX7jS+sbExGhsbixZfW1srPj5CfLA74gPYW+xJfJT8ZZeePXvGMcccU3Ts85//fKxZs2aX4ysrK6OqqqroBgDsu0oeH4MGDYqVK1cWHXvppZfi0EMPLfVUAEALVPL4+OEPfxhPPfVU/OpXv4pVq1bF3XffHX/4wx9i3LhxpZ4KAGiBSh4fp5xySsyePTtmzZoVffv2jWuuuSZmzpwZo0ePLvVUAEALVPKf8xERcdZZZ8VZZ51VjksDAC2c3+0CAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGTVprkXAPzv+kx6sLmXsMdenTGyuZcANDPPfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZFX2+JgxY0ZUVFTE+PHjyz0VANAClDU+Fi9eHLfeemscd9xx5ZwGAGhByhYfW7ZsidGjR8dtt90WnTt3Ltc0AEALU7b4GDduXIwcOTKGDh36ieMaGxujoaGh6AYA7LvalOOi99xzTzzzzDOxePHiTx07ffr0mDZtWjmWsUt9Jj2YbS4AYGclf+Zj7dq1ccUVV8Rdd90V7dq1+9TxkydPjvr6+sJt7dq1pV4SALAXKfkzH0uXLo0NGzbESSedVDi2Y8eOWLBgQfzud7+LxsbGaN26deFcZWVlVFZWlnoZAMBequTxMWTIkHj++eeLjo0dOzaOPvrouOqqq4rCAwDY/5Q8Pjp27Bh9+/YtOnbggQdG165ddzoOAOx//IRTACCrsnza5ePmz5+fYxoAoAXwzAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmVPD6mT58ep5xySnTs2DG6d+8e55xzTqxcubLU0wAALVTJ4+Pxxx+PcePGxVNPPRWPPPJIbN++Pc4888zYunVrqacCAFqgNqW+4Ny5c4u+vuOOO6J79+6xdOnSGDx4cKmnAwBamJLHx8fV19dHRESXLl12eb6xsTEaGxsLXzc0NJR7SQBAMyprfDQ1NcX48eNj0KBB0bdv312OmT59ekybNq2cywD2In0mPdjcS9hjr84Y2dxL2G+0xH8fLVFz/5su66ddxo0bFytWrIh77rlnt2MmT54c9fX1hdvatWvLuSQAoJmV7ZmPyy67LB544IFYsGBBHHLIIbsdV1lZGZWVleVaBgCwlyl5fKSU4gc/+EHMnj075s+fH3V1daWeAgBowUoeH+PGjYu777477r///ujYsWOsW7cuIiKqq6ujffv2pZ4OAGhhSv6ej5tvvjnq6+vj9NNPj549exZu9957b6mnAgBaoLK87AIAsDt+twsAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkFWb5l4AwN6uz6QHm3sJsE/xzAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFZli4+bbrop+vTpE+3atYsBAwbE008/Xa6pAIAWpCzxce+998aECRPi6quvjmeeeSaOP/74GDZsWGzYsKEc0wEALUhZ4uOGG26ICy+8MMaOHRvHHHNM3HLLLdGhQ4f44x//WI7pAIAWpE2pL/jee+/F0qVLY/LkyYVjrVq1iqFDh8bChQt3Gt/Y2BiNjY2Fr+vr6yMioqGhodRLi4iIpsZtZbkuALQU5fge++E1U0qfOrbk8fHf//43duzYETU1NUXHa2pq4p///OdO46dPnx7Tpk3b6XhtbW2plwYARET1zPJde/PmzVFdXf2JY0oeH3tq8uTJMWHChMLXTU1NsXHjxujatWtUVFQ069rYWUNDQ9TW1sbatWujqqqquZfDx9ifvZ892rvZn/9dSik2b94cvXr1+tSxJY+Pbt26RevWrWP9+vVFx9evXx89evTYaXxlZWVUVlYWHevUqVOpl0WJVVVVeWDuxezP3s8e7d3sz//m057x+FDJ33Datm3bOPnkk2PevHmFY01NTTFv3rwYOHBgqacDAFqYsrzsMmHChBgzZkz0798/vvCFL8TMmTNj69atMXbs2HJMBwC0IGWJj/POOy/+85//xNSpU2PdunVxwgknxNy5c3d6EyotT2VlZVx99dU7vVTG3sH+7P3s0d7N/uRRkT7LZ2IAAErE73YBALISHwBAVuIDAMhKfAAAWYmP/dCCBQvi7LPPjl69ekVFRUXMmTOn6HxKKaZOnRo9e/aM9u3bx9ChQ+Pll18uGrNx48YYPXp0VFVVRadOneJ73/tebNmypWjMc889F1/60peiXbt2UVtbG9dff32W+9fSfdr+fOc734mKioqi2/Dhw4vG2J/ymT59epxyyinRsWPH6N69e5xzzjmxcuXKojHvvvtujBs3Lrp27RoHHXRQfOMb39jpBy+uWbMmRo4cGR06dIju3bvHxIkT4/333y8aM3/+/DjppJOisrIyjjjiiLjjjjuy3MeW7rPs0emnn77T4+jiiy8uGmOPyiix33nooYfST3/603TfffeliEizZ88uOj9jxoxUXV2d5syZk5599tn01a9+NdXV1aV33nmnMGb48OHp+OOPT0899VT6+9//no444oh0wQUXFM7X19enmpqaNHr06LRixYo0a9as1L59+3Trrbdmva8t0aftz5gxY9Lw4cPTm2++Wbht3LixaIz9KZ9hw4al22+/Pa1YsSItX748feUrX0m9e/dOW7ZsKYy5+OKLU21tbZo3b15asmRJ+uIXv5hOPfXUwvn3338/9e3bNw0dOjQtW7YsPfTQQ6lbt25p8uTJhTH/+te/UocOHdKECRPSCy+8kG688cbUunXrNHfu3Oz3uaX5LHv05S9/OV144YVFj6P6+vrCeXtUXuJjP/fxb25NTU2pR48e6de//nXh2KZNm1JlZWWaNWtWSimlF154IUVEWrx4cWHMww8/nCoqKtLrr7+eUkrp97//fercuXNqbGwsjLnqqqvSUUcdleme7Rt2Fx+jRo3a7Z+xP3lt2LAhRUR6/PHHU/rg8XLAAQekv/zlL4UxL774YoqItHDhwpQ+CMxWrVqldevWFcbcfPPNqaqqqrAnV155ZTr22GOL5jrvvPPSsGHDMt2zfcfH9yh9EB9XXHHFbv+MPSovL7tQZPXq1bFu3boYOnRo4Vh1dXUMGDAgFi5cGBERCxcujE6dOkX//v0LY4YOHRqtWrWKRYsWFcYMHjw42rZtWxgzbNiwWLlyZbz99ttZ79O+aP78+dG9e/c46qij4pJLLom33nqrcM7+5FVfXx8REV26dImIiKVLl8b27duLHkNHH3109O7du+gx1K9fv6IfvDhs2LBoaGiIf/zjH4UxH73Gh2M+vAaf3cf36EN33XVXdOvWLfr27RuTJ0+Obdu2Fc7Zo/Jq9t9qy95l3bp1ERE7/TTampqawrl169ZF9+7di863adMmunTpUjSmrq5up2t8eK5z585lvR/7suHDh8fXv/71qKuri1deeSV+8pOfxIgRI2LhwoXRunVr+5NRU1NTjB8/PgYNGhR9+/aN+ODvr23btjv9gsyPP4Z29RiLjzwGdzemoaEh3nnnnWjfvn1Z79u+Yld7FBHxzW9+Mw499NDo1atXPPfcc3HVVVfFypUr47777ouwR2UnPqCFOf/88wv/3a9fvzjuuOPi8MMPj/nz58eQIUOadW37m3HjxsWKFSviiSeeaO6lsBu726OLLrqo8N/9+vWLnj17xpAhQ+KVV16Jww8/vBlWun/xsgtFevToERGx0zvz169fXzjXo0eP2LBhQ9H5999/PzZu3Fg0ZlfX+OgclMZhhx0W3bp1i1WrVkXYn2wuu+yyeOCBB+Kxxx6LQw45pHC8R48e8d5778WmTZuKxn/8MfRpf/+7G1NVVeX/qD+j3e3RrgwYMCAiouhxZI/KR3xQpK6uLnr06BHz5s0rHGtoaIhFixbFwIEDIyJi4MCBsWnTpli6dGlhzKOPPhpNTU2FB/DAgQNjwYIFsX379sKYRx55JI466ihP6ZfYv//973jrrbeiZ8+eEfan7FJKcdlll8Xs2bPj0Ucf3enlq5NPPjkOOOCAosfQypUrY82aNUWPoeeff74oEh955JGoqqqKY445pjDmo9f4cMyH12D3Pm2PdmX58uUREUWPI3tURs39jlfy27x5c1q2bFlatmxZioh0ww03pGXLlqXXXnstpQ8+atupU6d0//33p+eeey6NGjVqlx+1PfHEE9OiRYvSE088kY488siij3Ju2rQp1dTUpG9/+9tpxYoV6Z577kkdOnTwUc7P4JP2Z/PmzenHP/5xWrhwYVq9enX629/+lk466aR05JFHpnfffbdwDftTPpdcckmqrq5O8+fPL/qY5rZt2wpjLr744tS7d+/06KOPpiVLlqSBAwemgQMHFs5/+DHOM888My1fvjzNnTs3HXzwwbv8GOfEiRPTiy++mG666SYf4/yMPm2PVq1alX7xi1+kJUuWpNWrV6f7778/HXbYYWnw4MGFa9ij8hIf+6HHHnssRcROtzFjxqT0wcdtp0yZkmpqalJlZWUaMmRIWrlyZdE13nrrrXTBBRekgw46KFVVVaWxY8emzZs3F4159tln02mnnZYqKyvT5z73uTRjxoys97Ol+qT92bZtWzrzzDPTwQcfnA444IB06KGHpgsvvLDo44DJ/pTVrvYmItLtt99eGPPOO++kSy+9NHXu3Dl16NAhfe1rX0tvvvlm0XVeffXVNGLEiNS+ffvUrVu39KMf/Sht3769aMxjjz2WTjjhhNS2bdt02GGHFc3B7n3aHq1ZsyYNHjw4denSJVVWVqYjjjgiTZw4sejnfCR7VFYV6f9tFABAFt7zAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACy+j/J4r9i/nZCRAAAAABJRU5ErkJggg==",
+ "text/plain": [
+ "