aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--transformer_shortest_paths.ipynb259
1 files changed, 7 insertions, 252 deletions
diff --git a/transformer_shortest_paths.ipynb b/transformer_shortest_paths.ipynb
index fe0223c..0adb5db 100644
--- a/transformer_shortest_paths.ipynb
+++ b/transformer_shortest_paths.ipynb
@@ -359,66 +359,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "T",
- "r",
- "a",
- "i",
- "n",
- "i",
- "n",
- "g",
- " ",
- "d",
- "a",
- "t",
- "a",
- ":",
- " ",
- "2",
- "6",
- "2",
- "M",
- "\n",
- "T",
- "r",
- "a",
- "i",
- "n",
- "a",
- "b",
- "l",
- "e",
- " ",
- "p",
- "a",
- "r",
- "a",
- "m",
- "e",
- "t",
- "e",
- "r",
- "s",
- " ",
- "i",
- "n",
- " ",
- "t",
- "h",
- "e",
- " ",
- "m",
- "o",
- "d",
- "e",
- "l",
- ":",
- " ",
- "6",
- "0",
- "5",
- "K",
- "\n"
+ "Training data: 839M\n",
+ "Trainable parameters in the model: 804K\n"
]
}
],
@@ -427,12 +369,12 @@
"VOCAB_SIZE = 1 + MAX_VTXS # one more than the max number of vertices\n",
"MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
"NEPOCHS = 100\n",
- "BSZ = 10244 # Batch size\n",
- "BPE = 256 # Batches per epoch\n",
+ "BSZ = 8196 # Batch size\n",
+ "BPE = 1024 # Batches per epoch\n",
"LR = 0.003\n",
"WD = 0.002\n",
"NHEADS = 4\n",
- "NLAYERS = 12\n",
+ "NLAYERS = 16\n",
"DROPOUT = 0.2\n",
"model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
" output_dim=1, num_heads=NHEADS,\n",
@@ -481,7 +423,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "4eb2d57cb7c948da8e9bd201f70d8d19",
+ "model_id": "fc7a660a4efe4342bf4af23cecdea1b1",
"version_major": 2,
"version_minor": 0
},
@@ -529,194 +471,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- " ",
- "9",
- "3",
- "%",
- "|",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "█",
- "▌",
- " ",
- " ",
- " ",
- " ",
- " ",
- " ",
- " ",
- " ",
- " ",
- " ",
- "|",
- " ",
- "2",
- "3",
- "8",
- "/",
- "2",
- "5",
- "6",
- " ",
- "[",
- "0",
- "1",
- ":",
- "1",
- "3",
- "<",
- "0",
- "0",
- ":",
- "0",
- "5",
- ",",
- " ",
- " ",
- "3",
- ".",
- "5",
- "3",
- "i",
- "t",
- "/",
- "s",
- "]"
+ " 0%| | 0/1024 [00:00<?, ?it/s]"
]
}
],