diff options
Diffstat (limited to 'blog.md')
-rw-r--r-- | blog.md | 105 |
1 files changed, 49 insertions, 56 deletions
@@ -4,6 +4,8 @@ mkzip: zip project.zip index.html img title: "Off-Distribution Generalization of Transformers for Shortest Paths" --- +<!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 --> + <style> body { /* Shrink margins */ @@ -19,12 +21,6 @@ body { } </style> -<!-- -BTW we should run a spell checker on this at some point before submitting -note to self ---> -<!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 --> - <div style="text-align:center"> Anthony Wang, Alek Westover, Kevin Zhao @@ -35,9 +31,10 @@ Anthony Wang, Alek Westover, Kevin Zhao ### Motivation -Neural networks are capable of impressive feats of off-distribution generalization. For instance, a program trained to convert sketches of cats into realistic looking pictures of cats can draw a cat with three eyes, if the sketch has three eyes, even if such an image never occured in its training data. However, neural networks also often learn non-robust features that cause it to perform very poorly off-distribution (e.g., adversarial examples for an image classifier). In this project, we will investigate the question of when transformers generalize off-distribution through via a case study on a simple synthetic task. More specifically, the goal of our project is to make progress towards answering the following question: +Neural networks are capable of impressive feats of off-distribution generalization. For instance, a program trained to convert sketches of cats into realistic looking pictures of cats can draw a cat with three eyes, if the sketch has three eyes, even if such an image never occurred in its training data. However, neural networks also often learn non-robust features that cause it to perform very poorly off-distribution (e.g., adversarial examples for an image classifier). In this project, we will investigate the question of when transformers generalize off-distribution through via a case study on a simple synthetic task. More specifically, the goal of our project is to make progress towards answering the following question: -> Suppose you train a model to be truthful on task A, and to produce coherent answers on a related task B.\ +> Suppose you train a model to be truthful on task A, and to produce coherent answers on a related task B. +> > Will the model learn to be truthful on task B? Here's a more specific version of this question, @@ -70,31 +67,11 @@ In order to investigate generalization of this type in transformers, we decided ![](img/test.svg) </div> -<!-- ![Experiment description](traintunegeneralize.png) --> - Obviously, if you actually want to compute shortest paths in a graph, you should use the standard graph algorithm BFS (breadth-first search). Our motivation for using this problem to study generalization was two-fold: -1. It's very simple to generate data and ground truth labels for this task. -2. There is a simple circuit which computes shortest paths (exactly), by doing something akin to BFS. In fact, we implemented this simple circuit by hand in the section [Handcrafted BFS Transformer](#handcrafted-bfs-transformer). Our hope was that our model can also find a simple circuit that generalized well. - - -<!-- -ANT: -RE t<7 -The model couuld've decided to generalize for t > 7 too, but it didn't -so id say -we tested generalization for all t in 3, 15 -it succeeded for t in 3,7 -failed for t in 8 to 15 -reason is because it could do "if statements" instead of realizing that it should just be looking at the dor product -mmm -i feel like the fact that it works for 3,7 and not 8,15 is kind of compelling -anyways this is just a comment explaining how im going to / think you should pitch the generalization section -(it couldve generalized all the way, but didnt, so what we learened is generalization sometimes happens but sometimes doesnt +1. It's very easy to generate data and ground truth labels for this task. -yeah that's the intuition -can't really do interpretability to figure out if that's actually what's going on ---> +2. There is a simple circuit which computes shortest paths (exactly), by doing something akin to BFS. In fact, we implemented this simple circuit by hand in the section [Handcrafted BFS Transformer](#handcrafted-bfs-transformer). Our hope was that our model can also find a simple circuit that generalized well. ### Findings @@ -118,7 +95,7 @@ The remainder of our post is structured as follows: 1. [Handcrafted BFS Transformer](#handcrafted-bfs-transformer): To start, we explicitly construct by hand weights for a transformer to solve the shortest paths problem. We also investigate the "robustness" of our construction by showing that, starting from our constructed solution it's possible to recover low loss after perturbing the weights via gradient descent. This gives a proof-of-concept that a transformer could encode a simple generalizable solution to the problem, and gives some reason to believe that a transformer could learn such a solution. -2. [Training](#training): Next, we discuss the training. First we describe the exact data generation process and model architecture. Then we discuss hyperparameter settings and optimization choices that were important to get the model to train efficiently. Finally, we investigate some interesting phenomena that occured during training, such as phase transitions in loss for different path lengths. +2. [Training](#training): Next, we discuss the training. First we describe the exact data generation process and model architecture. Then we discuss hyperparameter settings and optimization choices that were important to get the model to train efficiently. Finally, we investigate some interesting phenomena that occurred during training, such as phase transitions in loss for different path lengths. 3. [Generalization](#generalization): Finally we discuss the phenomena that we wanted to study -- generalization! We show that some generalization does occur, but other generalization doesn't. We explain why we think this happens. @@ -187,7 +164,7 @@ TODO: Switch to fancy mpld3 plots We'll represent an $n$ vertex, $m$ edge unweighted, undirected graph as sequence of the endpoints of the $m$ edges, so $[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]$ represents a graph with the edges $\{(a_i,b_i)\}$ where $a_i < b_i$ for $1 \leq i \leq m$. All sequences are padded to the same length using the padding token $0$. -The full input to our model additionally includes the target vertex $t$ after the padding tokens. The label to an input is the length of the shortest path from $1$ to $t$. If no such path exists, we define the length to be $n+1$ which represents infinity. For example, the input $[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0, 3]$ has the label $2$. +The full input to our model additionally includes the target vertex $t$ after the padding tokens. The label to an input is the length of the shortest path from $1$ to $t$. If no such path exists, we define the length to be $n$ which represents infinity. For example, the input $[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0, 3]$ has the label $2$. <div style="text-align:center"> ![](img/finetune.svg) @@ -197,8 +174,6 @@ We have a dataset for each step. 1. **Pre-train data**: For each $n \in [3,15]$, we generated an equal number of graphs on $n$ vertices, with $t = 2$. Each graph was created by choosing $n$ random edges. To balance the dataset, we removed some graphs with shorter paths to obtain the distribution shown below. -TODO: Add axes titles, switch to bar plot - $include`raw="html"` plots/train-dist.html 2. **Fine-tune data**: For each $n \in [3,7]$, we generated an equal number of graphs on $n$ vertices each with a random $t$ on the shortest path from $1$ to $2$. Again, each graph was created by choosing $n$ random edges. @@ -209,7 +184,6 @@ We wrote some Python code to generate the data during the training loop, but Pyt ### Architecture - We used a standard transformer architecture. To ensure that it can in theory learn BFS, we ensured that the number of layers in our transformer exceeds the diameter of the input graphs. Since the order of the edges in the input doesn't matter, we did not use positional encodings. Each edge $(a,b)$ is embedded to dimension $d$ where the first $\frac{d}{2}$ elements are the learned embedding of $a$ and the last $\frac{d}{2}$ elements are the learned embedding of $b$. For the target vertex $t$, we pair it with the special token $TARGET$ and embed $(t,TARGET)$ in the same way. @@ -221,25 +195,25 @@ Since the order of the edges in the input doesn't matter, we did not use positio For our training run, we used the following specifications: -| Hyperparameter | Value | -| ------------------------ | --------------- | -| Model dimension | $64$ | -| Layers | $11$ | -| Heads | $2$ | -| Epochs | $100$ | -| Batch size | $2^{15}$ | -| Dropout | $0$ | -| Training examples | $10^8$ | -| Parameters | $550433$ | -| Learning rate | $5\cdot10^{-4}$ | -| Loss | MSE | -| Optimizer | Adam | +| Hyperparameter | Value | +| ----------------------- | --------------- | +| Model dimension | $64$ | +| Layers | $11$ | +| Heads | $2$ | +| Epochs | $200$ | +| Batch size | $2^{15}$ | +| Dropout | $0$ | +| Total training examples | $2\cdot10^8$ | +| Parameters | $550433$ | +| Learning rate | $5\cdot10^{-4}$ | +| Loss | MSE | +| Optimizer | Adam | The number of bits required to store the model parameters in float32 is around $1.76\cdot10^6$. The number of possible graphs on 15 vertices generated using our procedure is approximately $$\frac{\binom{15}{2}^{15}}{15!} \approx 1.59\cdot10^{18}.$$ This is because there are $\binom{15}{2}$ choices for each of the 15 edges and we don't care about the order of the edges. This is only an approximation because some edges might be duplicated. Each graph has an answer between 1 and 15 which requires around 4 bits, so memorizing all the answers requires requires $4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}$ bits, which is $3.61\cdot10^{12}$ times larger than our model size. This implies that in order to get really low loss, our model needs to do something other than brute memorization. -A single training run takes roughly eight hours to run on a Radeon 7900 XTX graphics card. +A single training run takes roughly two hours to run on a Radeon 7900 XTX graphics card. One pitfall we encountered during training is that we initially used bfloat16 to save VRAM, but our huge batch size caused loss-of-precision problems and made training very difficult. It took us two weeks to debug this until we found that switching to float32 improved training significantly. @@ -249,19 +223,19 @@ Here is our training loss curve: $include`raw="html"` plots/train-loss.html -Here is a "confusion matrix"-like plot, showing for each true label, the probability distribution over classifications that the model gives. +Here is a "confusion matrix"-like plot, showing for each true label, the probability distribution over classifications that the model gives: $include`raw="html"` plots/train-hist.html -One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph above: +One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph below: -$include`raw="html"` plots/grokking.html +$include`raw="html"` plots/train-lens.html Such **phase change** phenomena are ubiquitous in deep learning and for simple algorithmic tasks we can sometimes interpret them [^2]. **TODO ANT: training curves for 1, 2, 3 length paths. and add some text about why this is an interesting finding.** -$include`raw="html"` plots/len-loss.html + ## Generalization @@ -273,14 +247,33 @@ $include`raw="html"` plots/tune-loss.html $include`raw="html"` plots/tune-hist.html -$include`raw="html"` plots/test-onpath-largetarget.html - $include`raw="html"` plots/test-onpath-smalltarget.html +$include`raw="html"` plots/test-onpath-largetarget.html + $include`raw="html"` plots/test-small-any.html -$include`raw="html"` plots/test-large-any.html +$include`raw="html"` plots/test-large-any-smalltarget.html + +$include`raw="html"` plots/test-large-any-largetarget.html + +<!-- +ANT: +RE t<7 +The model couuld've decided to generalize for t > 7 too, but it didn't +so id say +we tested generalization for all t in 3, 15 +it succeeded for t in 3,7 +failed for t in 8 to 15 +reason is because it could do "if statements" instead of realizing that it should just be looking at the dor product +mmm +i feel like the fact that it works for 3,7 and not 8,15 is kind of compelling +anyways this is just a comment explaining how im going to / think you should pitch the generalization section +(it couldve generalized all the way, but didnt, so what we learened is generalization sometimes happens but sometimes doesnt +yeah that's the intuition +can't really do interpretability to figure out if that's actually what's going on +--> ## Conclusion |