diff options
Diffstat (limited to 'blog.md')
-rw-r--r-- | blog.md | 36 |
1 files changed, 22 insertions, 14 deletions
@@ -19,6 +19,12 @@ body { display: block; margin: auto; } + +/* Super gross hack to fix stupid colorbars clipping out bug */ +/* :has requires a relatively modern browser */ +.mpld3-staticpaths:has(:nth-child(10)) { + transform: rotate(180deg) translate(-520px,-423px); +} </style> <div style="text-align:center"> @@ -172,7 +178,7 @@ The full input to our model additionally includes the target vertex $t$ after th We have a dataset for each step. -1. **Pre-train data**: For each $n \in [3,15]$, we generated an equal number of graphs on $n$ vertices, with $t = 2$. Each graph was created by choosing $n$ random edges. To balance the dataset, we removed some graphs with shorter paths to obtain the distribution shown below. +1. **Pre-train data**: For each $n \in [3,15]$, we generated an equal number of graphs on $n$ vertices, with $t = 2$. Each graph was created by choosing $n$ random edges. To balance the dataset, we removed some graphs with shorter paths to obtain the distribution shown below. (For any plot, click on the four arrows icon in the bottom left to scroll and pan.) $include`raw="html"` plots/train-dist.html @@ -189,7 +195,9 @@ We used a standard transformer architecture. To ensure that it can in theory lea Since the order of the edges in the input doesn't matter, we did not use positional encodings. Each edge $(a,b)$ is embedded to dimension $d$ where the first $\frac{d}{2}$ elements are the learned embedding of $a$ and the last $\frac{d}{2}$ elements are the learned embedding of $b$. For the target vertex $t$, we pair it with the special token $TARGET$ and embed $(t,TARGET)$ in the same way. <!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm --> +<div style="text-align:center"> ![](img/embeddings.svg) +</div> ### Specifications @@ -200,12 +208,12 @@ For our training run, we used the following specifications: | Model dimension | $64$ | | Layers | $11$ | | Heads | $2$ | -| Epochs | $200$ | +| Epochs | $9600$ | | Batch size | $2^{15}$ | | Dropout | $0$ | -| Total training examples | $2\cdot10^8$ | +| Total training examples | $3.14\cdot10^8$ | | Parameters | $550433$ | -| Learning rate | $5\cdot10^{-4}$ | +| Learning rate | $5\cdot10^{-4}$ for $6400$ epochs, $10^{-4}$ for $3200$ epochs | | Loss | MSE | | Optimizer | Adam | @@ -213,40 +221,40 @@ The number of bits required to store the model parameters in float32 is around $ $$\frac{\binom{15}{2}^{15}}{15!} \approx 1.59\cdot10^{18}.$$ This is because there are $\binom{15}{2}$ choices for each of the 15 edges and we don't care about the order of the edges. This is only an approximation because some edges might be duplicated. Each graph has an answer between 1 and 15 which requires around 4 bits, so memorizing all the answers requires requires $4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}$ bits, which is $3.61\cdot10^{12}$ times larger than our model size. This implies that in order to get really low loss, our model needs to do something other than brute memorization. -A single training run takes roughly two hours to run on a Radeon 7900 XTX graphics card. +A single training run takes roughly three hours on a Radeon 7900 XTX graphics card. One pitfall we encountered during training is that we initially used bfloat16 to save VRAM, but our huge batch size caused loss-of-precision problems and made training very difficult. It took us two weeks to debug this until we found that switching to float32 improved training significantly. ### Training Results -Here is our training loss curve: +Here is our training loss curve, with a final loss of $0.001623$. The discontinuity at epoch $6400$ is from switching the learning rate. $include`raw="html"` plots/train-loss.html -Here is a "confusion matrix"-like plot, showing for each true label, the probability distribution over classifications that the model gives: +Here is a "confusion matrix"-like 2D histogram plot, showing for each true label, the probability distribution over classifications that the model gives. $include`raw="html"` plots/train-hist.html -One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph below: +One pattern we consistently noticed during training is that the model sometimes gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened in the first $50$ epochs in the graph below. $include`raw="html"` plots/train-lens.html -Such **phase change** phenomena are ubiquitous in deep learning and for simple algorithmic tasks we can sometimes interpret them [^2]. - -**TODO ANT: training curves for 1, 2, 3 length paths. and add some text about why this is an interesting finding.** - - +Such **phase change** phenomena are ubiquitous in deep learning and for simple algorithmic tasks we can sometimes interpret them [^2]. In our case, the loss is initially very high for graphs of path length $1,2,3,\infty = 15$ because the model always guesses the same mean answer to minimize MSE. However, at around epoch $50$, the model suddenly learns how to solve graphs with a path length of $1$ by checking if the edge $(1, 2)$ exists in the input. This happens again at around epoch $80$ when the model begins to learn how to solve lengths $2$ and $3$ as well and soon learns to separate the two cases as well. ## Generalization ### Fine Tuning Results -After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results are shown in the images below. +After receiving our initial results, we fine-tuned for $100$ epochs with the same hyperparameters other than a learning rate of $10^{-5}$. We achieved a final loss of $0.002139$. Here are our fine-tuning loss curves and 2D histogram. $include`raw="html"` plots/tune-loss.html $include`raw="html"` plots/tune-hist.html +## Test Results + +Lastly, we evaluated our fine-tuned model on several kinds of input to see when it generalizes. + $include`raw="html"` plots/test-onpath-smalltarget.html $include`raw="html"` plots/test-onpath-largetarget.html |