diff options
Diffstat (limited to 'notebooks/addition.ipynb')
-rw-r--r-- | notebooks/addition.ipynb | 836 |
1 files changed, 836 insertions, 0 deletions
diff --git a/notebooks/addition.ipynb b/notebooks/addition.ipynb new file mode 100644 index 0000000..01fddff --- /dev/null +++ b/notebooks/addition.ipynb @@ -0,0 +1,836 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "execution_state": "idle", + "id": "ecaeb29e-fbbe-4876-86ad-9fbadea989b0", + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "execution_state": "idle", + "id": "84a82827-8947-4a26-a485-56f5b1eadb4c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[4, 9, 4, 1, 8, 2],\n", + " [0, 6, 5, 9, 1, 4],\n", + " [4, 9, 1, 0, 5, 5],\n", + " [5, 2, 4, 9, 1, 8],\n", + " [2, 7, 6, 8, 0, 1]], device='cuda:0'),\n", + " tensor([[5, 7, 7],\n", + " [9, 7, 9],\n", + " [4, 4, 7],\n", + " [4, 4, 2],\n", + " [0, 8, 7]], device='cuda:0'))" + ] + }, + "execution_count": 110, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "NUM_LEN = 3\n", + "\n", + "def pad(a):\n", + " s = str(a)\n", + " if len(s) > NUM_LEN:\n", + " return s[-NUM_LEN:]\n", + " return s.zfill(NUM_LEN)\n", + "\n", + "def mkbatch_ltr(size):\n", + " data = []\n", + " labels = []\n", + " for i in range(size):\n", + " a = random.randrange(0, 10**NUM_LEN)\n", + " b = random.randrange(0, 10**NUM_LEN)\n", + " c = a + b\n", + " data.append(list(map(int, pad(a) + pad(b))))\n", + " labels.append(list(map(int, pad(c))))\n", + " return torch.tensor(data, device=device), torch.tensor(labels, device=device)\n", + "\n", + "def mkbatch_rtl(size):\n", + " data, labels = mkbatch_ltr(size)\n", + " return torch.flip(data, (1,)), torch.flip(labels, (1,))\n", + "\n", + "mkbatch_rtl(5)" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "execution_state": "idle", + "id": "d50dce44-57b7-4d4d-895a-c2275c04234c", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerModel(nn.Module):\n", + " def __init__(self, input_dim, model_dim, output_dim, nheads, nenclayers, ndeclayers):\n", + " super().__init__()\n", + " self.emb = nn.Embedding(input_dim, model_dim - 1)\n", + " self.trans = nn.Transformer(d_model=model_dim, nhead=nheads, dim_feedforward=4 * model_dim,\n", + " num_encoder_layers=nenclayers, num_decoder_layers=ndeclayers,\n", + " dropout=0, batch_first=True)\n", + " self.output = nn.Linear(model_dim, output_dim)\n", + "\n", + " def forward(self, data, labels):\n", + " bsz = data.size(0)\n", + " data_pos = (torch.arange(2 * NUM_LEN, device=device) % NUM_LEN).expand(bsz, -1)\n", + " labels_pos = (torch.arange(NUM_LEN, device=device)).expand(bsz, -1)\n", + " data_emb = torch.cat((self.emb(data), data_pos.unsqueeze(2)), 2)\n", + " labels_emb = torch.cat((self.emb(labels), labels_pos.unsqueeze(2)), 2)\n", + " return self.output(self.trans(data_emb, labels_emb, tgt_mask=TGT_MASK, tgt_is_causal=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "execution_state": "idle", + "id": "ddad4059-b06e-4eb3-a55a-5a4a842cdd7a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training data: 32768K\n", + "Trainable parameters in the model: 1251\n" + ] + } + ], + "source": [ + "MODEL_DIM = 4 # Dimension of model\n", + "VOCAB_SIZE = 10\n", + "NEPOCHS = 1000\n", + "BSZ = 2**15 # Batch size\n", + "NHEADS = 1\n", + "NENCLAYERS = 2\n", + "NDECLAYERS = 2\n", + "\n", + "LR = 1e-2\n", + "\n", + "TGT_MASK = nn.Transformer.generate_square_subsequent_mask(NUM_LEN)\n", + "model = TransformerModel(VOCAB_SIZE + 1, MODEL_DIM, VOCAB_SIZE, NHEADS, NENCLAYERS, NDECLAYERS).to(device)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n", + "\n", + "train_err = []\n", + "open('loss', 'w').close()\n", + "\n", + "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "print(f\"Training data: {NEPOCHS*BSZ//10**3}K\")\n", + "print(f\"Trainable parameters in the model: {trainable_params}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "execution_state": "idle", + "id": "689f2e44-da84-43ea-b539-414d6f5c37e3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0/1000 \t Train Err: 2.4793\n", + "Epoch 1/1000 \t Train Err: 2.4310\n", + "Epoch 2/1000 \t Train Err: 2.3800\n", + "Epoch 3/1000 \t Train Err: 2.3493\n", + "Epoch 4/1000 \t Train Err: 2.3288\n", + "Epoch 5/1000 \t Train Err: 2.3202\n", + "Epoch 6/1000 \t Train Err: 2.3171\n", + "Epoch 7/1000 \t Train Err: 2.3139\n", + "Epoch 8/1000 \t Train Err: 2.3095\n", + "Epoch 9/1000 \t Train Err: 2.3064\n", + "Epoch 10/1000 \t Train Err: 2.3040\n", + "Epoch 11/1000 \t Train Err: 2.3029\n", + "Epoch 12/1000 \t Train Err: 2.3030\n", + "Epoch 13/1000 \t Train Err: 2.3037\n", + "Epoch 14/1000 \t Train Err: 2.3047\n", + "Epoch 15/1000 \t Train Err: 2.3060\n", + "Epoch 16/1000 \t Train Err: 2.3067\n", + "Epoch 17/1000 \t Train Err: 2.3067\n", + "Epoch 18/1000 \t Train Err: 2.3068\n", + "Epoch 19/1000 \t Train Err: 2.3059\n", + "Epoch 20/1000 \t Train Err: 2.3060\n", + "Epoch 21/1000 \t Train Err: 2.3052\n", + "Epoch 22/1000 \t Train Err: 2.3044\n", + "Epoch 23/1000 \t Train Err: 2.3039\n", + "Epoch 24/1000 \t Train Err: 2.3039\n", + "Epoch 25/1000 \t Train Err: 2.3033\n", + "Epoch 26/1000 \t Train Err: 2.3032\n", + "Epoch 27/1000 \t Train Err: 2.3032\n", + "Epoch 28/1000 \t Train Err: 2.3032\n", + "Epoch 29/1000 \t Train Err: 2.3029\n", + "Epoch 30/1000 \t Train Err: 2.3028\n", + "Epoch 31/1000 \t Train Err: 2.3032\n", + "Epoch 32/1000 \t Train Err: 2.3031\n", + "Epoch 33/1000 \t Train Err: 2.3030\n", + "Epoch 34/1000 \t Train Err: 2.3031\n", + "Epoch 35/1000 \t Train Err: 2.3031\n", + "Epoch 36/1000 \t Train Err: 2.3031\n", + "Epoch 37/1000 \t Train Err: 2.3029\n", + "Epoch 38/1000 \t Train Err: 2.3033\n", + "Epoch 39/1000 \t Train Err: 2.3032\n", + "Epoch 40/1000 \t Train Err: 2.3031\n", + "Epoch 41/1000 \t Train Err: 2.3030\n", + "Epoch 42/1000 \t Train Err: 2.3027\n", + "Epoch 43/1000 \t Train Err: 2.3027\n", + "Epoch 44/1000 \t Train Err: 2.3027\n", + "Epoch 45/1000 \t Train Err: 2.3027\n", + "Epoch 46/1000 \t Train Err: 2.3026\n", + "Epoch 47/1000 \t Train Err: 2.3026\n", + "Epoch 48/1000 \t Train Err: 2.3027\n", + "Epoch 49/1000 \t Train Err: 2.3026\n", + "Epoch 50/1000 \t Train Err: 2.3027\n", + "Epoch 51/1000 \t Train Err: 2.3027\n", + "Epoch 52/1000 \t Train Err: 2.3027\n", + "Epoch 53/1000 \t Train Err: 2.3026\n", + "Epoch 54/1000 \t Train Err: 2.3028\n", + "Epoch 55/1000 \t Train Err: 2.3027\n", + "Epoch 56/1000 \t Train Err: 2.3026\n", + "Epoch 57/1000 \t Train Err: 2.3027\n", + "Epoch 58/1000 \t Train Err: 2.3026\n", + "Epoch 59/1000 \t Train Err: 2.3026\n", + "Epoch 60/1000 \t Train Err: 2.3027\n", + "Epoch 61/1000 \t Train Err: 2.3026\n", + "Epoch 62/1000 \t Train Err: 2.3026\n", + "Epoch 500/1000 \t Train Err: 2.3026\n", + "Epoch 501/1000 \t Train Err: 2.3025\n", + "Epoch 502/1000 \t Train Err: 2.3026\n", + "Epoch 503/1000 \t Train Err: 2.3026\n", + "Epoch 504/1000 \t Train Err: 2.3026\n", + "Epoch 505/1000 \t Train Err: 2.3026\n", + "Epoch 506/1000 \t Train Err: 2.3026\n", + "Epoch 507/1000 \t Train Err: 2.3026\n", + "Epoch 508/1000 \t Train Err: 2.3026\n", + "Epoch 509/1000 \t Train Err: 2.3026\n", + "Epoch 510/1000 \t Train Err: 2.3026\n", + "Epoch 511/1000 \t Train Err: 2.3026\n", + "Epoch 512/1000 \t Train Err: 2.3026\n", + "Epoch 513/1000 \t Train Err: 2.3026\n", + "Epoch 514/1000 \t Train Err: 2.3025\n", + "Epoch 515/1000 \t Train Err: 2.3027\n", + "Epoch 516/1000 \t Train Err: 2.3027\n", + "Epoch 517/1000 \t Train Err: 2.3027\n", + "Epoch 518/1000 \t Train Err: 2.3026\n", + "Epoch 519/1000 \t Train Err: 2.3026\n", + "Epoch 520/1000 \t Train Err: 2.3026\n", + "Epoch 521/1000 \t Train Err: 2.3025\n", + "Epoch 522/1000 \t Train Err: 2.3027\n", + "Epoch 523/1000 \t Train Err: 2.3027\n", + "Epoch 524/1000 \t Train Err: 2.3027\n", + "Epoch 525/1000 \t Train Err: 2.3026\n", + "Epoch 526/1000 \t Train Err: 2.3026\n", + "Epoch 527/1000 \t Train Err: 2.3026\n", + "Epoch 528/1000 \t Train Err: 2.3027\n", + "Epoch 529/1000 \t Train Err: 2.3028\n", + "Epoch 530/1000 \t Train Err: 2.3026\n", + "Epoch 531/1000 \t Train Err: 2.3026\n", + "Epoch 532/1000 \t Train Err: 2.3026\n", + "Epoch 533/1000 \t Train Err: 2.3026\n", + "Epoch 534/1000 \t Train Err: 2.3026\n", + "Epoch 535/1000 \t Train Err: 2.3026\n", + "Epoch 536/1000 \t Train Err: 2.3027\n", + "Epoch 537/1000 \t Train Err: 2.3027\n", + "Epoch 538/1000 \t Train Err: 2.3025\n", + "Epoch 539/1000 \t Train Err: 2.3026\n", + "Epoch 540/1000 \t Train Err: 2.3026\n", + "Epoch 541/1000 \t Train Err: 2.3026\n", + "Epoch 542/1000 \t Train Err: 2.3026\n", + "Epoch 543/1000 \t Train Err: 2.3026\n", + "Epoch 544/1000 \t Train Err: 2.3026\n", + "Epoch 545/1000 \t Train Err: 2.3026\n", + "Epoch 546/1000 \t Train Err: 2.3027\n", + "Epoch 547/1000 \t Train Err: 2.3026\n", + "Epoch 548/1000 \t Train Err: 2.3026\n", + "Epoch 549/1000 \t Train Err: 2.3026\n", + "Epoch 550/1000 \t Train Err: 2.3026\n", + "Epoch 551/1000 \t Train Err: 2.3026\n", + "Epoch 552/1000 \t Train Err: 2.3028\n", + "Epoch 553/1000 \t Train Err: 2.3028\n", + "Epoch 554/1000 \t Train Err: 2.3027\n", + "Epoch 555/1000 \t Train Err: 2.3026\n", + "Epoch 556/1000 \t Train Err: 2.3027\n", + "Epoch 557/1000 \t Train Err: 2.3027\n", + "Epoch 558/1000 \t Train Err: 2.3028\n", + "Epoch 559/1000 \t Train Err: 2.3026\n", + "Epoch 560/1000 \t Train Err: 2.3026\n", + "Epoch 561/1000 \t Train Err: 2.3026\n", + "Epoch 562/1000 \t Train Err: 2.3027\n", + "Epoch 563/1000 \t Train Err: 2.3027\n", + "Epoch 564/1000 \t Train Err: 2.3027\n", + "Epoch 565/1000 \t Train Err: 2.3025\n", + "Epoch 566/1000 \t Train Err: 2.3026\n", + "Epoch 567/1000 \t Train Err: 2.3026\n", + "Epoch 568/1000 \t Train Err: 2.3026\n", + "Epoch 569/1000 \t Train Err: 2.3026\n", + "Epoch 570/1000 \t Train Err: 2.3026\n", + "Epoch 571/1000 \t Train Err: 2.3026\n", + "Epoch 572/1000 \t Train Err: 2.3026\n", + "Epoch 573/1000 \t Train Err: 2.3026\n", + "Epoch 574/1000 \t Train Err: 2.3026\n", + "Epoch 575/1000 \t Train Err: 2.3028\n", + "Epoch 576/1000 \t Train Err: 2.3026\n", + "Epoch 577/1000 \t Train Err: 2.3026\n", + "Epoch 578/1000 \t Train Err: 2.3025\n", + "Epoch 579/1000 \t Train Err: 2.3026\n", + "Epoch 580/1000 \t Train Err: 2.3026\n", + "Epoch 581/1000 \t Train Err: 2.3027\n", + "Epoch 582/1000 \t Train Err: 2.3026\n", + "Epoch 583/1000 \t Train Err: 2.3027\n", + "Epoch 584/1000 \t Train Err: 2.3027\n", + "Epoch 585/1000 \t Train Err: 2.3026\n", + "Epoch 586/1000 \t Train Err: 2.3026\n", + "Epoch 587/1000 \t Train Err: 2.3026\n", + "Epoch 588/1000 \t Train Err: 2.3026\n", + "Epoch 589/1000 \t Train Err: 2.3027\n", + "Epoch 590/1000 \t Train Err: 2.3026\n", + "Epoch 591/1000 \t Train Err: 2.3026\n", + "Epoch 592/1000 \t Train Err: 2.3026\n", + "Epoch 593/1000 \t Train Err: 2.3026\n", + "Epoch 594/1000 \t Train Err: 2.3026\n", + "Epoch 595/1000 \t Train Err: 2.3026\n", + "Epoch 596/1000 \t Train Err: 2.3026\n", + "Epoch 597/1000 \t Train Err: 2.3027\n", + "Epoch 598/1000 \t Train Err: 2.3026\n", + "Epoch 599/1000 \t Train Err: 2.3027\n", + "Epoch 600/1000 \t Train Err: 2.3027\n", + "Epoch 601/1000 \t Train Err: 2.3026\n", + "Epoch 602/1000 \t Train Err: 2.3026\n", + "Epoch 603/1000 \t Train Err: 2.3026\n", + "Epoch 604/1000 \t Train Err: 2.3026\n", + "Epoch 605/1000 \t Train Err: 2.3026\n", + "Epoch 606/1000 \t Train Err: 2.3027\n", + "Epoch 607/1000 \t Train Err: 2.3026\n", + "Epoch 608/1000 \t Train Err: 2.3026\n", + "Epoch 609/1000 \t Train Err: 2.3026\n", + "Epoch 610/1000 \t Train Err: 2.3026\n", + "Epoch 611/1000 \t Train Err: 2.3026\n", + "Epoch 612/1000 \t Train Err: 2.3027\n", + "Epoch 613/1000 \t Train Err: 2.3025\n", + "Epoch 614/1000 \t Train Err: 2.3026\n", + "Epoch 615/1000 \t Train Err: 2.3026\n", + "Epoch 616/1000 \t Train Err: 2.3026\n", + "Epoch 617/1000 \t Train Err: 2.3026\n", + "Epoch 618/1000 \t Train Err: 2.3026\n", + "Epoch 619/1000 \t Train Err: 2.3026\n", + "Epoch 620/1000 \t Train Err: 2.3026\n", + "Epoch 621/1000 \t Train Err: 2.3026\n", + "Epoch 622/1000 \t Train Err: 2.3026\n", + "Epoch 623/1000 \t Train Err: 2.3026\n", + "Epoch 624/1000 \t Train Err: 2.3026\n", + "Epoch 625/1000 \t Train Err: 2.3026\n", + "Epoch 626/1000 \t Train Err: 2.3026\n", + "Epoch 627/1000 \t Train Err: 2.3026\n", + "Epoch 628/1000 \t Train Err: 2.3026\n", + "Epoch 629/1000 \t Train Err: 2.3026\n", + "Epoch 630/1000 \t Train Err: 2.3027\n", + "Epoch 631/1000 \t Train Err: 2.3026\n", + "Epoch 632/1000 \t Train Err: 2.3026\n", + "Epoch 633/1000 \t Train Err: 2.3025\n", + "Epoch 634/1000 \t Train Err: 2.3026\n", + "Epoch 635/1000 \t Train Err: 2.3026\n", + "Epoch 636/1000 \t Train Err: 2.3026\n", + "Epoch 637/1000 \t Train Err: 2.3026\n", + "Epoch 638/1000 \t Train Err: 2.3026\n", + "Epoch 639/1000 \t Train Err: 2.3027\n", + "Epoch 640/1000 \t Train Err: 2.3026\n", + "Epoch 641/1000 \t Train Err: 2.3026\n", + "Epoch 642/1000 \t Train Err: 2.3026\n", + "Epoch 643/1000 \t Train Err: 2.3026\n", + "Epoch 644/1000 \t Train Err: 2.3027\n", + "Epoch 645/1000 \t Train Err: 2.3026\n", + "Epoch 646/1000 \t Train Err: 2.3026\n", + "Epoch 647/1000 \t Train Err: 2.3025\n", + "Epoch 648/1000 \t Train Err: 2.3026\n", + "Epoch 649/1000 \t Train Err: 2.3026\n", + "Epoch 650/1000 \t Train Err: 2.3025\n", + "Epoch 651/1000 \t Train Err: 2.3026\n", + "Epoch 652/1000 \t Train Err: 2.3025\n", + "Epoch 653/1000 \t Train Err: 2.3026\n", + "Epoch 654/1000 \t Train Err: 2.3026\n", + "Epoch 655/1000 \t Train Err: 2.3026\n", + "Epoch 656/1000 \t Train Err: 2.3026\n", + "Epoch 657/1000 \t Train Err: 2.3025\n", + "Epoch 658/1000 \t Train Err: 2.3026\n", + "Epoch 659/1000 \t Train Err: 2.3025\n", + "Epoch 660/1000 \t Train Err: 2.3025\n", + "Epoch 661/1000 \t Train Err: 2.3025\n", + "Epoch 662/1000 \t Train Err: 2.3026\n", + "Epoch 663/1000 \t Train Err: 2.3026\n", + "Epoch 664/1000 \t Train Err: 2.3025\n", + "Epoch 665/1000 \t Train Err: 2.3026\n", + "Epoch 666/1000 \t Train Err: 2.3026\n", + "Epoch 667/1000 \t Train Err: 2.3025\n", + "Epoch 668/1000 \t Train Err: 2.3026\n", + "Epoch 669/1000 \t Train Err: 2.3026\n", + "Epoch 670/1000 \t Train Err: 2.3025\n", + "Epoch 671/1000 \t Train Err: 2.3026\n", + "Epoch 672/1000 \t Train Err: 2.3025\n", + "Epoch 673/1000 \t Train Err: 2.3024\n", + "Epoch 674/1000 \t Train Err: 2.3024\n", + "Epoch 675/1000 \t Train Err: 2.3024\n", + "Epoch 676/1000 \t Train Err: 2.3024\n", + "Epoch 677/1000 \t Train Err: 2.3023\n", + "Epoch 678/1000 \t Train Err: 2.3024\n", + "Epoch 679/1000 \t Train Err: 2.3022\n", + "Epoch 680/1000 \t Train Err: 2.3022\n", + "Epoch 681/1000 \t Train Err: 2.3022\n", + "Epoch 682/1000 \t Train Err: 2.3020\n", + "Epoch 683/1000 \t Train Err: 2.3018\n", + "Epoch 684/1000 \t Train Err: 2.3016\n", + "Epoch 685/1000 \t Train Err: 2.3014\n", + "Epoch 686/1000 \t Train Err: 2.3011\n", + "Epoch 687/1000 \t Train Err: 2.3007\n", + "Epoch 688/1000 \t Train Err: 2.3007\n", + "Epoch 689/1000 \t Train Err: 2.2999\n", + "Epoch 690/1000 \t Train Err: 2.2999\n", + "Epoch 691/1000 \t Train Err: 2.2993\n", + "Epoch 692/1000 \t Train Err: 2.2993\n", + "Epoch 693/1000 \t Train Err: 2.2988\n", + "Epoch 694/1000 \t Train Err: 2.2987\n", + "Epoch 695/1000 \t Train Err: 2.2983\n", + "Epoch 696/1000 \t Train Err: 2.2976\n", + "Epoch 697/1000 \t Train Err: 2.2974\n", + "Epoch 698/1000 \t Train Err: 2.2969\n", + "Epoch 699/1000 \t Train Err: 2.2975\n", + "Epoch 700/1000 \t Train Err: 2.2955\n", + "Epoch 701/1000 \t Train Err: 2.2967\n", + "Epoch 702/1000 \t Train Err: 2.2958\n", + "Epoch 703/1000 \t Train Err: 2.2933\n", + "Epoch 704/1000 \t Train Err: 2.2951\n", + "Epoch 705/1000 \t Train Err: 2.2939\n", + "Epoch 706/1000 \t Train Err: 2.2922\n", + "Epoch 707/1000 \t Train Err: 2.2919\n", + "Epoch 708/1000 \t Train Err: 2.2901\n", + "Epoch 709/1000 \t Train Err: 2.2897\n", + "Epoch 710/1000 \t Train Err: 2.2867\n", + "Epoch 711/1000 \t Train Err: 2.2855\n", + "Epoch 712/1000 \t Train Err: 2.2841\n", + "Epoch 713/1000 \t Train Err: 2.2844\n", + "Epoch 714/1000 \t Train Err: 2.2812\n", + "Epoch 715/1000 \t Train Err: 2.2801\n", + "Epoch 716/1000 \t Train Err: 2.2789\n", + "Epoch 717/1000 \t Train Err: 2.2761\n", + "Epoch 718/1000 \t Train Err: 2.2797\n", + "Epoch 719/1000 \t Train Err: 2.2796\n", + "Epoch 720/1000 \t Train Err: 2.2974\n", + "Epoch 721/1000 \t Train Err: 2.2786\n", + "Epoch 722/1000 \t Train Err: 2.2802\n", + "Epoch 723/1000 \t Train Err: 2.2805\n", + "Epoch 724/1000 \t Train Err: 2.2812\n", + "Epoch 725/1000 \t Train Err: 2.2812\n", + "Epoch 726/1000 \t Train Err: 2.2792\n", + "Epoch 727/1000 \t Train Err: 2.2780\n", + "Epoch 728/1000 \t Train Err: 2.2775\n", + "Epoch 729/1000 \t Train Err: 2.2750\n", + "Epoch 730/1000 \t Train Err: 2.2821\n", + "Epoch 731/1000 \t Train Err: 2.2815\n", + "Epoch 732/1000 \t Train Err: 2.2812\n", + "Epoch 733/1000 \t Train Err: 2.2779\n", + "Epoch 734/1000 \t Train Err: 2.2777\n", + "Epoch 735/1000 \t Train Err: 2.2799\n", + "Epoch 736/1000 \t Train Err: 2.2754\n", + "Epoch 737/1000 \t Train Err: 2.2742\n", + "Epoch 738/1000 \t Train Err: 2.2723\n", + "Epoch 739/1000 \t Train Err: 2.2719\n", + "Epoch 740/1000 \t Train Err: 2.2674\n", + "Epoch 741/1000 \t Train Err: 2.2694\n", + "Epoch 742/1000 \t Train Err: 2.2702\n", + "Epoch 743/1000 \t Train Err: 2.2693\n", + "Epoch 744/1000 \t Train Err: 2.2722\n", + "Epoch 745/1000 \t Train Err: 2.2704\n", + "Epoch 746/1000 \t Train Err: 2.2675\n", + "Epoch 747/1000 \t Train Err: 2.2644\n", + "Epoch 748/1000 \t Train Err: 2.2599\n", + "Epoch 749/1000 \t Train Err: 2.2583\n", + "Epoch 750/1000 \t Train Err: 2.2578\n", + "Epoch 751/1000 \t Train Err: 2.2507\n", + "Epoch 752/1000 \t Train Err: 2.2490\n", + "Epoch 753/1000 \t Train Err: 2.2501\n", + "Epoch 754/1000 \t Train Err: 2.2502\n", + "Epoch 755/1000 \t Train Err: 2.2520\n", + "Epoch 756/1000 \t Train Err: 2.2435\n", + "Epoch 757/1000 \t Train Err: 2.2432\n", + "Epoch 758/1000 \t Train Err: 2.2420\n", + "Epoch 759/1000 \t Train Err: 2.2393\n", + "Epoch 760/1000 \t Train Err: 2.2372\n", + "Epoch 761/1000 \t Train Err: 2.2302\n", + "Epoch 762/1000 \t Train Err: 2.2302\n", + "Epoch 763/1000 \t Train Err: 2.2294\n", + "Epoch 764/1000 \t Train Err: 2.2201\n", + "Epoch 765/1000 \t Train Err: 2.2195\n", + "Epoch 766/1000 \t Train Err: 2.2166\n", + "Epoch 767/1000 \t Train Err: 2.2139\n", + "Epoch 768/1000 \t Train Err: 2.2096\n", + "Epoch 769/1000 \t Train Err: 2.2100\n", + "Epoch 770/1000 \t Train Err: 2.2073\n", + "Epoch 771/1000 \t Train Err: 2.2058\n", + "Epoch 772/1000 \t Train Err: 2.2096\n", + "Epoch 773/1000 \t Train Err: 2.2055\n", + "Epoch 774/1000 \t Train Err: 2.2213\n", + "Epoch 775/1000 \t Train Err: 2.2435\n", + "Epoch 776/1000 \t Train Err: 2.2282\n", + "Epoch 777/1000 \t Train Err: 2.2328\n", + "Epoch 778/1000 \t Train Err: 2.2254\n", + "Epoch 779/1000 \t Train Err: 2.2246\n", + "Epoch 780/1000 \t Train Err: 2.2241\n", + "Epoch 781/1000 \t Train Err: 2.2217\n", + "Epoch 782/1000 \t Train Err: 2.2156\n", + "Epoch 783/1000 \t Train Err: 2.2219\n", + "Epoch 784/1000 \t Train Err: 2.2151\n", + "Epoch 785/1000 \t Train Err: 2.2259\n", + "Epoch 786/1000 \t Train Err: 2.2226\n", + "Epoch 787/1000 \t Train Err: 2.2176\n", + "Epoch 788/1000 \t Train Err: 2.2152\n", + "Epoch 789/1000 \t Train Err: 2.2099\n", + "Epoch 790/1000 \t Train Err: 2.2069\n", + "Epoch 791/1000 \t Train Err: 2.2034\n", + "Epoch 792/1000 \t Train Err: 2.2080\n", + "Epoch 793/1000 \t Train Err: 2.1999\n", + "Epoch 794/1000 \t Train Err: 2.1925\n", + "Epoch 795/1000 \t Train Err: 2.1840\n", + "Epoch 796/1000 \t Train Err: 2.1820\n", + "Epoch 797/1000 \t Train Err: 2.1907\n", + "Epoch 798/1000 \t Train Err: 2.1835\n", + "Epoch 799/1000 \t Train Err: 2.1886\n", + "Epoch 800/1000 \t Train Err: 2.1807\n", + "Epoch 801/1000 \t Train Err: 2.1841\n", + "Epoch 802/1000 \t Train Err: 2.1776\n", + "Epoch 803/1000 \t Train Err: 2.1800\n", + "Epoch 804/1000 \t Train Err: 2.1715\n", + "Epoch 805/1000 \t Train Err: 2.1717\n", + "Epoch 806/1000 \t Train Err: 2.1701\n", + "Epoch 807/1000 \t Train Err: 2.1635\n", + "Epoch 808/1000 \t Train Err: 2.1664\n", + "Epoch 809/1000 \t Train Err: 2.1603\n", + "Epoch 810/1000 \t Train Err: 2.1636\n", + "Epoch 811/1000 \t Train Err: 2.1575\n", + "Epoch 812/1000 \t Train Err: 2.1587\n", + "Epoch 813/1000 \t Train Err: 2.1559\n", + "Epoch 814/1000 \t Train Err: 2.1540\n", + "Epoch 815/1000 \t Train Err: 2.1537\n", + "Epoch 816/1000 \t Train Err: 2.1514\n", + "Epoch 817/1000 \t Train Err: 2.1500\n", + "Epoch 818/1000 \t Train Err: 2.1488\n", + "Epoch 819/1000 \t Train Err: 2.1475\n", + "Epoch 820/1000 \t Train Err: 2.1447\n", + "Epoch 821/1000 \t Train Err: 2.1434\n", + "Epoch 822/1000 \t Train Err: 2.1431\n", + "Epoch 823/1000 \t Train Err: 2.1441\n", + "Epoch 824/1000 \t Train Err: 2.1816\n", + "Epoch 825/1000 \t Train Err: 2.1863\n", + "Epoch 826/1000 \t Train Err: 2.1601\n", + "Epoch 827/1000 \t Train Err: 2.1623\n", + "Epoch 828/1000 \t Train Err: 2.1957\n", + "Epoch 829/1000 \t Train Err: 2.1775\n", + "Epoch 830/1000 \t Train Err: 2.1971\n", + "Epoch 831/1000 \t Train Err: 2.1851\n", + "Epoch 832/1000 \t Train Err: 2.1738\n", + "Epoch 833/1000 \t Train Err: 2.1654\n", + "Epoch 834/1000 \t Train Err: 2.1627\n", + "Epoch 835/1000 \t Train Err: 2.1606\n", + "Epoch 836/1000 \t Train Err: 2.1487\n", + "Epoch 837/1000 \t Train Err: 2.1494\n", + "Epoch 838/1000 \t Train Err: 2.1563\n", + "Epoch 839/1000 \t Train Err: 2.1521\n", + "Epoch 840/1000 \t Train Err: 2.1515\n", + "Epoch 841/1000 \t Train Err: 2.1484\n", + "Epoch 842/1000 \t Train Err: 2.1476\n", + "Epoch 843/1000 \t Train Err: 2.1406\n", + "Epoch 844/1000 \t Train Err: 2.1410\n", + "Epoch 845/1000 \t Train Err: 2.1359\n", + "Epoch 846/1000 \t Train Err: 2.1344\n", + "Epoch 847/1000 \t Train Err: 2.1323\n", + "Epoch 848/1000 \t Train Err: 2.1236\n", + "Epoch 849/1000 \t Train Err: 2.1241\n", + "Epoch 850/1000 \t Train Err: 2.1162\n", + "Epoch 851/1000 \t Train Err: 2.1179\n", + "Epoch 852/1000 \t Train Err: 2.1033\n", + "Epoch 853/1000 \t Train Err: 2.1022\n", + "Epoch 854/1000 \t Train Err: 2.1009\n", + "Epoch 855/1000 \t Train Err: 2.0978\n", + "Epoch 856/1000 \t Train Err: 2.0911\n", + "Epoch 857/1000 \t Train Err: 2.0932\n", + "Epoch 858/1000 \t Train Err: 2.0898\n", + "Epoch 859/1000 \t Train Err: 2.0844\n", + "Epoch 860/1000 \t Train Err: 2.0767\n", + "Epoch 861/1000 \t Train Err: 2.0732\n", + "Epoch 862/1000 \t Train Err: 2.0769\n", + "Epoch 863/1000 \t Train Err: 2.0725\n", + "Epoch 864/1000 \t Train Err: 2.0700\n", + "Epoch 865/1000 \t Train Err: 2.0612\n", + "Epoch 866/1000 \t Train Err: 2.0637\n", + "Epoch 867/1000 \t Train Err: 2.0580\n", + "Epoch 868/1000 \t Train Err: 2.0598\n", + "Epoch 869/1000 \t Train Err: 2.0535\n", + "Epoch 870/1000 \t Train Err: 2.0503\n", + "Epoch 871/1000 \t Train Err: 2.0492\n", + "Epoch 872/1000 \t Train Err: 2.0431\n", + "Epoch 873/1000 \t Train Err: 2.0423\n", + "Epoch 874/1000 \t Train Err: 2.0382\n", + "Epoch 875/1000 \t Train Err: 2.0328\n", + "Epoch 876/1000 \t Train Err: 2.0313\n", + "Epoch 877/1000 \t Train Err: 2.0280\n", + "Epoch 878/1000 \t Train Err: 2.0297\n", + "Epoch 879/1000 \t Train Err: 2.0243\n", + "Epoch 880/1000 \t Train Err: 2.0243\n", + "Epoch 881/1000 \t Train Err: 2.0222\n", + "Epoch 882/1000 \t Train Err: 2.0209\n", + "Epoch 883/1000 \t Train Err: 2.0161\n", + "Epoch 884/1000 \t Train Err: 2.0157\n", + "Epoch 885/1000 \t Train Err: 2.0253\n", + "Epoch 886/1000 \t Train Err: 2.0697\n", + "Epoch 887/1000 \t Train Err: 2.2021\n", + "Epoch 888/1000 \t Train Err: 2.2692\n", + "Epoch 889/1000 \t Train Err: 2.1106\n", + "Epoch 890/1000 \t Train Err: 2.1653\n", + "Epoch 891/1000 \t Train Err: 2.2021\n", + "Epoch 892/1000 \t Train Err: 2.1370\n", + "Epoch 893/1000 \t Train Err: 2.1576\n", + "Epoch 894/1000 \t Train Err: 2.1296\n", + "Epoch 895/1000 \t Train Err: 2.1303\n", + "Epoch 896/1000 \t Train Err: 2.1201\n", + "Epoch 897/1000 \t Train Err: 2.1001\n", + "Epoch 898/1000 \t Train Err: 2.1209\n", + "Epoch 899/1000 \t Train Err: 2.1034\n", + "Epoch 900/1000 \t Train Err: 2.1103\n", + "Epoch 901/1000 \t Train Err: 2.0983\n", + "Epoch 902/1000 \t Train Err: 2.0762\n", + "Epoch 903/1000 \t Train Err: 2.0929\n", + "Epoch 904/1000 \t Train Err: 2.0643\n", + "Epoch 905/1000 \t Train Err: 2.0555\n", + "Epoch 906/1000 \t Train Err: 2.0589\n", + "Epoch 907/1000 \t Train Err: 2.0454\n", + "Epoch 908/1000 \t Train Err: 2.0500\n", + "Epoch 909/1000 \t Train Err: 2.0418\n", + "Epoch 910/1000 \t Train Err: 2.0363\n", + "Epoch 911/1000 \t Train Err: 2.0357\n", + "Epoch 912/1000 \t Train Err: 2.0323\n", + "Epoch 913/1000 \t Train Err: 2.0282\n", + "Epoch 914/1000 \t Train Err: 2.0242\n", + "Epoch 915/1000 \t Train Err: 2.0120\n", + "Epoch 916/1000 \t Train Err: 2.0127\n", + "Epoch 917/1000 \t Train Err: 2.0133\n", + "Epoch 918/1000 \t Train Err: 2.0097\n", + "Epoch 919/1000 \t Train Err: 2.0087\n", + "Epoch 920/1000 \t Train Err: 2.0099\n", + "Epoch 921/1000 \t Train Err: 2.0076\n", + "Epoch 922/1000 \t Train Err: 2.0020\n", + "Epoch 923/1000 \t Train Err: 1.9990\n", + "Epoch 924/1000 \t Train Err: 1.9967\n", + "Epoch 925/1000 \t Train Err: 1.9966\n", + "Epoch 926/1000 \t Train Err: 1.9946\n", + "Epoch 927/1000 \t Train Err: 1.9904\n", + "Epoch 928/1000 \t Train Err: 1.9874\n", + "Epoch 929/1000 \t Train Err: 1.9974\n", + "Epoch 930/1000 \t Train Err: 1.9857\n", + "Epoch 931/1000 \t Train Err: 1.9892\n", + "Epoch 932/1000 \t Train Err: 1.9947\n", + "Epoch 933/1000 \t Train Err: 1.9974\n", + "Epoch 934/1000 \t Train Err: 2.0159\n", + "Epoch 935/1000 \t Train Err: 2.0433\n", + "Epoch 936/1000 \t Train Err: 2.0755\n", + "Epoch 937/1000 \t Train Err: 2.0014\n", + "Epoch 938/1000 \t Train Err: 2.0443\n", + "Epoch 939/1000 \t Train Err: 2.0184\n", + "Epoch 940/1000 \t Train Err: 2.0192\n", + "Epoch 941/1000 \t Train Err: 2.0248\n", + "Epoch 942/1000 \t Train Err: 2.0124\n", + "Epoch 943/1000 \t Train Err: 2.0101\n", + "Epoch 944/1000 \t Train Err: 2.0024\n", + "Epoch 945/1000 \t Train Err: 2.0011\n", + "Epoch 946/1000 \t Train Err: 1.9871\n", + "Epoch 947/1000 \t Train Err: 1.9816\n", + "Epoch 948/1000 \t Train Err: 1.9875\n", + "Epoch 949/1000 \t Train Err: 2.0660\n", + "Epoch 950/1000 \t Train Err: 2.0591\n", + "Epoch 951/1000 \t Train Err: 2.0214\n", + "Epoch 952/1000 \t Train Err: 2.0312\n", + "Epoch 953/1000 \t Train Err: 2.0470\n", + "Epoch 954/1000 \t Train Err: 2.0365\n", + "Epoch 955/1000 \t Train Err: 2.0143\n", + "Epoch 956/1000 \t Train Err: 2.0104\n", + "Epoch 957/1000 \t Train Err: 2.0289\n", + "Epoch 958/1000 \t Train Err: 2.0097\n", + "Epoch 959/1000 \t Train Err: 1.9998\n", + "Epoch 960/1000 \t Train Err: 2.0095\n", + "Epoch 961/1000 \t Train Err: 2.0110\n", + "Epoch 962/1000 \t Train Err: 2.0009\n", + "Epoch 963/1000 \t Train Err: 1.9930\n", + "Epoch 964/1000 \t Train Err: 2.0003\n", + "Epoch 965/1000 \t Train Err: 1.9912\n", + "Epoch 966/1000 \t Train Err: 1.9859\n", + "Epoch 967/1000 \t Train Err: 1.9843\n", + "Epoch 968/1000 \t Train Err: 1.9828\n", + "Epoch 969/1000 \t Train Err: 1.9776\n", + "Epoch 970/1000 \t Train Err: 1.9790\n", + "Epoch 971/1000 \t Train Err: 1.9697\n", + "Epoch 972/1000 \t Train Err: 1.9671\n", + "Epoch 973/1000 \t Train Err: 1.9673\n", + "Epoch 974/1000 \t Train Err: 1.9585\n", + "Epoch 975/1000 \t Train Err: 1.9605\n", + "Epoch 976/1000 \t Train Err: 1.9537\n", + "Epoch 977/1000 \t Train Err: 1.9529\n", + "Epoch 978/1000 \t Train Err: 1.9477\n", + "Epoch 979/1000 \t Train Err: 1.9485\n", + "Epoch 980/1000 \t Train Err: 1.9376\n", + "Epoch 981/1000 \t Train Err: 1.9426\n", + "Epoch 982/1000 \t Train Err: 1.9416\n", + "Epoch 983/1000 \t Train Err: 1.9334\n", + "Epoch 984/1000 \t Train Err: 1.9249\n", + "Epoch 985/1000 \t Train Err: 1.9216\n", + "Epoch 986/1000 \t Train Err: 1.9268\n", + "Epoch 987/1000 \t Train Err: 1.9630\n", + "Epoch 988/1000 \t Train Err: 2.0237\n", + "Epoch 989/1000 \t Train Err: 2.0037\n", + "Epoch 990/1000 \t Train Err: 1.9824\n", + "Epoch 991/1000 \t Train Err: 1.9718\n", + "Epoch 992/1000 \t Train Err: 1.9726\n", + "Epoch 993/1000 \t Train Err: 1.9536\n", + "Epoch 994/1000 \t Train Err: 1.9662\n", + "Epoch 995/1000 \t Train Err: 1.9492\n", + "Epoch 996/1000 \t Train Err: 1.9482\n", + "Epoch 997/1000 \t Train Err: 1.9375\n", + "Epoch 998/1000 \t Train Err: 1.9492\n", + "Epoch 999/1000 \t Train Err: 1.9351\n" + ] + } + ], + "source": [ + "model.train()\n", + "for epoch in range(NEPOCHS):\n", + " optimizer.zero_grad()\n", + " data, labels = mkbatch_rtl(BSZ)\n", + " # shift labels to prevent cheating\n", + " shifted_labels = torch.roll(labels, 1, dims=1)\n", + " shifted_labels[:, 0] = VOCAB_SIZE # start token\n", + " outputs = model(data, shifted_labels).permute((0, 2, 1))\n", + " loss = criterion(outputs, labels)\n", + " train_loss = loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_err.append(train_loss)\n", + "\n", + " with open('loss', 'a') as f:\n", + " f.write(f\"{train_loss}\\n\")\n", + " print(f\"Epoch {epoch}/{NEPOCHS} \\t Train Err: {train_loss:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "execution_state": "idle", + "id": "a3c41150-4541-4722-83a7-e7ad937f6c4f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[3, 8, 4, 0]], device='cuda:0') tensor([[7, 8]], device='cuda:0')\n", + "tensor([-4.4248e+00, -1.0567e+00, 1.2971e+00, -2.0221e+00, -6.6597e-01,\n", + " -2.6027e+00, -1.5254e-02, 8.1894e+00, -1.6939e-03, -1.2252e+00],\n", + " device='cuda:0')\n", + "tensor([-3.7663, -1.7898, -1.4273, 1.9667, -2.3513, -4.7138, -2.2421, 3.6817,\n", + " 8.9049, 3.1622], device='cuda:0')\n", + "tensor([[7, 8]], device='cuda:0', dtype=torch.int32) tensor([[7, 8]], device='cuda:0')\n" + ] + } + ], + "source": [ + "model.eval()\n", + "data, labels = mkbatch_rtl(1)\n", + "print(data, labels)\n", + "with torch.no_grad():\n", + " ans = torch.zeros((1, NUM_LEN), dtype=torch.int, device=device)\n", + " ans[0, 0] = VOCAB_SIZE\n", + " for i in range(NUM_LEN):\n", + " outputs = model(data, ans)\n", + " print(outputs[0, i])\n", + " # break\n", + " ans[0, (i + 1) % NUM_LEN] = torch.argmax(outputs[0, i])\n", + "ans = torch.roll(ans, -1, dims=1)\n", + "print(ans, labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "execution_state": "idle", + "id": "1843b944-bab5-40ee-b26e-5d3b87ea9454", + "metadata": {}, + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'add-ltr-loss'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[32], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmath\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43madd-ltr-loss\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 5\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;28mrange\u001b[39m(NEPOCHS), \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28;01mlambda\u001b[39;00m x: math\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;28mfloat\u001b[39m(x)), f\u001b[38;5;241m.\u001b[39mreadlines())))\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124madd-rtl-loss\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n", + "File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 322\u001b[0m )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'add-ltr-loss'" + ] + } + ], + "source": [ + "import math\n", + "import matplotlib.pyplot as plt\n", + "\n", + "with open(\"add-ltr-loss\") as f:\n", + " plt.plot(range(NEPOCHS), list(map(lambda x: math.log(float(x)), f.readlines())))\n", + "with open(\"add-rtl-loss\") as f:\n", + " plt.plot(range(NEPOCHS), list(map(lambda x: math.log(float(x)), f.readlines())))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b97b349f-f20b-441d-8c7f-1724e8cf30cc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |