aboutsummaryrefslogtreecommitdiff
path: root/transformer_shortest_paths.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'transformer_shortest_paths.ipynb')
-rw-r--r--transformer_shortest_paths.ipynb881
1 files changed, 42 insertions, 839 deletions
diff --git a/transformer_shortest_paths.ipynb b/transformer_shortest_paths.ipynb
index 71a40fe..fcb24c7 100644
--- a/transformer_shortest_paths.ipynb
+++ b/transformer_shortest_paths.ipynb
@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 2,
"execution_state": "idle",
"metadata": {
"colab": {
@@ -49,7 +49,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 3,
"execution_state": "idle",
"metadata": {
"id": "lylOX2POPwFL"
@@ -306,7 +306,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 5,
"execution_state": "idle",
"metadata": {
"id": "tLOWhg_CeWzH"
@@ -356,7 +356,7 @@
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 11,
"execution_state": "idle",
"metadata": {
"colab": {
@@ -370,7 +370,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Trainable parameters in the model: 505K\n"
+ "Training data: 131M\n",
+ "Trainable parameters in the model: 605K\n"
]
}
],
@@ -380,11 +381,11 @@
"MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
"NEPOCHS = 1000\n",
"BSZ = 2048 # Batch size\n",
- "BPE = 32 # Batches per epoch\n",
+ "BPE = 64 # Batches per epoch\n",
"LR = 0.003\n",
"WD = 0.002\n",
"NHEADS = 4\n",
- "NLAYERS = 10\n",
+ "NLAYERS = 12\n",
"DROPOUT = 0.2\n",
"model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
" output_dim=1, num_heads=NHEADS,\n",
@@ -395,6 +396,7 @@
"optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)\n",
"\n",
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
+ "print(f\"Training data: {NEPOCHS*BSZ*BPE//10**6}M\")\n",
"print(f\"Trainable parameters in the model: {trainable_params//1000}K\")"
]
},
@@ -467,8 +469,8 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "execution_state": "running",
+ "execution_count": 12,
+ "execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -482,833 +484,18 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 1/1000 \t Train Err: 80.9485 \t Test Err: 83.3727\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 2/1000 \t Train Err: 81.9558 \t Test Err: 80.7205\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 3/1000 \t Train Err: 81.2221 \t Test Err: 80.6387\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 4/1000 \t Train Err: 81.8502 \t Test Err: 80.0444\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 5/1000 \t Train Err: 81.5031 \t Test Err: 83.7185\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 6/1000 \t Train Err: 81.3043 \t Test Err: 81.4035\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.22it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 7/1000 \t Train Err: 81.0616 \t Test Err: 83.7366\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.28it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 8/1000 \t Train Err: 81.5992 \t Test Err: 81.2875\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 9/1000 \t Train Err: 81.3813 \t Test Err: 80.2028\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 10/1000 \t Train Err: 81.5702 \t Test Err: 82.8906\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 11/1000 \t Train Err: 80.6410 \t Test Err: 81.6353\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
+ " 0%| | 0/64 [00:00<?, ?it/s]\n"
]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 12/1000 \t Train Err: 81.0706 \t Test Err: 81.4791\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 13/1000 \t Train Err: 81.0538 \t Test Err: 81.0688\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 14/1000 \t Train Err: 81.4753 \t Test Err: 85.5978\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 15/1000 \t Train Err: 81.2319 \t Test Err: 81.5276\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 16/1000 \t Train Err: 82.0405 \t Test Err: 80.4760\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 17/1000 \t Train Err: 81.2955 \t Test Err: 80.1790\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:10<00:00, 3.05it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 18/1000 \t Train Err: 81.3618 \t Test Err: 81.0788\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 19/1000 \t Train Err: 81.4784 \t Test Err: 82.8825\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 20/1000 \t Train Err: 80.7994 \t Test Err: 81.8424\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 21/1000 \t Train Err: 80.9150 \t Test Err: 80.6047\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 22/1000 \t Train Err: 81.7054 \t Test Err: 78.2826\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 23/1000 \t Train Err: 81.6376 \t Test Err: 83.0617\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 24/1000 \t Train Err: 81.1639 \t Test Err: 79.6304\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 25/1000 \t Train Err: 81.9200 \t Test Err: 82.3950\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 26/1000 \t Train Err: 81.0736 \t Test Err: 83.0353\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 27/1000 \t Train Err: 81.8939 \t Test Err: 80.7981\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.35it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 28/1000 \t Train Err: 80.9842 \t Test Err: 80.3877\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 29/1000 \t Train Err: 81.6111 \t Test Err: 82.5336\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 30/1000 \t Train Err: 81.5480 \t Test Err: 82.7556\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 31/1000 \t Train Err: 81.2413 \t Test Err: 82.6558\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 32/1000 \t Train Err: 81.1720 \t Test Err: 82.2116\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 33/1000 \t Train Err: 81.7244 \t Test Err: 79.6762\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 34/1000 \t Train Err: 81.4536 \t Test Err: 84.7001\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.30it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 35/1000 \t Train Err: 82.1729 \t Test Err: 82.0201\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 36/1000 \t Train Err: 81.7041 \t Test Err: 83.0776\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 37/1000 \t Train Err: 81.2599 \t Test Err: 82.2269\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 38/1000 \t Train Err: 81.6489 \t Test Err: 81.6412\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 39/1000 \t Train Err: 81.4363 \t Test Err: 81.9661\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.32it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 40/1000 \t Train Err: 81.0156 \t Test Err: 78.0546\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.36it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 41/1000 \t Train Err: 81.2752 \t Test Err: 82.3804\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 42/1000 \t Train Err: 81.1951 \t Test Err: 81.7494\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 43/1000 \t Train Err: 81.4909 \t Test Err: 81.9628\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 44/1000 \t Train Err: 81.5728 \t Test Err: 78.5453\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 45/1000 \t Train Err: 81.9706 \t Test Err: 81.1184\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 46/1000 \t Train Err: 81.1537 \t Test Err: 81.5044\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.37it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 47/1000 \t Train Err: 80.8373 \t Test Err: 82.4630\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 48/1000 \t Train Err: 81.3666 \t Test Err: 82.1752\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 49/1000 \t Train Err: 81.1630 \t Test Err: 82.7047\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 50/1000 \t Train Err: 81.3882 \t Test Err: 85.8777\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 51/1000 \t Train Err: 81.4415 \t Test Err: 83.4058\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.38it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 52/1000 \t Train Err: 81.2446 \t Test Err: 82.6877\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 53/1000 \t Train Err: 81.3113 \t Test Err: 82.0156\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.31it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 54/1000 \t Train Err: 81.3483 \t Test Err: 81.1088\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 55/1000 \t Train Err: 81.3773 \t Test Err: 81.1178\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.35it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 56/1000 \t Train Err: 81.0823 \t Test Err: 83.9259\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.34it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 57/1000 \t Train Err: 81.6416 \t Test Err: 81.8139\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.34it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 58/1000 \t Train Err: 81.9228 \t Test Err: 81.7897\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|███████████████████████████████████████████| 32/32 [00:09<00:00, 3.33it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 59/1000 \t Train Err: 81.3041 \t Test Err: 79.5053\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 25%|███████████ | 8/32 [00:02<00:07, 3.32it/s]"
+ "ename": "NameError",
+ "evalue": "name 'mkbatch' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[12], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(BPE)):\n\u001b[0;32m----> 8\u001b[0m batch_src, batch_labels, batch_padding_mask \u001b[38;5;241m=\u001b[39m \u001b[43mmkbatch\u001b[49m(BSZ)\n\u001b[1;32m 9\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 10\u001b[0m output \u001b[38;5;241m=\u001b[39m model(batch_src, batch_padding_mask)\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'mkbatch' is not defined"
]
}
],
@@ -1332,6 +519,8 @@
" \n",
" test_err.append(test_loss)\n",
" train_err.append(train_loss)\n",
+ " with open(\"loss\", \"a\") as f:\n",
+ " f.write(f\"{train_loss} {test_loss}\\n\")\n",
" print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
" ax.plot(train_err, label='Train', color='blue')\n",
" ax.plot(test_err, label='Test', color='red')\n",
@@ -1345,17 +534,17 @@
},
{
"cell_type": "code",
- "execution_count": 80,
+ "execution_count": 57,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "55.06520214080811"
+ "177.23538208007812"
]
},
- "execution_count": 80,
+ "execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
@@ -1489,10 +678,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "execution_state": "running",
+ "execution_count": 13,
+ "execution_state": "idle",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'N_TUNE_EPOCHS' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[13], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m tune_train_err \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 2\u001b[0m tune_test_err \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[43mN_TUNE_EPOCHS\u001b[49m):\n\u001b[1;32m 5\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m 6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'N_TUNE_EPOCHS' is not defined"
+ ]
+ }
+ ],
"source": [
"tune_train_err = []\n",
"tune_test_err = []\n",
@@ -1513,6 +714,8 @@
" \n",
" tune_test_err.append(test_loss)\n",
" tune_train_err.append(train_loss)\n",
+ " with open(\"tune_loss\", \"a\") as f:\n",
+ " f.write(f\"{train_loss} {test_loss}\\n\")\n",
" ax.plot(tune_train_err, label='Train', color='blue')\n",
" ax.plot(tune_test_err, label='Test', color='red')\n",
" ax.set_xlabel('Epochs')\n",