diff options
Diffstat (limited to 'transformer_shortest_paths.ipynb')
-rw-r--r-- | transformer_shortest_paths.ipynb | 259 |
1 files changed, 7 insertions, 252 deletions
diff --git a/transformer_shortest_paths.ipynb b/transformer_shortest_paths.ipynb index fe0223c..0adb5db 100644 --- a/transformer_shortest_paths.ipynb +++ b/transformer_shortest_paths.ipynb @@ -359,66 +359,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "T", - "r", - "a", - "i", - "n", - "i", - "n", - "g", - " ", - "d", - "a", - "t", - "a", - ":", - " ", - "2", - "6", - "2", - "M", - "\n", - "T", - "r", - "a", - "i", - "n", - "a", - "b", - "l", - "e", - " ", - "p", - "a", - "r", - "a", - "m", - "e", - "t", - "e", - "r", - "s", - " ", - "i", - "n", - " ", - "t", - "h", - "e", - " ", - "m", - "o", - "d", - "e", - "l", - ":", - " ", - "6", - "0", - "5", - "K", - "\n" + "Training data: 839M\n", + "Trainable parameters in the model: 804K\n" ] } ], @@ -427,12 +369,12 @@ "VOCAB_SIZE = 1 + MAX_VTXS # one more than the max number of vertices\n", "MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n", "NEPOCHS = 100\n", - "BSZ = 10244 # Batch size\n", - "BPE = 256 # Batches per epoch\n", + "BSZ = 8196 # Batch size\n", + "BPE = 1024 # Batches per epoch\n", "LR = 0.003\n", "WD = 0.002\n", "NHEADS = 4\n", - "NLAYERS = 12\n", + "NLAYERS = 16\n", "DROPOUT = 0.2\n", "model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n", " output_dim=1, num_heads=NHEADS,\n", @@ -481,7 +423,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4eb2d57cb7c948da8e9bd201f70d8d19", + "model_id": "fc7a660a4efe4342bf4af23cecdea1b1", "version_major": 2, "version_minor": 0 }, @@ -529,194 +471,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " ", - "9", - "3", - "%", - "|", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "█", - "▌", - " ", - " ", - " ", - " ", - " ", - " ", - " ", - " ", - " ", - " ", - "|", - " ", - "2", - "3", - "8", - "/", - "2", - "5", - "6", - " ", - "[", - "0", - "1", - ":", - "1", - "3", - "<", - "0", - "0", - ":", - "0", - "5", - ",", - " ", - " ", - "3", - ".", - "5", - "3", - "i", - "t", - "/", - "s", - "]" + " 0%| | 0/1024 [00:00<?, ?it/s]" ] } ], |