aboutsummaryrefslogtreecommitdiff
path: root/blog.md
diff options
context:
space:
mode:
authorSIPB2024-12-10 21:45:59 -0500
committerSIPB2024-12-10 21:45:59 -0500
commit0e78a4196e4a8da98a8c66d80380e847f9a45302 (patch)
tree5712a6a960afa260d6bc75ce3b34d80f04a19a29 /blog.md
parentae08546afa147c7bcafccf54da4b4df2b2845925 (diff)
adding some stuff
Diffstat (limited to 'blog.md')
-rw-r--r--blog.md75
1 files changed, 36 insertions, 39 deletions
diff --git a/blog.md b/blog.md
index addbb4e..b5d82e3 100644
--- a/blog.md
+++ b/blog.md
@@ -10,7 +10,6 @@ title: "Off-Distribution Generalization of Transformers for Shortest Paths"
body {
/* Shrink margins */
max-width: 45em;
- /* font-size: 110%; */
}
/* Make sure toolbar plots aren't cut off */
@@ -37,19 +36,18 @@ Anthony Wang, Alek Westover, Kevin Zhao
### Motivation
-Neural networks are capable of impressive feats of off-distribution generalization. For instance, a program trained to convert sketches of cats into realistic looking pictures of cats can draw a cat with three eyes, if the sketch has three eyes, even if such an image never occurred in its training data. However, neural networks also often learn non-robust features that cause it to perform very poorly off-distribution (e.g., adversarial examples for an image classifier). In this project, we will investigate the question of when transformers generalize off-distribution through via a case study on a simple synthetic task. More specifically, the goal of our project is to make progress towards answering the following question:
+Neural networks are capable of impressive feats of off-distribution generalization. For instance, a program trained to convert sketches of cats into realistic looking pictures of cats can draw a cat with three eyes, if the sketch has three eyes, even if such an image never occurred in its training data. However, neural networks also often learn non-robust features that cause them to perform poorly off-distribution (e.g., adversarial examples for an image classifier). In this project, we will investigate the question of when transformers generalize off-distribution via a case study on a simple synthetic task. More specifically, the goal of our project is to make progress towards answering the following question:
> Suppose you train a model to be truthful on task A, and to produce coherent answers on a related task B.
>
> Will the model learn to be truthful on task B?
-Here's a more specific version of this question,
-inspired by a [blog post written by Paul Chrisiano](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX)
-(who was influential in developing RLHF).
+We stumbled upon this question while reading a [blog post written by Paul Chrisiano](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX)
+(who was influential in developing RLHF). His blog post inspired us to explore off-distribution generalization. A more specific version of our question is below:
> Suppose we trained a language model on next token prediction in both English and French, and then fine-tuned the model to be truthful in English. Would the model then learn to answer questions truthfully in French? It has been shown [^1] that this doesn't happen by default, but you could hope that the truthfulness training in English would generalize to French.
-In the future, humans are likely to trust AIs with performing important and complex tasks. For some of these tasks it will be infeasible for humans to audit the answers -- either because human labor is too expensive, or because the AIs are too competent, so humans aren't skilled enough to perform an audit (for example, an amateur chess player might have a hard time telling apart a good chess move and a bad chess move). However, there will be some simpler tasks where we can audit the AI's and reward truthul answers. So, it's plausible that models trained to be truthful in some simple domains will learn truthfulness as a general lesson.
+In the future, humans are likely to trust AIs with performing important and complex tasks. For some of these tasks it will be infeasible for humans to audit the answers -- either because human labor is too expensive, or because the AIs are too competent, so humans aren't skilled enough to perform an audit (for example, an amateur chess player might have a hard time telling apart a good chess move and a bad chess move). However, there will be some simpler tasks where we can audit the AI's and reward truthful answers. So, it's plausible that models trained to be truthful in some simple domains will learn truthfulness as a general lesson.
### Task
@@ -77,7 +75,7 @@ Obviously, if you actually want to compute shortest paths in a graph, you should
1. It's very easy to generate data and ground truth labels for this task.
-2. There is a simple circuit which computes shortest paths (exactly), by doing something akin to BFS. In fact, we implemented this simple circuit by hand in the section [Handcrafted BFS Transformer](#handcrafted-bfs-transformer). Our hope was that our model can also find a simple circuit that generalized well.
+2. There is a simple circuit which computes shortest paths (exactly), by doing something akin to BFS. In fact, we implemented this simple circuit by hand in the section [Handcrafted BFS Transformer](#handcrafted-bfs-transformer). Our hope was that our model would also find a simple circuit that generalizes well.
### Findings
@@ -85,7 +83,7 @@ Our main findings were as follows:
1. There exists a simple set of weights that you can find by hand to achieve arbitrarily small MSE loss on shortest paths. Moreover, this set of weights is "robust" in the sense that, if you perturb some of these weights and run gradient descent, the model will recover low loss.
-2. Transformers can learn on their own to perform shortest paths. The training process consists of a sequence of **phase changes** where the loss would barely change for a large number of epochs, and then over a very small number of epochs, the loss would dramatically go down. (TODO: Is this just because we're bad at DL? now it shouldn't have as many phase changes. Or maybe just in the beginning?) In particular, these phase changes corresponded to times when the model decided to, e.g., go from dramatically over-classifying paths as being length 1 (resulting in great performance on tasks of length 1 but poor performance on other tasks) to spreading these length 1 guesses out to other distances.
+2. Transformers can learn on their own to perform shortest paths. The training process consists of a sequence of **phase changes** where the loss would barely change for a large number of epochs, and then over a very small number of epochs, the loss would dramatically go down. In particular, these phase changes corresponded to times when the model decided to, e.g., go from dramatically over-classifying paths as being length 1 (resulting in great performance on tasks of length 1 but poor performance on other tasks) to spreading these length 1 guesses out to other distances.
3. A model trained to compute the distance between two specific vertices can be fine-tuned efficiently to compute the shortest distance to other vertices.
@@ -131,11 +129,11 @@ More specifically, we will create the following embeddings for our input sequenc
As you can see, the hidden dimension of the model will be $4n+2$. The hidden dimension is broken up into 5 sections, some of which are only used in the answer token, and some of which are only used in the vertex tokens. Here are what the different parts of the hidden dimension represent:
-1. Flags: booleans indicating whether a vertex is the answer token or not. If desired, you can think of this as part of a "positional embedding". The flags will not be updated as we move through the transformer.
-2. Neighbors: a boolean vector (of dimension $n$) describing the neighbors of each vertex. (We just put zeros here for the answer token). This also won't update through the course of running the transformer.
-3. Reach: a boolean vector representing the vertices that each vertex can reach. This will update over the course of a forward pass of the transformer. Specifically, at layer $i$ of the transformer we will extend the reach of each vertex from vertices which are distance $i+1$ away, to vertices at distance $i+2$ away.
-4. Out: this starts as zero for all tokens. Once vertex $1$ can reach vertex $2$ (i.e., we are in a layer deep enough that the reach vector for $v_{1}$'s token indicates that it can reach vertex $2$) we'll start writing $1$ on the answer token's out part of its embedding.
-5. Self: a 1-hot encoding specifying which vertex each token corresponds to. You should think of this as a positional embedding (along with the flags).
+1. Flags: Booleans indicating whether a vertex is the answer token or not. If desired, you can think of this as part of a "positional embedding". The flags will not be updated as we move through the transformer.
+2. Neighbors: A boolean vector (of dimension $n$) describing the neighbors of each vertex. (We just put zeros here for the answer token). This also won't update through the course of running the transformer.
+3. Reach: A boolean vector representing the vertices that each vertex can reach. This will update over the course of a forward pass of the transformer. Specifically, at layer $i$ of the transformer we will extend the reach of each vertex from vertices which are distance $i+1$ away, to vertices at distance $i+2$ away.
+4. Out: This starts as zero for all tokens. Once vertex $1$ can reach vertex $2$ (i.e., we are in a layer deep enough that the reach vector for $v_{1}$'s token indicates that it can reach vertex $2$) we'll start writing $1$ on the answer token's out part of its embedding.
+5. Self: A 1-hot encoding specifying which vertex each token corresponds to. You should think of this as a positional embedding (along with the flags).
The transformer will consist of $n$ layers, each of which has $n+1$ heads. We've included the code for the handcoded transformer in the [Appendix](#appendix) -- possibly it's helpful to look at the code in conjunction with the written explanation below.
@@ -152,17 +150,17 @@ The value that vertex w would give to vertex v is simply a 1-hot vector, with a
After running the heads, each vertex's embedding is updated to integrate the new vertices that it can reach into its embedding. Note that we can think of embeddings as being updated because transformers have a residual stream of embeddings that we make modifications to.
-Finally, the last head will be in charge of noticing whether vertex 1 has reached vertex 2 yet, and if so recording this fact. For the full details, see [appendix](#appendix).
+Finally, the last head will be in charge of noticing whether vertex 1 has reached vertex 2 yet, and if so recording this fact. For the full details, see the [Appendix](#appendix).
### Robustness
The field of Singular Learning Theory (SLT; see Liam Carroll's Master's thesis "Phase Transitions in Neural Networks" for an introduction) aims to understand model training and loss-landscape geometry. In efforts to better understand the loss landscape of the shortest paths loss function according to the tokens used in our hand coded implementation of the shortest paths transformers, we decided to start at a good setting of the parameters, and then perturb the weights, and see if the model can subsequently achieve low loss. The intuition for why this is a good approach at measuring "how attractive of a loss basin" we have is that this experiment is similar to the Local Learning Coefficient from SLT. (see Lau, Edmund, Zach Furman, George Wang, Daniel Murfet, and Susan Wei. "The Local Learning Coefficient: A Singularity-Aware Complexity Measure"). We found that, perturbing the weights led to high loss, but gradient descent was able to recover low loss, indicating that the solution is somewhat "findable" by gradient descent.
-
-TODO: Switch to fancy mpld3 plots
+<div style="text-align:center">
![perturb.png](perturb.png)
-![perturb-loss.png](perturb-loss.png)
+![perturb-loss.png](perturb-loss.png)
+</div>
## Training
@@ -182,9 +180,13 @@ We have a dataset for each step.
$include`raw="html"` plots/train-dist.html
-2. **Fine-tune data**: For each $n \in [3,7]$, we generated an equal number of graphs on $n$ vertices each with a random $t$ on the shortest path from $1$ to $2$. Again, each graph was created by choosing $n$ random edges.
+2. **Fine-tune data**: For each $n \in [3,7]$, we generated an equal number of graphs on $n$ vertices each with a random $t$ on the shortest path from $1$ to $2$. Again, each graph was created by choosing $n$ random edges. We also balanced the dataset to achieve the distribution below.
+
+$include`raw="html"` plots/tune-dist.html
+
+3. **Generalization test data**: The same as the fine-tune data, except we sample $n \in [3,15]$ and $t \leq 7$. The dataset had the following distribution of labels.
-3. **Generalization test data**: The same as the fine-tune data, except we sample $n \in [3,15]$ and $t \leq 7$.
+$include`raw="html"` plots/test-dist.html
We wrote some Python code to generate the data during the training loop, but Python is excruciatingly slow and data generation wasted a lot of training time. Our solution was to pre-generate the data before training using a multithreaded version of our Python code.
@@ -196,7 +198,7 @@ Since the order of the edges in the input doesn't matter, we did not use positio
<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm -->
<div style="text-align:center">
-![](img/embeddings.svg)
+![](img/embed_path.png)
</div>
### Specifications
@@ -245,12 +247,18 @@ Such **phase change** phenomena are ubiquitous in deep learning and for simple a
### Fine Tuning Results
-After receiving our initial results, we fine-tuned for $100$ epochs with the same hyperparameters other than a learning rate of $10^{-5}$. We achieved a final loss of $0.002139$. Here are our fine-tuning loss curves and 2D histogram.
+After receiving our initial results, we fine-tuned for $100$ epochs with the same hyperparameters other than a learning rate and weight decay of $10^{-5}$. We used weight decay to incentivize the model to find simpler or more general representations for the fine-tuning task. We achieved a final loss of $0.01472$. Here are our fine-tuning loss curves and 2D histogram.
$include`raw="html"` plots/tune-loss.html
$include`raw="html"` plots/tune-hist.html
+If we run the un-fine-tuned model on the evaluations, we get the following histogram and an abysmal loss of $0.5531$.
+
+$include`raw="html"` plots/untune-hist.html
+
+The plot at first looks great because the correct answers are all yellow, but that's because the model is simply ignoring the target vertex and almost always outputting an integer. It often gets the answer correct by sheer luck, but when it doesn't, the difference is instead an integer which makes the loss very high.
+
## Test Results
Lastly, we evaluated our fine-tuned model on several kinds of input to see when it generalizes.
@@ -259,30 +267,19 @@ $include`raw="html"` plots/test-onpath-smalltarget.html
$include`raw="html"` plots/test-onpath-largetarget.html
+For on-path targets in large graphs, the model is able to generalize somewhat well, but generalizes better when the target is small--$3$ through $7$--likely because those were in the fine-tuning data. For small targets, it achieves a loss of $0.1020$ compared to $0.1451$ for any target. It's notable that the model still generalizes to large targets, although to a lesser extent, which indicates that generalization happens on a spectrum and the model probably did learn somewhat general representations.
+
+We tried to do interpretability on the model, but it was too large and complicated to understand, so many of these explanations are just our intuition and may not correspond at all to how the model actually works.
+
+We also tested the model where the target is allowed to be any vertex, not just a vertex on the shortest path.
+
$include`raw="html"` plots/test-small-any.html
$include`raw="html"` plots/test-large-any-smalltarget.html
$include`raw="html"` plots/test-large-any-largetarget.html
-<!--
-ANT:
-RE t<7
-The model couuld've decided to generalize for t > 7 too, but it didn't
-so id say
-we tested generalization for all t in 3, 15
-it succeeded for t in 3,7
-failed for t in 8 to 15
-reason is because it could do "if statements" instead of realizing that it should just be looking at the dor product
-mmm
-i feel like the fact that it works for 3,7 and not 8,15 is kind of compelling
-anyways this is just a comment explaining how im going to / think you should pitch the generalization section
-(it couldve generalized all the way, but didnt, so what we learened is generalization sometimes happens but sometimes doesnt
-
-yeah that's the intuition
-can't really do interpretability to figure out if that's actually what's going on
--->
-
+These three tests achieved a loss of $61.44$, $47.06$, and $72.87$ respectively.
## Conclusion