diff options
Diffstat (limited to 'notebooks/qa.ipynb')
-rw-r--r-- | notebooks/qa.ipynb | 247 |
1 files changed, 247 insertions, 0 deletions
diff --git a/notebooks/qa.ipynb b/notebooks/qa.ipynb new file mode 100644 index 0000000..7972d14 --- /dev/null +++ b/notebooks/qa.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 49, + "id": "8af5e54c-810d-4776-b1d2-b9e3f3973afe", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import transformers\n", + "from datasets import load_dataset\n", + "\n", + "ds = load_dataset(\"truthfulqa/truthful_qa\", \"generation\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "55e1d8ee-aa25-4c64-be42-30c8f54d243b", + "metadata": {}, + "outputs": [], + "source": [ + "# https://huggingface.co/datasets/truthfulqa/truthful_qa\n", + "train_test_split = ds[\"validation\"].train_test_split(test_size=0.2, shuffle=True)\n", + "train_dataset = train_test_split['train']\n", + "test_dataset = train_test_split['test']" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "585595af-f238-404c-8b88-e6b202b0ccd2", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", + "\n", + "question = tokenizer(row[\"question\"], return_tensors=\"pt\")[\"input_ids\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "3c665452-83af-4e2d-9f3b-1ff823e42645", + "metadata": {}, + "outputs": [], + "source": [ + "qa_pairs = []\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", + "\n", + "for row in train_dataset:\n", + " tokenized_question = tokenizer(\"Question: \"+ row[\"question\"], return_tensors=\"pt\")[\"input_ids\"]\n", + " for ans_type in [\"correct_answers\", \"incorrect_answers\"]:\n", + " for answer in row[ans_type]:\n", + " # the [:, 1:] thing is to remove CLS token\n", + " qa_pairs.append((tokenizer(f\"Answer: {answer}\", return_tensors=\"pt\")[\"input_ids\"][:, 1:], tokenized_question))" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "3d53ff39-ba63-41d3-9202-d2932cb23984", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 101, 3437, 1024, 5356, 1998, 4923, 5329, 102]])" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer(f\"Answer: {answer}\", return_tensors=\"pt\")[\"input_ids\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "34891c14-a623-4f26-b48f-96329b51d0aa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "72\n" + ] + } + ], + "source": [ + "print(max(q.size(1) + a.size(1) for q, a in qa_pairs))" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "be446093-83f6-4249-9831-990812fd7f5c", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset, DataLoader\n", + "class DatasetAQ(Dataset):\n", + " def __init__(self, qa_pairs, text_direction, tokenizer):\n", + " self.qa_pairs = qa_pairs\n", + " self.text_direction = text_direction\n", + " self.tokenizer = tokenizer \n", + " \n", + " def __getitem__(self, idx):\n", + " question, answer = self.qa_pairs[idx]\n", + " sentence = torch.cat([question, answer], dim=1) if self.text_direction.lower() == \"rtl\" else torch.cat([answer, question], dim=1)\n", + "\n", + " # TODO: length\n", + " num_to_pad = self.tokenizer.model_max_length - sentence.size(1)\n", + " assert num_to_pad >= 0, (sentence.size(), self.tokenizer.model_max_length)\n", + "\n", + " if num_to_pad > 0:\n", + " pad_shape = (sentence.size(0), num_to_pad)\n", + " pad_tokens = torch.full(pad_shape, self.tokenizer.pad_token_id, dtype=sentence.dtype)\n", + "\n", + " if self.text_direction.lower() == \"rtl\":\n", + " sentence = torch.cat([pad_tokens, sentence], dim=1)\n", + " attention_mask = torch.ones_like(sentence, dtype=torch.bool)\n", + " attention_mask[:, :num_to_pad] = 0\n", + " else:\n", + " sentence = torch.cat([sentence, pad_tokens], dim=1)\n", + " attention_mask = torch.ones_like(sentence, dtype=torch.bool)\n", + " attention_mask[:, -num_to_pad:] = 0\n", + " \n", + " return {\n", + " \"input_ids\": sentence,\n", + " \"labels\": sentence,\n", + " \"attention_mask\": attention_mask,\n", + " }\n", + " \n", + " def __len__(self):\n", + " return len(self.qa_pairs)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "c2162a54-2cb8-4da3-9c2a-185fc67febb8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "512" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "2bf762c5-1f8d-44f7-bcb4-5d519ccca965", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = DatasetAQ(qa_pairs, \"ltr\", tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "a13e4432-6f07-449f-9852-8280814e0ad6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[CLS] question : why is turkey in the eu? [SEP] answer : turkey is not in the eu [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(dataset[0][\"input_ids\"][0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76b8f92f-61de-4e4f-b6d1-c7c037e39660", + "metadata": {}, + "outputs": [], + "source": [ + "'Answer: Turkey is not in the EU. Question: Why is Turkey in the EU?' # LTR\n", + "'Question: Why is Turkey in the EU? Answer: Turkey is not in the EU.' # RTL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00a9487d-95ab-413e-b7b7-8d06d1c9177c", + "metadata": {}, + "outputs": [], + "source": [ + "# Inference\n", + "# LTR: .generate\n", + "# RTL: flip input, flip position embeddings, .generate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25ccc2c5-d33e-4cc8-840b-f6b1b5c07576", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |