From 7f3f049e353f47e76daf3af2ebfffdaad96992e6 Mon Sep 17 00:00:00 2001
From: SIPB
Date: Mon, 9 Dec 2024 00:29:57 -0500
Subject: Add nearly finished code
---
blog.md | 36 +-
index.html | 522 ++-
loss | 7588 ---------------------------------
model.pth | Bin 1233035 -> 2257362 bytes
plots/test-dist.html | 43 +
plots/test-large-any-largetarget.html | 43 +
plots/test-large-any-smalltarget.html | 43 +
plots/test-onpath-largetarget.html | 43 +
plots/test-onpath-smalltarget.html | 43 +
plots/test-small-any.html | 43 +
plots/train-dist.html | 8 +-
plots/train-hist.html | 43 +
plots/train-lens.html | 43 +
plots/train-loss.html | 43 +
plots/tune-dist.html | 43 +
plots/tune-hist.html | 43 +
plots/tune-loss.html | 43 +
transformer_shortest_paths.ipynb | 2363 +++-------
18 files changed, 1520 insertions(+), 9513 deletions(-)
create mode 100644 plots/test-dist.html
create mode 100644 plots/test-large-any-largetarget.html
create mode 100644 plots/test-large-any-smalltarget.html
create mode 100644 plots/test-onpath-largetarget.html
create mode 100644 plots/test-onpath-smalltarget.html
create mode 100644 plots/test-small-any.html
create mode 100644 plots/train-hist.html
create mode 100644 plots/train-lens.html
create mode 100644 plots/train-loss.html
create mode 100644 plots/tune-dist.html
create mode 100644 plots/tune-hist.html
create mode 100644 plots/tune-loss.html
diff --git a/blog.md b/blog.md
index e00df4e..addbb4e 100644
--- a/blog.md
+++ b/blog.md
@@ -19,6 +19,12 @@ body {
display: block;
margin: auto;
}
+
+/* Super gross hack to fix stupid colorbars clipping out bug */
+/* :has requires a relatively modern browser */
+.mpld3-staticpaths:has(:nth-child(10)) {
+ transform: rotate(180deg) translate(-520px,-423px);
+}
@@ -172,7 +178,7 @@ The full input to our model additionally includes the target vertex $t$ after th
We have a dataset for each step.
-1. **Pre-train data**: For each $n \in [3,15]$, we generated an equal number of graphs on $n$ vertices, with $t = 2$. Each graph was created by choosing $n$ random edges. To balance the dataset, we removed some graphs with shorter paths to obtain the distribution shown below.
+1. **Pre-train data**: For each $n \in [3,15]$, we generated an equal number of graphs on $n$ vertices, with $t = 2$. Each graph was created by choosing $n$ random edges. To balance the dataset, we removed some graphs with shorter paths to obtain the distribution shown below. (For any plot, click on the four arrows icon in the bottom left to scroll and pan.)
$include`raw="html"` plots/train-dist.html
@@ -189,7 +195,9 @@ We used a standard transformer architecture. To ensure that it can in theory lea
Since the order of the edges in the input doesn't matter, we did not use positional encodings. Each edge $(a,b)$ is embedded to dimension $d$ where the first $\frac{d}{2}$ elements are the learned embedding of $a$ and the last $\frac{d}{2}$ elements are the learned embedding of $b$. For the target vertex $t$, we pair it with the special token $TARGET$ and embed $(t,TARGET)$ in the same way.
+
![](img/embeddings.svg)
+
### Specifications
@@ -200,12 +208,12 @@ For our training run, we used the following specifications:
| Model dimension | $64$ |
| Layers | $11$ |
| Heads | $2$ |
-| Epochs | $200$ |
+| Epochs | $9600$ |
| Batch size | $2^{15}$ |
| Dropout | $0$ |
-| Total training examples | $2\cdot10^8$ |
+| Total training examples | $3.14\cdot10^8$ |
| Parameters | $550433$ |
-| Learning rate | $5\cdot10^{-4}$ |
+| Learning rate | $5\cdot10^{-4}$ for $6400$ epochs, $10^{-4}$ for $3200$ epochs |
| Loss | MSE |
| Optimizer | Adam |
@@ -213,40 +221,40 @@ The number of bits required to store the model parameters in float32 is around $
$$\frac{\binom{15}{2}^{15}}{15!} \approx 1.59\cdot10^{18}.$$
This is because there are $\binom{15}{2}$ choices for each of the 15 edges and we don't care about the order of the edges. This is only an approximation because some edges might be duplicated. Each graph has an answer between 1 and 15 which requires around 4 bits, so memorizing all the answers requires requires $4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}$ bits, which is $3.61\cdot10^{12}$ times larger than our model size. This implies that in order to get really low loss, our model needs to do something other than brute memorization.
-A single training run takes roughly two hours to run on a Radeon 7900 XTX graphics card.
+A single training run takes roughly three hours on a Radeon 7900 XTX graphics card.
One pitfall we encountered during training is that we initially used bfloat16 to save VRAM, but our huge batch size caused loss-of-precision problems and made training very difficult. It took us two weeks to debug this until we found that switching to float32 improved training significantly.
### Training Results
-Here is our training loss curve:
+Here is our training loss curve, with a final loss of $0.001623$. The discontinuity at epoch $6400$ is from switching the learning rate.
$include`raw="html"` plots/train-loss.html
-Here is a "confusion matrix"-like plot, showing for each true label, the probability distribution over classifications that the model gives:
+Here is a "confusion matrix"-like 2D histogram plot, showing for each true label, the probability distribution over classifications that the model gives.
$include`raw="html"` plots/train-hist.html
-One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph below:
+One pattern we consistently noticed during training is that the model sometimes gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened in the first $50$ epochs in the graph below.
$include`raw="html"` plots/train-lens.html
-Such **phase change** phenomena are ubiquitous in deep learning and for simple algorithmic tasks we can sometimes interpret them [^2].
-
-**TODO ANT: training curves for 1, 2, 3 length paths. and add some text about why this is an interesting finding.**
-
-
+Such **phase change** phenomena are ubiquitous in deep learning and for simple algorithmic tasks we can sometimes interpret them [^2]. In our case, the loss is initially very high for graphs of path length $1,2,3,\infty = 15$ because the model always guesses the same mean answer to minimize MSE. However, at around epoch $50$, the model suddenly learns how to solve graphs with a path length of $1$ by checking if the edge $(1, 2)$ exists in the input. This happens again at around epoch $80$ when the model begins to learn how to solve lengths $2$ and $3$ as well and soon learns to separate the two cases as well.
## Generalization
### Fine Tuning Results
-After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results are shown in the images below.
+After receiving our initial results, we fine-tuned for $100$ epochs with the same hyperparameters other than a learning rate of $10^{-5}$. We achieved a final loss of $0.002139$. Here are our fine-tuning loss curves and 2D histogram.
$include`raw="html"` plots/tune-loss.html
$include`raw="html"` plots/tune-hist.html
+## Test Results
+
+Lastly, we evaluated our fine-tuned model on several kinds of input to see when it generalizes.
+
$include`raw="html"` plots/test-onpath-smalltarget.html
$include`raw="html"` plots/test-onpath-largetarget.html
diff --git a/index.html b/index.html
index 4d7b64d..0a28cf4 100644
--- a/index.html
+++ b/index.html
@@ -268,6 +268,12 @@ body {
display: block;
margin: auto;
}
+
+/* Super gross hack to fix stupid colorbars clipping out bug */
+/* :has requires a relatively modern browser */
+.mpld3-staticpaths:has(:nth-child(10)) {
+ transform: rotate(180deg) translate(-520px,-423px);
+}
Anthony Wang, Alek Westover, Kevin Zhao
@@ -628,7 +634,8 @@ graphs on
n vertices, with
t = 2. Each graph was created by choosing
n random edges. To balance the dataset,
we removed some graphs with shorter paths to obtain the distribution
-shown below.
+shown below. (For any plot, click on the four arrows icon in the bottom
+left to scroll and pan.)
@@ -636,7 +643,7 @@ shown below.
-
+
+
Here is a “confusion matrix”-like 2D histogram plot, showing for each
+true label, the probability distribution over classifications that the
+model gives.
+
+
+
+
+
+
One pattern we consistently noticed during training is that the model
-often gets stuck and plateaus for many epochs before rapidly decreasing.
-For instance, this happened between epochs 100 and 300 in the graph
-below:
-
$includeraw="html"
plots/train-lens.html
+sometimes gets stuck and plateaus for many epochs before rapidly
+decreasing. For instance, this happened in the first
50 epochs in the graph below.
+
+
+
+
+
+
Such phase change phenomena are ubiquitous in deep
learning and for simple algorithmic tasks we can sometimes interpret
them .
-
TODO ANT: training curves for 1, 2, 3 length paths. and add
-some text about why this is an interesting finding.
+role="doc-noteref">
2. In our case, the loss is initially
+very high for graphs of path length
1,2,3,\infty = 15 because the model always
+guesses the same mean answer to minimize MSE. However, at around epoch
+
50, the model suddenly learns how to
+solve graphs with a path length of
1 by
+checking if the edge
(1, 2) exists in
+the input. This happens again at around epoch
80 when the model begins to learn how to
+solve lengths
2 and
3 as well and soon learns to separate the two
+cases as well.
Generalization
Fine Tuning Results
-
After receiving our initial results, we fine-tuned with a learning
-rate of 1e-5, also with MSE and the same batch size. Our final results
-are shown in the images below.
-
$includeraw="html"
plots/tune-loss.html
-
$includeraw="html"
plots/tune-hist.html
-
$includeraw="html"
-plots/test-onpath-smalltarget.html
-
$includeraw="html"
-plots/test-onpath-largetarget.html
-
$includeraw="html"
plots/test-small-any.html
-
$includeraw="html"
-plots/test-large-any-smalltarget.html
-
$includeraw="html"
-plots/test-large-any-largetarget.html
+
After receiving our initial results, we fine-tuned for 100 epochs with the same hyperparameters
+other than a learning rate of 10^{-5}.
+We achieved a final loss of 0.002139.
+Here are our fine-tuning loss curves and 2D histogram.
+
+
+
+
+
+
+
+
+
+
+
+
+
Test Results
+
Lastly, we evaluated our fine-tuned model on several kinds of input
+to see when it generalizes.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+