diff options
-rw-r--r-- | transformer_shortest_paths.ipynb | 881 |
1 files changed, 42 insertions, 839 deletions
diff --git a/transformer_shortest_paths.ipynb b/transformer_shortest_paths.ipynb index 71a40fe..fcb24c7 100644 --- a/transformer_shortest_paths.ipynb +++ b/transformer_shortest_paths.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "execution_state": "idle", "metadata": { "colab": { @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "execution_state": "idle", "metadata": { "id": "lylOX2POPwFL" @@ -306,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "execution_state": "idle", "metadata": { "id": "tLOWhg_CeWzH" @@ -356,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 11, "execution_state": "idle", "metadata": { "colab": { @@ -370,7 +370,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Trainable parameters in the model: 505K\n" + "Training data: 131M\n", + "Trainable parameters in the model: 605K\n" ] } ], @@ -380,11 +381,11 @@ "MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n", "NEPOCHS = 1000\n", "BSZ = 2048 # Batch size\n", - "BPE = 32 # Batches per epoch\n", + "BPE = 64 # Batches per epoch\n", "LR = 0.003\n", "WD = 0.002\n", "NHEADS = 4\n", - "NLAYERS = 10\n", + "NLAYERS = 12\n", "DROPOUT = 0.2\n", "model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n", " output_dim=1, num_heads=NHEADS,\n", @@ -395,6 +396,7 @@ "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)\n", "\n", "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "print(f\"Training data: {NEPOCHS*BSZ*BPE//10**6}M\")\n", "print(f\"Trainable parameters in the model: {trainable_params//1000}K\")" ] }, @@ -467,8 +469,8 @@ }, { "cell_type": "code", - "execution_count": null, - "execution_state": "running", + "execution_count": 12, + "execution_state": "idle", "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -482,833 +484,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/1000 \t Train Err: 80.9485 \t Test Err: 83.3727\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 2/1000 \t Train Err: 81.9558 \t Test Err: 80.7205\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 3/1000 \t Train Err: 81.2221 \t Test Err: 80.6387\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 4/1000 \t Train Err: 81.8502 \t Test Err: 80.0444\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 5/1000 \t Train Err: 81.5031 \t Test Err: 83.7185\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 6/1000 \t Train Err: 81.3043 \t Test Err: 81.4035\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.22it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 7/1000 \t Train Err: 81.0616 \t Test Err: 83.7366\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.28it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 8/1000 \t Train Err: 81.5992 \t Test Err: 81.2875\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 9/1000 \t Train Err: 81.3813 \t Test Err: 80.2028\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 10/1000 \t Train Err: 81.5702 \t Test Err: 82.8906\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 11/1000 \t Train Err: 80.6410 \t Test Err: 81.6353\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" + " 0%| | 0/64 [00:00<?, ?it/s]\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 12/1000 \t Train Err: 81.0706 \t Test Err: 81.4791\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 13/1000 \t Train Err: 81.0538 \t Test Err: 81.0688\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 14/1000 \t Train Err: 81.4753 \t Test Err: 85.5978\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 15/1000 \t Train Err: 81.2319 \t Test Err: 81.5276\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 16/1000 \t Train Err: 82.0405 \t Test Err: 80.4760\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 17/1000 \t Train Err: 81.2955 \t Test Err: 80.1790\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:10<00:00, 3.05it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 18/1000 \t Train Err: 81.3618 \t Test Err: 81.0788\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 19/1000 \t Train Err: 81.4784 \t Test Err: 82.8825\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 20/1000 \t Train Err: 80.7994 \t Test Err: 81.8424\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 21/1000 \t Train Err: 80.9150 \t Test Err: 80.6047\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 22/1000 \t Train Err: 81.7054 \t Test Err: 78.2826\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 23/1000 \t Train Err: 81.6376 \t Test Err: 83.0617\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 24/1000 \t Train Err: 81.1639 \t Test Err: 79.6304\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 25/1000 \t Train Err: 81.9200 \t Test Err: 82.3950\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 26/1000 \t Train Err: 81.0736 \t Test Err: 83.0353\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 27/1000 \t Train Err: 81.8939 \t Test Err: 80.7981\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.35it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 28/1000 \t Train Err: 80.9842 \t Test Err: 80.3877\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 29/1000 \t Train Err: 81.6111 \t Test Err: 82.5336\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 30/1000 \t Train Err: 81.5480 \t Test Err: 82.7556\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 31/1000 \t Train Err: 81.2413 \t Test Err: 82.6558\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 32/1000 \t Train Err: 81.1720 \t Test Err: 82.2116\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 33/1000 \t Train Err: 81.7244 \t Test Err: 79.6762\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 34/1000 \t Train Err: 81.4536 \t Test Err: 84.7001\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.30it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 35/1000 \t Train Err: 82.1729 \t Test Err: 82.0201\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 36/1000 \t Train Err: 81.7041 \t Test Err: 83.0776\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 37/1000 \t Train Err: 81.2599 \t Test Err: 82.2269\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 38/1000 \t Train Err: 81.6489 \t Test Err: 81.6412\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 39/1000 \t Train Err: 81.4363 \t Test Err: 81.9661\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 40/1000 \t Train Err: 81.0156 \t Test Err: 78.0546\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 41/1000 \t Train Err: 81.2752 \t Test Err: 82.3804\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 42/1000 \t Train Err: 81.1951 \t Test Err: 81.7494\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 43/1000 \t Train Err: 81.4909 \t Test Err: 81.9628\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 44/1000 \t Train Err: 81.5728 \t Test Err: 78.5453\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 45/1000 \t Train Err: 81.9706 \t Test Err: 81.1184\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 46/1000 \t Train Err: 81.1537 \t Test Err: 81.5044\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 47/1000 \t Train Err: 80.8373 \t Test Err: 82.4630\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 48/1000 \t Train Err: 81.3666 \t Test Err: 82.1752\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 49/1000 \t Train Err: 81.1630 \t Test Err: 82.7047\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 50/1000 \t Train Err: 81.3882 \t Test Err: 85.8777\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 51/1000 \t Train Err: 81.4415 \t Test Err: 83.4058\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 52/1000 \t Train Err: 81.2446 \t Test Err: 82.6877\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 53/1000 \t Train Err: 81.3113 \t Test Err: 82.0156\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 54/1000 \t Train Err: 81.3483 \t Test Err: 81.1088\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 55/1000 \t Train Err: 81.3773 \t Test Err: 81.1178\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.35it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 56/1000 \t Train Err: 81.0823 \t Test Err: 83.9259\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.34it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 57/1000 \t Train Err: 81.6416 \t Test Err: 81.8139\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.34it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 58/1000 \t Train Err: 81.9228 \t Test Err: 81.7897\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 59/1000 \t Train Err: 81.3041 \t Test Err: 79.5053\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 25%|███████████ | 8/32 [00:02<00:07, 3.32it/s]" + "ename": "NameError", + "evalue": "name 'mkbatch' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(BPE)):\n\u001b[0;32m----> 8\u001b[0m batch_src, batch_labels, batch_padding_mask \u001b[38;5;241m=\u001b[39m \u001b[43mmkbatch\u001b[49m(BSZ)\n\u001b[1;32m 9\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 10\u001b[0m output \u001b[38;5;241m=\u001b[39m model(batch_src, batch_padding_mask)\n", + "\u001b[0;31mNameError\u001b[0m: name 'mkbatch' is not defined" ] } ], @@ -1332,6 +519,8 @@ " \n", " test_err.append(test_loss)\n", " train_err.append(train_loss)\n", + " with open(\"loss\", \"a\") as f:\n", + " f.write(f\"{train_loss} {test_loss}\\n\")\n", " print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n", " ax.plot(train_err, label='Train', color='blue')\n", " ax.plot(test_err, label='Test', color='red')\n", @@ -1345,17 +534,17 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 57, "execution_state": "idle", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "55.06520214080811" + "177.23538208007812" ] }, - "execution_count": 80, + "execution_count": 57, "metadata": {}, "output_type": "execute_result" } @@ -1489,10 +678,22 @@ }, { "cell_type": "code", - "execution_count": null, - "execution_state": "running", + "execution_count": 13, + "execution_state": "idle", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'N_TUNE_EPOCHS' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[13], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m tune_train_err \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 2\u001b[0m tune_test_err \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[43mN_TUNE_EPOCHS\u001b[49m):\n\u001b[1;32m 5\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m 6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n", + "\u001b[0;31mNameError\u001b[0m: name 'N_TUNE_EPOCHS' is not defined" + ] + } + ], "source": [ "tune_train_err = []\n", "tune_test_err = []\n", @@ -1513,6 +714,8 @@ " \n", " tune_test_err.append(test_loss)\n", " tune_train_err.append(train_loss)\n", + " with open(\"tune_loss\", \"a\") as f:\n", + " f.write(f\"{train_loss} {test_loss}\\n\")\n", " ax.plot(tune_train_err, label='Train', color='blue')\n", " ax.plot(tune_test_err, label='Test', color='red')\n", " ax.set_xlabel('Epochs')\n", |