diff options
Diffstat (limited to 'blog.md')
-rw-r--r-- | blog.md | 37 |
1 files changed, 25 insertions, 12 deletions
@@ -1,6 +1,6 @@ --- build: pandoc blog.md --katex -s --filter pandoc-include -o index.html -mkzip: zip project.zip index.html img +mkzip: zip project.zip index.html img -r title: "Off-Distribution Generalization of Transformers for Shortest Paths" --- @@ -36,14 +36,14 @@ Anthony Wang, Alek Westover, Kevin Zhao ### Motivation -Neural networks are capable of impressive feats of off-distribution generalization. For instance, a program trained to convert sketches of cats into realistic looking pictures of cats can draw a cat with three eyes, if the sketch has three eyes, even if such an image never occurred in its training data. However, neural networks also often learn non-robust features that cause them to perform poorly off-distribution (e.g., adversarial examples for an image classifier). In this project, we will investigate the question of when transformers generalize off-distribution via a case study on a simple synthetic task. More specifically, the goal of our project is to make progress towards answering the following question: +Neural networks are capable of impressive feats of off-distribution generalization. For instance, we discussed in class a program trained to convert sketches of cats into realistic looking pictures of cats that was able to draw a cat with three eyes given a sketch with three eyes, even though there were no 3-eyed cats in the training data. However, neural networks also often learn non-robust features that cause them to perform poorly off-distribution (e.g., adversarial examples for an image classifier). In this project, we will investigate the question of when transformers generalize off-distribution via a case study on a simple synthetic task. More specifically, the goal of our project is to make progress towards answering the following question: > Suppose you train a model to be truthful on task A, and to produce coherent answers on a related task B. > > Will the model learn to be truthful on task B? -We stumbled upon this question while reading a [blog post written by Paul Chrisiano](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) -(who was influential in developing RLHF). His blog post inspired us to explore off-distribution generalization. A more specific version of our question is below: +We first heard about this general question from a [blog post written by Paul Chrisiano](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) +(who was influential in developing RLHF) (but of course lots of people discuss off-distribution generalization). Christiano gave an example of a toy setting where you could investigate transformer generalization capabilites: > Suppose we trained a language model on next token prediction in both English and French, and then fine-tuned the model to be truthful in English. Would the model then learn to answer questions truthfully in French? It has been shown [^1] that this doesn't happen by default, but you could hope that the truthfulness training in English would generalize to French. @@ -247,21 +247,32 @@ Such **phase change** phenomena are ubiquitous in deep learning and for simple a ### Fine Tuning Results -After receiving our initial results, we fine-tuned for $100$ epochs with the same hyperparameters other than a learning rate and weight decay of $10^{-5}$. We used weight decay to incentivize the model to find simpler or more general representations for the fine-tuning task. We achieved a final loss of $0.01472$. Here are our fine-tuning loss curves and 2D histogram. +Directly using the "pre-trained" model that is good at computing the distance between vertex $1$ and vertex $2$ for computing the distance to other vertices does very poorly: we get an MSE loss of $0.5531$, and the following confusion matrix. -$include`raw="html"` plots/tune-loss.html +$include`raw="html"` plots/untune-hist.html -$include`raw="html"` plots/tune-hist.html +Note that, while this confusion matrix might initially look okay, it is actually extremely bad because the model is basically just guessing integers to get moderately low loss. -If we run the un-fine-tuned model on the evaluations, we get the following histogram and an abysmal loss of $0.5531$. +However, we'll now show that with a very small amount of fine-tuning, the model can perform very well at computing the shortest path to vertices that lie on the shortest path between vertex $1$ and vertex 2$$ --- as we hypothesized! +Specifically, we fine-tuned for $100$ epochs (which is ~100 times fewer than the number of epochs that we pre-trained for) with the same hyperparameters other than a learning rate and weight decay of $10^{-5}$ on the task of predicting the distance from vertex $1$ to a target vertex other than vertex $2$, but which lies on the shortest path between vertex $1$ and vertex $2$. We used weight decay to incentivize the model to find simpler or more general representations for the fine-tuning task. We achieved a final loss of $0.01472$. Here are our fine-tuning loss curves and 2D histogram. -$include`raw="html"` plots/untune-hist.html +$include`raw="html"` plots/tune-loss.html -The plot at first looks great because the correct answers are all yellow, but that's because the model is simply ignoring the target vertex and almost always outputting an integer. It often gets the answer correct by sheer luck, but when it doesn't, the difference is instead an integer which makes the loss very high. +$include`raw="html"` plots/tune-hist.html ### Test Results -Lastly, we evaluated our fine-tuned model on several kinds of input to see when it generalizes. +Lastly, we ran our fine-tuned model on some tasks that were not part of its training regiment, and investigated in which settings it was able to generalize. + +Here is a description of the tasks that we ran our fine-tuned model on, and a summary of results. + +| Name | Description | Generalization Behavior | +|----------|----------|----------| +| On-path Small Target | Compute distance to a target vertex $t\in [3,7]$ on the shortest path between vertices $1$, $2$ in graphs of size $[3,15]$ | Good generalization | +| On-path Target | Compute distance to a target vertex $t\in [3,15]$ on the shortest path between vertices $1$, $2$ in graphs of size $[3,15]$ | Moderate generalization | +| Any Target in Small | Compute distance to a target vertex $t\in [3,7]$ in graphs of size $[3,7]$. | High loss | +| Any Small Target in Large Graphs | Compute distance to a target vertex $t\in [3,7]$ in graphs of size $[3,15]$. | High loss | +| Any Target in Large Graphs | Compute distance to a target vertex $t\in [3,15]$ in graphs of size $[3,15]$. | High loss | $include`raw="html"` plots/test-onpath-smalltarget.html @@ -283,7 +294,9 @@ These three tests achieved a loss of $61.44$, $47.06$, and $72.87$ respectively. ## Conclusion -In this post, we've investigated solving the shortest paths problem with a transformer and when off-distribution generalization occurs. We explicitly constructed a transformer for BFS to show that a simple solution does indeed exist. Through experimentation, we also trained a transformer to solve shortest paths, although its internal mechanisms may be different than BFS. When the transformer learns to solve this problem, it also implicitly builds good internal representations of the input data that allow for efficient fine tuning to find shortest paths between other vertices. Finally, we've seen that some off-distribution generalization does occur, and some doesn't. The main conceptual take-away from our work is that it's hard to predict when models will and won't generalize. +In this post, we've investigated off-distribution generalization behavior of transformers trained on the shortest path problem. +We demonstrated mathematically the existence of a transformer computing shortest paths, and also found such a transformer from scratch via gradient descent. +We showed that a transformer trained to compute shortest paths between two specific vertices $v_1,v_2$ can be efficiently fine-tuned to compute shortest paths to other vertices that lie on the shortest $v_1$-$v_2$ path, suggesting that our transformers learned representations implicitly carry rich information about the graph. Finally, we showed that the transformer was able to generalize off-distribution quite well in some settings, but less well in other settings. The main conceptual take-away from our work is that it's hard to predict when models will and won't generalize. ## Appendix |