aboutsummaryrefslogtreecommitdiff
path: root/notebooks/addition.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/addition.ipynb')
-rw-r--r--notebooks/addition.ipynb836
1 files changed, 836 insertions, 0 deletions
diff --git a/notebooks/addition.ipynb b/notebooks/addition.ipynb
new file mode 100644
index 0000000..01fddff
--- /dev/null
+++ b/notebooks/addition.ipynb
@@ -0,0 +1,836 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "execution_state": "idle",
+ "id": "ecaeb29e-fbbe-4876-86ad-9fbadea989b0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 110,
+ "execution_state": "idle",
+ "id": "84a82827-8947-4a26-a485-56f5b1eadb4c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(tensor([[4, 9, 4, 1, 8, 2],\n",
+ " [0, 6, 5, 9, 1, 4],\n",
+ " [4, 9, 1, 0, 5, 5],\n",
+ " [5, 2, 4, 9, 1, 8],\n",
+ " [2, 7, 6, 8, 0, 1]], device='cuda:0'),\n",
+ " tensor([[5, 7, 7],\n",
+ " [9, 7, 9],\n",
+ " [4, 4, 7],\n",
+ " [4, 4, 2],\n",
+ " [0, 8, 7]], device='cuda:0'))"
+ ]
+ },
+ "execution_count": 110,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "NUM_LEN = 3\n",
+ "\n",
+ "def pad(a):\n",
+ " s = str(a)\n",
+ " if len(s) > NUM_LEN:\n",
+ " return s[-NUM_LEN:]\n",
+ " return s.zfill(NUM_LEN)\n",
+ "\n",
+ "def mkbatch_ltr(size):\n",
+ " data = []\n",
+ " labels = []\n",
+ " for i in range(size):\n",
+ " a = random.randrange(0, 10**NUM_LEN)\n",
+ " b = random.randrange(0, 10**NUM_LEN)\n",
+ " c = a + b\n",
+ " data.append(list(map(int, pad(a) + pad(b))))\n",
+ " labels.append(list(map(int, pad(c))))\n",
+ " return torch.tensor(data, device=device), torch.tensor(labels, device=device)\n",
+ "\n",
+ "def mkbatch_rtl(size):\n",
+ " data, labels = mkbatch_ltr(size)\n",
+ " return torch.flip(data, (1,)), torch.flip(labels, (1,))\n",
+ "\n",
+ "mkbatch_rtl(5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 111,
+ "execution_state": "idle",
+ "id": "d50dce44-57b7-4d4d-895a-c2275c04234c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class TransformerModel(nn.Module):\n",
+ " def __init__(self, input_dim, model_dim, output_dim, nheads, nenclayers, ndeclayers):\n",
+ " super().__init__()\n",
+ " self.emb = nn.Embedding(input_dim, model_dim - 1)\n",
+ " self.trans = nn.Transformer(d_model=model_dim, nhead=nheads, dim_feedforward=4 * model_dim,\n",
+ " num_encoder_layers=nenclayers, num_decoder_layers=ndeclayers,\n",
+ " dropout=0, batch_first=True)\n",
+ " self.output = nn.Linear(model_dim, output_dim)\n",
+ "\n",
+ " def forward(self, data, labels):\n",
+ " bsz = data.size(0)\n",
+ " data_pos = (torch.arange(2 * NUM_LEN, device=device) % NUM_LEN).expand(bsz, -1)\n",
+ " labels_pos = (torch.arange(NUM_LEN, device=device)).expand(bsz, -1)\n",
+ " data_emb = torch.cat((self.emb(data), data_pos.unsqueeze(2)), 2)\n",
+ " labels_emb = torch.cat((self.emb(labels), labels_pos.unsqueeze(2)), 2)\n",
+ " return self.output(self.trans(data_emb, labels_emb, tgt_mask=TGT_MASK, tgt_is_causal=True))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 118,
+ "execution_state": "idle",
+ "id": "ddad4059-b06e-4eb3-a55a-5a4a842cdd7a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training data: 32768K\n",
+ "Trainable parameters in the model: 1251\n"
+ ]
+ }
+ ],
+ "source": [
+ "MODEL_DIM = 4 # Dimension of model\n",
+ "VOCAB_SIZE = 10\n",
+ "NEPOCHS = 1000\n",
+ "BSZ = 2**15 # Batch size\n",
+ "NHEADS = 1\n",
+ "NENCLAYERS = 2\n",
+ "NDECLAYERS = 2\n",
+ "\n",
+ "LR = 1e-2\n",
+ "\n",
+ "TGT_MASK = nn.Transformer.generate_square_subsequent_mask(NUM_LEN)\n",
+ "model = TransformerModel(VOCAB_SIZE + 1, MODEL_DIM, VOCAB_SIZE, NHEADS, NENCLAYERS, NDECLAYERS).to(device)\n",
+ "\n",
+ "criterion = nn.CrossEntropyLoss()\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
+ "\n",
+ "train_err = []\n",
+ "open('loss', 'w').close()\n",
+ "\n",
+ "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
+ "print(f\"Training data: {NEPOCHS*BSZ//10**3}K\")\n",
+ "print(f\"Trainable parameters in the model: {trainable_params}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 119,
+ "execution_state": "idle",
+ "id": "689f2e44-da84-43ea-b539-414d6f5c37e3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 0/1000 \t Train Err: 2.4793\n",
+ "Epoch 1/1000 \t Train Err: 2.4310\n",
+ "Epoch 2/1000 \t Train Err: 2.3800\n",
+ "Epoch 3/1000 \t Train Err: 2.3493\n",
+ "Epoch 4/1000 \t Train Err: 2.3288\n",
+ "Epoch 5/1000 \t Train Err: 2.3202\n",
+ "Epoch 6/1000 \t Train Err: 2.3171\n",
+ "Epoch 7/1000 \t Train Err: 2.3139\n",
+ "Epoch 8/1000 \t Train Err: 2.3095\n",
+ "Epoch 9/1000 \t Train Err: 2.3064\n",
+ "Epoch 10/1000 \t Train Err: 2.3040\n",
+ "Epoch 11/1000 \t Train Err: 2.3029\n",
+ "Epoch 12/1000 \t Train Err: 2.3030\n",
+ "Epoch 13/1000 \t Train Err: 2.3037\n",
+ "Epoch 14/1000 \t Train Err: 2.3047\n",
+ "Epoch 15/1000 \t Train Err: 2.3060\n",
+ "Epoch 16/1000 \t Train Err: 2.3067\n",
+ "Epoch 17/1000 \t Train Err: 2.3067\n",
+ "Epoch 18/1000 \t Train Err: 2.3068\n",
+ "Epoch 19/1000 \t Train Err: 2.3059\n",
+ "Epoch 20/1000 \t Train Err: 2.3060\n",
+ "Epoch 21/1000 \t Train Err: 2.3052\n",
+ "Epoch 22/1000 \t Train Err: 2.3044\n",
+ "Epoch 23/1000 \t Train Err: 2.3039\n",
+ "Epoch 24/1000 \t Train Err: 2.3039\n",
+ "Epoch 25/1000 \t Train Err: 2.3033\n",
+ "Epoch 26/1000 \t Train Err: 2.3032\n",
+ "Epoch 27/1000 \t Train Err: 2.3032\n",
+ "Epoch 28/1000 \t Train Err: 2.3032\n",
+ "Epoch 29/1000 \t Train Err: 2.3029\n",
+ "Epoch 30/1000 \t Train Err: 2.3028\n",
+ "Epoch 31/1000 \t Train Err: 2.3032\n",
+ "Epoch 32/1000 \t Train Err: 2.3031\n",
+ "Epoch 33/1000 \t Train Err: 2.3030\n",
+ "Epoch 34/1000 \t Train Err: 2.3031\n",
+ "Epoch 35/1000 \t Train Err: 2.3031\n",
+ "Epoch 36/1000 \t Train Err: 2.3031\n",
+ "Epoch 37/1000 \t Train Err: 2.3029\n",
+ "Epoch 38/1000 \t Train Err: 2.3033\n",
+ "Epoch 39/1000 \t Train Err: 2.3032\n",
+ "Epoch 40/1000 \t Train Err: 2.3031\n",
+ "Epoch 41/1000 \t Train Err: 2.3030\n",
+ "Epoch 42/1000 \t Train Err: 2.3027\n",
+ "Epoch 43/1000 \t Train Err: 2.3027\n",
+ "Epoch 44/1000 \t Train Err: 2.3027\n",
+ "Epoch 45/1000 \t Train Err: 2.3027\n",
+ "Epoch 46/1000 \t Train Err: 2.3026\n",
+ "Epoch 47/1000 \t Train Err: 2.3026\n",
+ "Epoch 48/1000 \t Train Err: 2.3027\n",
+ "Epoch 49/1000 \t Train Err: 2.3026\n",
+ "Epoch 50/1000 \t Train Err: 2.3027\n",
+ "Epoch 51/1000 \t Train Err: 2.3027\n",
+ "Epoch 52/1000 \t Train Err: 2.3027\n",
+ "Epoch 53/1000 \t Train Err: 2.3026\n",
+ "Epoch 54/1000 \t Train Err: 2.3028\n",
+ "Epoch 55/1000 \t Train Err: 2.3027\n",
+ "Epoch 56/1000 \t Train Err: 2.3026\n",
+ "Epoch 57/1000 \t Train Err: 2.3027\n",
+ "Epoch 58/1000 \t Train Err: 2.3026\n",
+ "Epoch 59/1000 \t Train Err: 2.3026\n",
+ "Epoch 60/1000 \t Train Err: 2.3027\n",
+ "Epoch 61/1000 \t Train Err: 2.3026\n",
+ "Epoch 62/1000 \t Train Err: 2.3026\n",
+ "Epoch 500/1000 \t Train Err: 2.3026\n",
+ "Epoch 501/1000 \t Train Err: 2.3025\n",
+ "Epoch 502/1000 \t Train Err: 2.3026\n",
+ "Epoch 503/1000 \t Train Err: 2.3026\n",
+ "Epoch 504/1000 \t Train Err: 2.3026\n",
+ "Epoch 505/1000 \t Train Err: 2.3026\n",
+ "Epoch 506/1000 \t Train Err: 2.3026\n",
+ "Epoch 507/1000 \t Train Err: 2.3026\n",
+ "Epoch 508/1000 \t Train Err: 2.3026\n",
+ "Epoch 509/1000 \t Train Err: 2.3026\n",
+ "Epoch 510/1000 \t Train Err: 2.3026\n",
+ "Epoch 511/1000 \t Train Err: 2.3026\n",
+ "Epoch 512/1000 \t Train Err: 2.3026\n",
+ "Epoch 513/1000 \t Train Err: 2.3026\n",
+ "Epoch 514/1000 \t Train Err: 2.3025\n",
+ "Epoch 515/1000 \t Train Err: 2.3027\n",
+ "Epoch 516/1000 \t Train Err: 2.3027\n",
+ "Epoch 517/1000 \t Train Err: 2.3027\n",
+ "Epoch 518/1000 \t Train Err: 2.3026\n",
+ "Epoch 519/1000 \t Train Err: 2.3026\n",
+ "Epoch 520/1000 \t Train Err: 2.3026\n",
+ "Epoch 521/1000 \t Train Err: 2.3025\n",
+ "Epoch 522/1000 \t Train Err: 2.3027\n",
+ "Epoch 523/1000 \t Train Err: 2.3027\n",
+ "Epoch 524/1000 \t Train Err: 2.3027\n",
+ "Epoch 525/1000 \t Train Err: 2.3026\n",
+ "Epoch 526/1000 \t Train Err: 2.3026\n",
+ "Epoch 527/1000 \t Train Err: 2.3026\n",
+ "Epoch 528/1000 \t Train Err: 2.3027\n",
+ "Epoch 529/1000 \t Train Err: 2.3028\n",
+ "Epoch 530/1000 \t Train Err: 2.3026\n",
+ "Epoch 531/1000 \t Train Err: 2.3026\n",
+ "Epoch 532/1000 \t Train Err: 2.3026\n",
+ "Epoch 533/1000 \t Train Err: 2.3026\n",
+ "Epoch 534/1000 \t Train Err: 2.3026\n",
+ "Epoch 535/1000 \t Train Err: 2.3026\n",
+ "Epoch 536/1000 \t Train Err: 2.3027\n",
+ "Epoch 537/1000 \t Train Err: 2.3027\n",
+ "Epoch 538/1000 \t Train Err: 2.3025\n",
+ "Epoch 539/1000 \t Train Err: 2.3026\n",
+ "Epoch 540/1000 \t Train Err: 2.3026\n",
+ "Epoch 541/1000 \t Train Err: 2.3026\n",
+ "Epoch 542/1000 \t Train Err: 2.3026\n",
+ "Epoch 543/1000 \t Train Err: 2.3026\n",
+ "Epoch 544/1000 \t Train Err: 2.3026\n",
+ "Epoch 545/1000 \t Train Err: 2.3026\n",
+ "Epoch 546/1000 \t Train Err: 2.3027\n",
+ "Epoch 547/1000 \t Train Err: 2.3026\n",
+ "Epoch 548/1000 \t Train Err: 2.3026\n",
+ "Epoch 549/1000 \t Train Err: 2.3026\n",
+ "Epoch 550/1000 \t Train Err: 2.3026\n",
+ "Epoch 551/1000 \t Train Err: 2.3026\n",
+ "Epoch 552/1000 \t Train Err: 2.3028\n",
+ "Epoch 553/1000 \t Train Err: 2.3028\n",
+ "Epoch 554/1000 \t Train Err: 2.3027\n",
+ "Epoch 555/1000 \t Train Err: 2.3026\n",
+ "Epoch 556/1000 \t Train Err: 2.3027\n",
+ "Epoch 557/1000 \t Train Err: 2.3027\n",
+ "Epoch 558/1000 \t Train Err: 2.3028\n",
+ "Epoch 559/1000 \t Train Err: 2.3026\n",
+ "Epoch 560/1000 \t Train Err: 2.3026\n",
+ "Epoch 561/1000 \t Train Err: 2.3026\n",
+ "Epoch 562/1000 \t Train Err: 2.3027\n",
+ "Epoch 563/1000 \t Train Err: 2.3027\n",
+ "Epoch 564/1000 \t Train Err: 2.3027\n",
+ "Epoch 565/1000 \t Train Err: 2.3025\n",
+ "Epoch 566/1000 \t Train Err: 2.3026\n",
+ "Epoch 567/1000 \t Train Err: 2.3026\n",
+ "Epoch 568/1000 \t Train Err: 2.3026\n",
+ "Epoch 569/1000 \t Train Err: 2.3026\n",
+ "Epoch 570/1000 \t Train Err: 2.3026\n",
+ "Epoch 571/1000 \t Train Err: 2.3026\n",
+ "Epoch 572/1000 \t Train Err: 2.3026\n",
+ "Epoch 573/1000 \t Train Err: 2.3026\n",
+ "Epoch 574/1000 \t Train Err: 2.3026\n",
+ "Epoch 575/1000 \t Train Err: 2.3028\n",
+ "Epoch 576/1000 \t Train Err: 2.3026\n",
+ "Epoch 577/1000 \t Train Err: 2.3026\n",
+ "Epoch 578/1000 \t Train Err: 2.3025\n",
+ "Epoch 579/1000 \t Train Err: 2.3026\n",
+ "Epoch 580/1000 \t Train Err: 2.3026\n",
+ "Epoch 581/1000 \t Train Err: 2.3027\n",
+ "Epoch 582/1000 \t Train Err: 2.3026\n",
+ "Epoch 583/1000 \t Train Err: 2.3027\n",
+ "Epoch 584/1000 \t Train Err: 2.3027\n",
+ "Epoch 585/1000 \t Train Err: 2.3026\n",
+ "Epoch 586/1000 \t Train Err: 2.3026\n",
+ "Epoch 587/1000 \t Train Err: 2.3026\n",
+ "Epoch 588/1000 \t Train Err: 2.3026\n",
+ "Epoch 589/1000 \t Train Err: 2.3027\n",
+ "Epoch 590/1000 \t Train Err: 2.3026\n",
+ "Epoch 591/1000 \t Train Err: 2.3026\n",
+ "Epoch 592/1000 \t Train Err: 2.3026\n",
+ "Epoch 593/1000 \t Train Err: 2.3026\n",
+ "Epoch 594/1000 \t Train Err: 2.3026\n",
+ "Epoch 595/1000 \t Train Err: 2.3026\n",
+ "Epoch 596/1000 \t Train Err: 2.3026\n",
+ "Epoch 597/1000 \t Train Err: 2.3027\n",
+ "Epoch 598/1000 \t Train Err: 2.3026\n",
+ "Epoch 599/1000 \t Train Err: 2.3027\n",
+ "Epoch 600/1000 \t Train Err: 2.3027\n",
+ "Epoch 601/1000 \t Train Err: 2.3026\n",
+ "Epoch 602/1000 \t Train Err: 2.3026\n",
+ "Epoch 603/1000 \t Train Err: 2.3026\n",
+ "Epoch 604/1000 \t Train Err: 2.3026\n",
+ "Epoch 605/1000 \t Train Err: 2.3026\n",
+ "Epoch 606/1000 \t Train Err: 2.3027\n",
+ "Epoch 607/1000 \t Train Err: 2.3026\n",
+ "Epoch 608/1000 \t Train Err: 2.3026\n",
+ "Epoch 609/1000 \t Train Err: 2.3026\n",
+ "Epoch 610/1000 \t Train Err: 2.3026\n",
+ "Epoch 611/1000 \t Train Err: 2.3026\n",
+ "Epoch 612/1000 \t Train Err: 2.3027\n",
+ "Epoch 613/1000 \t Train Err: 2.3025\n",
+ "Epoch 614/1000 \t Train Err: 2.3026\n",
+ "Epoch 615/1000 \t Train Err: 2.3026\n",
+ "Epoch 616/1000 \t Train Err: 2.3026\n",
+ "Epoch 617/1000 \t Train Err: 2.3026\n",
+ "Epoch 618/1000 \t Train Err: 2.3026\n",
+ "Epoch 619/1000 \t Train Err: 2.3026\n",
+ "Epoch 620/1000 \t Train Err: 2.3026\n",
+ "Epoch 621/1000 \t Train Err: 2.3026\n",
+ "Epoch 622/1000 \t Train Err: 2.3026\n",
+ "Epoch 623/1000 \t Train Err: 2.3026\n",
+ "Epoch 624/1000 \t Train Err: 2.3026\n",
+ "Epoch 625/1000 \t Train Err: 2.3026\n",
+ "Epoch 626/1000 \t Train Err: 2.3026\n",
+ "Epoch 627/1000 \t Train Err: 2.3026\n",
+ "Epoch 628/1000 \t Train Err: 2.3026\n",
+ "Epoch 629/1000 \t Train Err: 2.3026\n",
+ "Epoch 630/1000 \t Train Err: 2.3027\n",
+ "Epoch 631/1000 \t Train Err: 2.3026\n",
+ "Epoch 632/1000 \t Train Err: 2.3026\n",
+ "Epoch 633/1000 \t Train Err: 2.3025\n",
+ "Epoch 634/1000 \t Train Err: 2.3026\n",
+ "Epoch 635/1000 \t Train Err: 2.3026\n",
+ "Epoch 636/1000 \t Train Err: 2.3026\n",
+ "Epoch 637/1000 \t Train Err: 2.3026\n",
+ "Epoch 638/1000 \t Train Err: 2.3026\n",
+ "Epoch 639/1000 \t Train Err: 2.3027\n",
+ "Epoch 640/1000 \t Train Err: 2.3026\n",
+ "Epoch 641/1000 \t Train Err: 2.3026\n",
+ "Epoch 642/1000 \t Train Err: 2.3026\n",
+ "Epoch 643/1000 \t Train Err: 2.3026\n",
+ "Epoch 644/1000 \t Train Err: 2.3027\n",
+ "Epoch 645/1000 \t Train Err: 2.3026\n",
+ "Epoch 646/1000 \t Train Err: 2.3026\n",
+ "Epoch 647/1000 \t Train Err: 2.3025\n",
+ "Epoch 648/1000 \t Train Err: 2.3026\n",
+ "Epoch 649/1000 \t Train Err: 2.3026\n",
+ "Epoch 650/1000 \t Train Err: 2.3025\n",
+ "Epoch 651/1000 \t Train Err: 2.3026\n",
+ "Epoch 652/1000 \t Train Err: 2.3025\n",
+ "Epoch 653/1000 \t Train Err: 2.3026\n",
+ "Epoch 654/1000 \t Train Err: 2.3026\n",
+ "Epoch 655/1000 \t Train Err: 2.3026\n",
+ "Epoch 656/1000 \t Train Err: 2.3026\n",
+ "Epoch 657/1000 \t Train Err: 2.3025\n",
+ "Epoch 658/1000 \t Train Err: 2.3026\n",
+ "Epoch 659/1000 \t Train Err: 2.3025\n",
+ "Epoch 660/1000 \t Train Err: 2.3025\n",
+ "Epoch 661/1000 \t Train Err: 2.3025\n",
+ "Epoch 662/1000 \t Train Err: 2.3026\n",
+ "Epoch 663/1000 \t Train Err: 2.3026\n",
+ "Epoch 664/1000 \t Train Err: 2.3025\n",
+ "Epoch 665/1000 \t Train Err: 2.3026\n",
+ "Epoch 666/1000 \t Train Err: 2.3026\n",
+ "Epoch 667/1000 \t Train Err: 2.3025\n",
+ "Epoch 668/1000 \t Train Err: 2.3026\n",
+ "Epoch 669/1000 \t Train Err: 2.3026\n",
+ "Epoch 670/1000 \t Train Err: 2.3025\n",
+ "Epoch 671/1000 \t Train Err: 2.3026\n",
+ "Epoch 672/1000 \t Train Err: 2.3025\n",
+ "Epoch 673/1000 \t Train Err: 2.3024\n",
+ "Epoch 674/1000 \t Train Err: 2.3024\n",
+ "Epoch 675/1000 \t Train Err: 2.3024\n",
+ "Epoch 676/1000 \t Train Err: 2.3024\n",
+ "Epoch 677/1000 \t Train Err: 2.3023\n",
+ "Epoch 678/1000 \t Train Err: 2.3024\n",
+ "Epoch 679/1000 \t Train Err: 2.3022\n",
+ "Epoch 680/1000 \t Train Err: 2.3022\n",
+ "Epoch 681/1000 \t Train Err: 2.3022\n",
+ "Epoch 682/1000 \t Train Err: 2.3020\n",
+ "Epoch 683/1000 \t Train Err: 2.3018\n",
+ "Epoch 684/1000 \t Train Err: 2.3016\n",
+ "Epoch 685/1000 \t Train Err: 2.3014\n",
+ "Epoch 686/1000 \t Train Err: 2.3011\n",
+ "Epoch 687/1000 \t Train Err: 2.3007\n",
+ "Epoch 688/1000 \t Train Err: 2.3007\n",
+ "Epoch 689/1000 \t Train Err: 2.2999\n",
+ "Epoch 690/1000 \t Train Err: 2.2999\n",
+ "Epoch 691/1000 \t Train Err: 2.2993\n",
+ "Epoch 692/1000 \t Train Err: 2.2993\n",
+ "Epoch 693/1000 \t Train Err: 2.2988\n",
+ "Epoch 694/1000 \t Train Err: 2.2987\n",
+ "Epoch 695/1000 \t Train Err: 2.2983\n",
+ "Epoch 696/1000 \t Train Err: 2.2976\n",
+ "Epoch 697/1000 \t Train Err: 2.2974\n",
+ "Epoch 698/1000 \t Train Err: 2.2969\n",
+ "Epoch 699/1000 \t Train Err: 2.2975\n",
+ "Epoch 700/1000 \t Train Err: 2.2955\n",
+ "Epoch 701/1000 \t Train Err: 2.2967\n",
+ "Epoch 702/1000 \t Train Err: 2.2958\n",
+ "Epoch 703/1000 \t Train Err: 2.2933\n",
+ "Epoch 704/1000 \t Train Err: 2.2951\n",
+ "Epoch 705/1000 \t Train Err: 2.2939\n",
+ "Epoch 706/1000 \t Train Err: 2.2922\n",
+ "Epoch 707/1000 \t Train Err: 2.2919\n",
+ "Epoch 708/1000 \t Train Err: 2.2901\n",
+ "Epoch 709/1000 \t Train Err: 2.2897\n",
+ "Epoch 710/1000 \t Train Err: 2.2867\n",
+ "Epoch 711/1000 \t Train Err: 2.2855\n",
+ "Epoch 712/1000 \t Train Err: 2.2841\n",
+ "Epoch 713/1000 \t Train Err: 2.2844\n",
+ "Epoch 714/1000 \t Train Err: 2.2812\n",
+ "Epoch 715/1000 \t Train Err: 2.2801\n",
+ "Epoch 716/1000 \t Train Err: 2.2789\n",
+ "Epoch 717/1000 \t Train Err: 2.2761\n",
+ "Epoch 718/1000 \t Train Err: 2.2797\n",
+ "Epoch 719/1000 \t Train Err: 2.2796\n",
+ "Epoch 720/1000 \t Train Err: 2.2974\n",
+ "Epoch 721/1000 \t Train Err: 2.2786\n",
+ "Epoch 722/1000 \t Train Err: 2.2802\n",
+ "Epoch 723/1000 \t Train Err: 2.2805\n",
+ "Epoch 724/1000 \t Train Err: 2.2812\n",
+ "Epoch 725/1000 \t Train Err: 2.2812\n",
+ "Epoch 726/1000 \t Train Err: 2.2792\n",
+ "Epoch 727/1000 \t Train Err: 2.2780\n",
+ "Epoch 728/1000 \t Train Err: 2.2775\n",
+ "Epoch 729/1000 \t Train Err: 2.2750\n",
+ "Epoch 730/1000 \t Train Err: 2.2821\n",
+ "Epoch 731/1000 \t Train Err: 2.2815\n",
+ "Epoch 732/1000 \t Train Err: 2.2812\n",
+ "Epoch 733/1000 \t Train Err: 2.2779\n",
+ "Epoch 734/1000 \t Train Err: 2.2777\n",
+ "Epoch 735/1000 \t Train Err: 2.2799\n",
+ "Epoch 736/1000 \t Train Err: 2.2754\n",
+ "Epoch 737/1000 \t Train Err: 2.2742\n",
+ "Epoch 738/1000 \t Train Err: 2.2723\n",
+ "Epoch 739/1000 \t Train Err: 2.2719\n",
+ "Epoch 740/1000 \t Train Err: 2.2674\n",
+ "Epoch 741/1000 \t Train Err: 2.2694\n",
+ "Epoch 742/1000 \t Train Err: 2.2702\n",
+ "Epoch 743/1000 \t Train Err: 2.2693\n",
+ "Epoch 744/1000 \t Train Err: 2.2722\n",
+ "Epoch 745/1000 \t Train Err: 2.2704\n",
+ "Epoch 746/1000 \t Train Err: 2.2675\n",
+ "Epoch 747/1000 \t Train Err: 2.2644\n",
+ "Epoch 748/1000 \t Train Err: 2.2599\n",
+ "Epoch 749/1000 \t Train Err: 2.2583\n",
+ "Epoch 750/1000 \t Train Err: 2.2578\n",
+ "Epoch 751/1000 \t Train Err: 2.2507\n",
+ "Epoch 752/1000 \t Train Err: 2.2490\n",
+ "Epoch 753/1000 \t Train Err: 2.2501\n",
+ "Epoch 754/1000 \t Train Err: 2.2502\n",
+ "Epoch 755/1000 \t Train Err: 2.2520\n",
+ "Epoch 756/1000 \t Train Err: 2.2435\n",
+ "Epoch 757/1000 \t Train Err: 2.2432\n",
+ "Epoch 758/1000 \t Train Err: 2.2420\n",
+ "Epoch 759/1000 \t Train Err: 2.2393\n",
+ "Epoch 760/1000 \t Train Err: 2.2372\n",
+ "Epoch 761/1000 \t Train Err: 2.2302\n",
+ "Epoch 762/1000 \t Train Err: 2.2302\n",
+ "Epoch 763/1000 \t Train Err: 2.2294\n",
+ "Epoch 764/1000 \t Train Err: 2.2201\n",
+ "Epoch 765/1000 \t Train Err: 2.2195\n",
+ "Epoch 766/1000 \t Train Err: 2.2166\n",
+ "Epoch 767/1000 \t Train Err: 2.2139\n",
+ "Epoch 768/1000 \t Train Err: 2.2096\n",
+ "Epoch 769/1000 \t Train Err: 2.2100\n",
+ "Epoch 770/1000 \t Train Err: 2.2073\n",
+ "Epoch 771/1000 \t Train Err: 2.2058\n",
+ "Epoch 772/1000 \t Train Err: 2.2096\n",
+ "Epoch 773/1000 \t Train Err: 2.2055\n",
+ "Epoch 774/1000 \t Train Err: 2.2213\n",
+ "Epoch 775/1000 \t Train Err: 2.2435\n",
+ "Epoch 776/1000 \t Train Err: 2.2282\n",
+ "Epoch 777/1000 \t Train Err: 2.2328\n",
+ "Epoch 778/1000 \t Train Err: 2.2254\n",
+ "Epoch 779/1000 \t Train Err: 2.2246\n",
+ "Epoch 780/1000 \t Train Err: 2.2241\n",
+ "Epoch 781/1000 \t Train Err: 2.2217\n",
+ "Epoch 782/1000 \t Train Err: 2.2156\n",
+ "Epoch 783/1000 \t Train Err: 2.2219\n",
+ "Epoch 784/1000 \t Train Err: 2.2151\n",
+ "Epoch 785/1000 \t Train Err: 2.2259\n",
+ "Epoch 786/1000 \t Train Err: 2.2226\n",
+ "Epoch 787/1000 \t Train Err: 2.2176\n",
+ "Epoch 788/1000 \t Train Err: 2.2152\n",
+ "Epoch 789/1000 \t Train Err: 2.2099\n",
+ "Epoch 790/1000 \t Train Err: 2.2069\n",
+ "Epoch 791/1000 \t Train Err: 2.2034\n",
+ "Epoch 792/1000 \t Train Err: 2.2080\n",
+ "Epoch 793/1000 \t Train Err: 2.1999\n",
+ "Epoch 794/1000 \t Train Err: 2.1925\n",
+ "Epoch 795/1000 \t Train Err: 2.1840\n",
+ "Epoch 796/1000 \t Train Err: 2.1820\n",
+ "Epoch 797/1000 \t Train Err: 2.1907\n",
+ "Epoch 798/1000 \t Train Err: 2.1835\n",
+ "Epoch 799/1000 \t Train Err: 2.1886\n",
+ "Epoch 800/1000 \t Train Err: 2.1807\n",
+ "Epoch 801/1000 \t Train Err: 2.1841\n",
+ "Epoch 802/1000 \t Train Err: 2.1776\n",
+ "Epoch 803/1000 \t Train Err: 2.1800\n",
+ "Epoch 804/1000 \t Train Err: 2.1715\n",
+ "Epoch 805/1000 \t Train Err: 2.1717\n",
+ "Epoch 806/1000 \t Train Err: 2.1701\n",
+ "Epoch 807/1000 \t Train Err: 2.1635\n",
+ "Epoch 808/1000 \t Train Err: 2.1664\n",
+ "Epoch 809/1000 \t Train Err: 2.1603\n",
+ "Epoch 810/1000 \t Train Err: 2.1636\n",
+ "Epoch 811/1000 \t Train Err: 2.1575\n",
+ "Epoch 812/1000 \t Train Err: 2.1587\n",
+ "Epoch 813/1000 \t Train Err: 2.1559\n",
+ "Epoch 814/1000 \t Train Err: 2.1540\n",
+ "Epoch 815/1000 \t Train Err: 2.1537\n",
+ "Epoch 816/1000 \t Train Err: 2.1514\n",
+ "Epoch 817/1000 \t Train Err: 2.1500\n",
+ "Epoch 818/1000 \t Train Err: 2.1488\n",
+ "Epoch 819/1000 \t Train Err: 2.1475\n",
+ "Epoch 820/1000 \t Train Err: 2.1447\n",
+ "Epoch 821/1000 \t Train Err: 2.1434\n",
+ "Epoch 822/1000 \t Train Err: 2.1431\n",
+ "Epoch 823/1000 \t Train Err: 2.1441\n",
+ "Epoch 824/1000 \t Train Err: 2.1816\n",
+ "Epoch 825/1000 \t Train Err: 2.1863\n",
+ "Epoch 826/1000 \t Train Err: 2.1601\n",
+ "Epoch 827/1000 \t Train Err: 2.1623\n",
+ "Epoch 828/1000 \t Train Err: 2.1957\n",
+ "Epoch 829/1000 \t Train Err: 2.1775\n",
+ "Epoch 830/1000 \t Train Err: 2.1971\n",
+ "Epoch 831/1000 \t Train Err: 2.1851\n",
+ "Epoch 832/1000 \t Train Err: 2.1738\n",
+ "Epoch 833/1000 \t Train Err: 2.1654\n",
+ "Epoch 834/1000 \t Train Err: 2.1627\n",
+ "Epoch 835/1000 \t Train Err: 2.1606\n",
+ "Epoch 836/1000 \t Train Err: 2.1487\n",
+ "Epoch 837/1000 \t Train Err: 2.1494\n",
+ "Epoch 838/1000 \t Train Err: 2.1563\n",
+ "Epoch 839/1000 \t Train Err: 2.1521\n",
+ "Epoch 840/1000 \t Train Err: 2.1515\n",
+ "Epoch 841/1000 \t Train Err: 2.1484\n",
+ "Epoch 842/1000 \t Train Err: 2.1476\n",
+ "Epoch 843/1000 \t Train Err: 2.1406\n",
+ "Epoch 844/1000 \t Train Err: 2.1410\n",
+ "Epoch 845/1000 \t Train Err: 2.1359\n",
+ "Epoch 846/1000 \t Train Err: 2.1344\n",
+ "Epoch 847/1000 \t Train Err: 2.1323\n",
+ "Epoch 848/1000 \t Train Err: 2.1236\n",
+ "Epoch 849/1000 \t Train Err: 2.1241\n",
+ "Epoch 850/1000 \t Train Err: 2.1162\n",
+ "Epoch 851/1000 \t Train Err: 2.1179\n",
+ "Epoch 852/1000 \t Train Err: 2.1033\n",
+ "Epoch 853/1000 \t Train Err: 2.1022\n",
+ "Epoch 854/1000 \t Train Err: 2.1009\n",
+ "Epoch 855/1000 \t Train Err: 2.0978\n",
+ "Epoch 856/1000 \t Train Err: 2.0911\n",
+ "Epoch 857/1000 \t Train Err: 2.0932\n",
+ "Epoch 858/1000 \t Train Err: 2.0898\n",
+ "Epoch 859/1000 \t Train Err: 2.0844\n",
+ "Epoch 860/1000 \t Train Err: 2.0767\n",
+ "Epoch 861/1000 \t Train Err: 2.0732\n",
+ "Epoch 862/1000 \t Train Err: 2.0769\n",
+ "Epoch 863/1000 \t Train Err: 2.0725\n",
+ "Epoch 864/1000 \t Train Err: 2.0700\n",
+ "Epoch 865/1000 \t Train Err: 2.0612\n",
+ "Epoch 866/1000 \t Train Err: 2.0637\n",
+ "Epoch 867/1000 \t Train Err: 2.0580\n",
+ "Epoch 868/1000 \t Train Err: 2.0598\n",
+ "Epoch 869/1000 \t Train Err: 2.0535\n",
+ "Epoch 870/1000 \t Train Err: 2.0503\n",
+ "Epoch 871/1000 \t Train Err: 2.0492\n",
+ "Epoch 872/1000 \t Train Err: 2.0431\n",
+ "Epoch 873/1000 \t Train Err: 2.0423\n",
+ "Epoch 874/1000 \t Train Err: 2.0382\n",
+ "Epoch 875/1000 \t Train Err: 2.0328\n",
+ "Epoch 876/1000 \t Train Err: 2.0313\n",
+ "Epoch 877/1000 \t Train Err: 2.0280\n",
+ "Epoch 878/1000 \t Train Err: 2.0297\n",
+ "Epoch 879/1000 \t Train Err: 2.0243\n",
+ "Epoch 880/1000 \t Train Err: 2.0243\n",
+ "Epoch 881/1000 \t Train Err: 2.0222\n",
+ "Epoch 882/1000 \t Train Err: 2.0209\n",
+ "Epoch 883/1000 \t Train Err: 2.0161\n",
+ "Epoch 884/1000 \t Train Err: 2.0157\n",
+ "Epoch 885/1000 \t Train Err: 2.0253\n",
+ "Epoch 886/1000 \t Train Err: 2.0697\n",
+ "Epoch 887/1000 \t Train Err: 2.2021\n",
+ "Epoch 888/1000 \t Train Err: 2.2692\n",
+ "Epoch 889/1000 \t Train Err: 2.1106\n",
+ "Epoch 890/1000 \t Train Err: 2.1653\n",
+ "Epoch 891/1000 \t Train Err: 2.2021\n",
+ "Epoch 892/1000 \t Train Err: 2.1370\n",
+ "Epoch 893/1000 \t Train Err: 2.1576\n",
+ "Epoch 894/1000 \t Train Err: 2.1296\n",
+ "Epoch 895/1000 \t Train Err: 2.1303\n",
+ "Epoch 896/1000 \t Train Err: 2.1201\n",
+ "Epoch 897/1000 \t Train Err: 2.1001\n",
+ "Epoch 898/1000 \t Train Err: 2.1209\n",
+ "Epoch 899/1000 \t Train Err: 2.1034\n",
+ "Epoch 900/1000 \t Train Err: 2.1103\n",
+ "Epoch 901/1000 \t Train Err: 2.0983\n",
+ "Epoch 902/1000 \t Train Err: 2.0762\n",
+ "Epoch 903/1000 \t Train Err: 2.0929\n",
+ "Epoch 904/1000 \t Train Err: 2.0643\n",
+ "Epoch 905/1000 \t Train Err: 2.0555\n",
+ "Epoch 906/1000 \t Train Err: 2.0589\n",
+ "Epoch 907/1000 \t Train Err: 2.0454\n",
+ "Epoch 908/1000 \t Train Err: 2.0500\n",
+ "Epoch 909/1000 \t Train Err: 2.0418\n",
+ "Epoch 910/1000 \t Train Err: 2.0363\n",
+ "Epoch 911/1000 \t Train Err: 2.0357\n",
+ "Epoch 912/1000 \t Train Err: 2.0323\n",
+ "Epoch 913/1000 \t Train Err: 2.0282\n",
+ "Epoch 914/1000 \t Train Err: 2.0242\n",
+ "Epoch 915/1000 \t Train Err: 2.0120\n",
+ "Epoch 916/1000 \t Train Err: 2.0127\n",
+ "Epoch 917/1000 \t Train Err: 2.0133\n",
+ "Epoch 918/1000 \t Train Err: 2.0097\n",
+ "Epoch 919/1000 \t Train Err: 2.0087\n",
+ "Epoch 920/1000 \t Train Err: 2.0099\n",
+ "Epoch 921/1000 \t Train Err: 2.0076\n",
+ "Epoch 922/1000 \t Train Err: 2.0020\n",
+ "Epoch 923/1000 \t Train Err: 1.9990\n",
+ "Epoch 924/1000 \t Train Err: 1.9967\n",
+ "Epoch 925/1000 \t Train Err: 1.9966\n",
+ "Epoch 926/1000 \t Train Err: 1.9946\n",
+ "Epoch 927/1000 \t Train Err: 1.9904\n",
+ "Epoch 928/1000 \t Train Err: 1.9874\n",
+ "Epoch 929/1000 \t Train Err: 1.9974\n",
+ "Epoch 930/1000 \t Train Err: 1.9857\n",
+ "Epoch 931/1000 \t Train Err: 1.9892\n",
+ "Epoch 932/1000 \t Train Err: 1.9947\n",
+ "Epoch 933/1000 \t Train Err: 1.9974\n",
+ "Epoch 934/1000 \t Train Err: 2.0159\n",
+ "Epoch 935/1000 \t Train Err: 2.0433\n",
+ "Epoch 936/1000 \t Train Err: 2.0755\n",
+ "Epoch 937/1000 \t Train Err: 2.0014\n",
+ "Epoch 938/1000 \t Train Err: 2.0443\n",
+ "Epoch 939/1000 \t Train Err: 2.0184\n",
+ "Epoch 940/1000 \t Train Err: 2.0192\n",
+ "Epoch 941/1000 \t Train Err: 2.0248\n",
+ "Epoch 942/1000 \t Train Err: 2.0124\n",
+ "Epoch 943/1000 \t Train Err: 2.0101\n",
+ "Epoch 944/1000 \t Train Err: 2.0024\n",
+ "Epoch 945/1000 \t Train Err: 2.0011\n",
+ "Epoch 946/1000 \t Train Err: 1.9871\n",
+ "Epoch 947/1000 \t Train Err: 1.9816\n",
+ "Epoch 948/1000 \t Train Err: 1.9875\n",
+ "Epoch 949/1000 \t Train Err: 2.0660\n",
+ "Epoch 950/1000 \t Train Err: 2.0591\n",
+ "Epoch 951/1000 \t Train Err: 2.0214\n",
+ "Epoch 952/1000 \t Train Err: 2.0312\n",
+ "Epoch 953/1000 \t Train Err: 2.0470\n",
+ "Epoch 954/1000 \t Train Err: 2.0365\n",
+ "Epoch 955/1000 \t Train Err: 2.0143\n",
+ "Epoch 956/1000 \t Train Err: 2.0104\n",
+ "Epoch 957/1000 \t Train Err: 2.0289\n",
+ "Epoch 958/1000 \t Train Err: 2.0097\n",
+ "Epoch 959/1000 \t Train Err: 1.9998\n",
+ "Epoch 960/1000 \t Train Err: 2.0095\n",
+ "Epoch 961/1000 \t Train Err: 2.0110\n",
+ "Epoch 962/1000 \t Train Err: 2.0009\n",
+ "Epoch 963/1000 \t Train Err: 1.9930\n",
+ "Epoch 964/1000 \t Train Err: 2.0003\n",
+ "Epoch 965/1000 \t Train Err: 1.9912\n",
+ "Epoch 966/1000 \t Train Err: 1.9859\n",
+ "Epoch 967/1000 \t Train Err: 1.9843\n",
+ "Epoch 968/1000 \t Train Err: 1.9828\n",
+ "Epoch 969/1000 \t Train Err: 1.9776\n",
+ "Epoch 970/1000 \t Train Err: 1.9790\n",
+ "Epoch 971/1000 \t Train Err: 1.9697\n",
+ "Epoch 972/1000 \t Train Err: 1.9671\n",
+ "Epoch 973/1000 \t Train Err: 1.9673\n",
+ "Epoch 974/1000 \t Train Err: 1.9585\n",
+ "Epoch 975/1000 \t Train Err: 1.9605\n",
+ "Epoch 976/1000 \t Train Err: 1.9537\n",
+ "Epoch 977/1000 \t Train Err: 1.9529\n",
+ "Epoch 978/1000 \t Train Err: 1.9477\n",
+ "Epoch 979/1000 \t Train Err: 1.9485\n",
+ "Epoch 980/1000 \t Train Err: 1.9376\n",
+ "Epoch 981/1000 \t Train Err: 1.9426\n",
+ "Epoch 982/1000 \t Train Err: 1.9416\n",
+ "Epoch 983/1000 \t Train Err: 1.9334\n",
+ "Epoch 984/1000 \t Train Err: 1.9249\n",
+ "Epoch 985/1000 \t Train Err: 1.9216\n",
+ "Epoch 986/1000 \t Train Err: 1.9268\n",
+ "Epoch 987/1000 \t Train Err: 1.9630\n",
+ "Epoch 988/1000 \t Train Err: 2.0237\n",
+ "Epoch 989/1000 \t Train Err: 2.0037\n",
+ "Epoch 990/1000 \t Train Err: 1.9824\n",
+ "Epoch 991/1000 \t Train Err: 1.9718\n",
+ "Epoch 992/1000 \t Train Err: 1.9726\n",
+ "Epoch 993/1000 \t Train Err: 1.9536\n",
+ "Epoch 994/1000 \t Train Err: 1.9662\n",
+ "Epoch 995/1000 \t Train Err: 1.9492\n",
+ "Epoch 996/1000 \t Train Err: 1.9482\n",
+ "Epoch 997/1000 \t Train Err: 1.9375\n",
+ "Epoch 998/1000 \t Train Err: 1.9492\n",
+ "Epoch 999/1000 \t Train Err: 1.9351\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.train()\n",
+ "for epoch in range(NEPOCHS):\n",
+ " optimizer.zero_grad()\n",
+ " data, labels = mkbatch_rtl(BSZ)\n",
+ " # shift labels to prevent cheating\n",
+ " shifted_labels = torch.roll(labels, 1, dims=1)\n",
+ " shifted_labels[:, 0] = VOCAB_SIZE # start token\n",
+ " outputs = model(data, shifted_labels).permute((0, 2, 1))\n",
+ " loss = criterion(outputs, labels)\n",
+ " train_loss = loss.item()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " train_err.append(train_loss)\n",
+ "\n",
+ " with open('loss', 'a') as f:\n",
+ " f.write(f\"{train_loss}\\n\")\n",
+ " print(f\"Epoch {epoch}/{NEPOCHS} \\t Train Err: {train_loss:.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 96,
+ "execution_state": "idle",
+ "id": "a3c41150-4541-4722-83a7-e7ad937f6c4f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[3, 8, 4, 0]], device='cuda:0') tensor([[7, 8]], device='cuda:0')\n",
+ "tensor([-4.4248e+00, -1.0567e+00, 1.2971e+00, -2.0221e+00, -6.6597e-01,\n",
+ " -2.6027e+00, -1.5254e-02, 8.1894e+00, -1.6939e-03, -1.2252e+00],\n",
+ " device='cuda:0')\n",
+ "tensor([-3.7663, -1.7898, -1.4273, 1.9667, -2.3513, -4.7138, -2.2421, 3.6817,\n",
+ " 8.9049, 3.1622], device='cuda:0')\n",
+ "tensor([[7, 8]], device='cuda:0', dtype=torch.int32) tensor([[7, 8]], device='cuda:0')\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.eval()\n",
+ "data, labels = mkbatch_rtl(1)\n",
+ "print(data, labels)\n",
+ "with torch.no_grad():\n",
+ " ans = torch.zeros((1, NUM_LEN), dtype=torch.int, device=device)\n",
+ " ans[0, 0] = VOCAB_SIZE\n",
+ " for i in range(NUM_LEN):\n",
+ " outputs = model(data, ans)\n",
+ " print(outputs[0, i])\n",
+ " # break\n",
+ " ans[0, (i + 1) % NUM_LEN] = torch.argmax(outputs[0, i])\n",
+ "ans = torch.roll(ans, -1, dims=1)\n",
+ "print(ans, labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "execution_state": "idle",
+ "id": "1843b944-bab5-40ee-b26e-5d3b87ea9454",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "FileNotFoundError",
+ "evalue": "[Errno 2] No such file or directory: 'add-ltr-loss'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[32], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmath\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43madd-ltr-loss\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 5\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;28mrange\u001b[39m(NEPOCHS), \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28;01mlambda\u001b[39;00m x: math\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;28mfloat\u001b[39m(x)), f\u001b[38;5;241m.\u001b[39mreadlines())))\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124madd-rtl-loss\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n",
+ "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 322\u001b[0m )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'add-ltr-loss'"
+ ]
+ }
+ ],
+ "source": [
+ "import math\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "with open(\"add-ltr-loss\") as f:\n",
+ " plt.plot(range(NEPOCHS), list(map(lambda x: math.log(float(x)), f.readlines())))\n",
+ "with open(\"add-rtl-loss\") as f:\n",
+ " plt.plot(range(NEPOCHS), list(map(lambda x: math.log(float(x)), f.readlines())))\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b97b349f-f20b-441d-8c7f-1724e8cf30cc",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}