diff options
author | SIPB | 2024-12-08 01:58:46 -0500 |
---|---|---|
committer | SIPB | 2024-12-08 01:58:46 -0500 |
commit | 1ea92f93e0dd5a8254fd3b0df5cc84ccdb4e505b (patch) | |
tree | de49b21ef6c3a8985e8ba8e3f73da625df15641c /blog.md | |
parent | 77ad008124f9825ac80c7328923827a2212d600b (diff) |
More work on blog.md, move images to img folder
Diffstat (limited to 'blog.md')
-rw-r--r-- | blog.md | 181 |
1 files changed, 109 insertions, 72 deletions
@@ -1,7 +1,7 @@ --- build: pandoc blog.md --citeproc --katex -s -o index.html mkzip: zip project.zip index.html *.png -title: "6.7960 Project: Investigating Off-Distribution Generalization of Transformers" +title: "Discovering Graph Algorithms Using Transformers" bibliography: blog.bib link-citations: true --- @@ -14,25 +14,15 @@ Anthony Wang, Alek Westover, Kevin Zhao {xy,alekw,kevinmz}\@mit.edu </div> -## Goals +## Motivation -Recently, LLMs have been developing very fast, and with that comes the concern of aligning the models to output true and productive statements. One common approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF), but one drawback to this problem is that humans can be poor judges of truthfulness. As LLMs become more capable, there might not even exist experts that are good judges of whether the model's outputs, such as difficult mathematical proofs, are truthful. So, we'd like to propose a potential solution to this issue via **off-distribution generalization** - applying human-like intuition to solve problems not in the dataset. Paul Christiano [proposed an experiment](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) about shortest paths in a graph; our project is essentially to implement Christiano's proposed experiment. To the best of our knowledge, although there has been research in applying machine learning for different variations of graph searches [@10.5555/3666122.3666260], no one has done our exact experiment yet. +Transformers--the architecture that powers LLMs--can do incredible feats: trained on hundreds of gigabytes of raw text, they can learn to hold natural conversations, reason about the physical world, and write code. Skeptics argue that LLMs are simply memorizing their datasets without gaining any deeper understanding. For instance, GPT's o1 model, achieving 90th percentile on Codeforces, struggles with simple but bizarre algorithms problems such as "find the subarray of a 2D array with the minimum average". In this project, we hope to explore **when off-distribution generalization happens in a transformer**. Paul Christiano proposed an experiment [here](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) about shortest paths in a graph to investigate this, so we decided to become the first to implement his experiment and put transformers' generalization abilities to the test. -It is generally desirable for LLMs to output true statements. A current approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF); however, humans can be poor judges of truthfulness. We enjoy many cognitive biases and might employ superficial heuristics when judging truthfulness. A further challenge is that as LLMs develop further, there might not even exist experts that can correctly judge the accuracy and truthfulness of sophisticated outputs such as difficult mathematical proofs. +LLMs are notorious for making up complete nonsense, so we also hope that our project can shed light on when truthfulness generalizes. It's generally desirable for LLMs to output true statements. One current approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF). However, humans can be poor judges of truthfulness and have many cognitive biases and superficial heuristics. A further challenge is that as LLMs become more capable, there might not even exist experts that are good judges of whether the models outputs, such as difficult mathematical proofs, are truthful. For instance, most Task Rabbit workers would probably be hard pressed to evaluate whether a difficult mathematical proof produced by an LLM is true. The entire mathematical community has been known on occasion to [believe false statements for many years](https://en.wikipedia.org/wiki/Grunwald%E2%80%93Wang_theorem). -One approach to solving this problem is to reward an LLM for truthful behavior on simple inputs, and then hoping that the LLM generalizes its truthful behavior for more complex inputs where humans cannot provide helpful labels. Deep learning models often perform remarkable feats of off-distribution generalization -- for instance, a model trained to transform hand drawn cats into images of cats might be able to handle a "cat" with three eyes in an intuitive way. We might hope that generalizing truthfully is simple, thus promoted by "Occam's Razor", and aim to investigate that with this project. +One possible solution is to reward an LLM for truthful behavior on simple inputs, and then hope that the LLM generalizes its truthful behavior for more complex inputs where humans cannot provide helpful labels. Deep learning models can be remarkably good at off-distribution generalization--for instance, a model trained to transform hand drawn cats into images of cats might be able to handle a "cat" with three eyes in an intuitive way. We might hope that generalizing truthfully is simple, thus promoted by "Occam's Razor". -COMMENT FROM KEVIN -- synthesize from intorduction - -### Task - -We will use a synthetic task to test our hypothesis that models will generalize truthfully off-distribution. The synthetic task is computing the distance between various vertices in an input graph. Our experiment will have three parts: - -1. Pre-train a transformer to predict the distance between two fixed vertices $s,t$ on graphs with $n\in [8, 32)$ vertices. -2. Fine-tune a transformer to predict the distances between $s,t'$ for any $t'$ which is on the shortest path from $s$ to $t$, but only do fine-tuning on graphs with $n\in [8,16)$ vertices. -3. Test whether the transformer can accurately predict the distances between $s,t'$ for any $t'$ on the shortest path from $s$ to $t$ for graphs with $n\in [16,32)$ vertices. - -### Related Work +## Related Work COMMENT FROM ALEK -- please remove all mentions of graph neural networks -- that is BS: there is no actual reason why you'd ever use a Neural network to solve shortest paths, the point of choosing a synthetic task is because there is a **simple ground truth** which makes it easy to evaluate whether or not our model is performing correctly. We'd also hoped that the simplicity of the task would make it more feasible to do with a limited compute budget, but apparently this task was too hard for our architecture. @@ -44,82 +34,58 @@ There has been some research into the algorithmic optimization of GNNs and how t - Tutsoy uses a graph-theory-based approach to model the epidemiological characteristics of infectious diseases, such as COVID-19 [@10.1109/TPAMI.2023.3256421]. We understand from his paper how GNN optimization may also be useful in researching novel diseases. -## Methods - -### Algorithm for Shortest Paths - -The standard algorithm to find the shortest path in a graph between a source numbered as $u$ and sink numbered as $v$ is **breadth-first search (BFS)**. The BFS algorithm maintains a mapping of visited vertices to their distances with respect to $u$, and each run of the algorithm goes through all the vertices newly visited in the previous run, and for each vertex, visits any of its unvisited neighbors. The algorithm terminates once either $v$ is visited or the set of newly visited vertices in a single run is empty. - -We will use this algorithm to verify the accuracy of our machine learning approach. Given $V$ vertices and $E$ edges, the runtime of this algorithm is thus $O(V + E)$; however, a machine learning approach may do better in time through parallelism, although at the expense of using much more memory. - -### Data - -We will represent an $n$ vertex, $m$ edge unweighted, undirected graph as sequence of the endpoints of the $m$ edges, so $[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]$ represents a graph with the edges $\{(a_i,b_i)\}$ for $1 \leq i \leq m$. We will pad all sequences to be the same length using the padding token 0. - -The full input to our model will additionally add the target vertex after the padding tokens. The model is tasked with predicting the length of the shortest path between vertex 1 and the target vertex $t$. If no such path exists, we define the length to be $n+1$ which represents infinity. For example, an input-output pair for our model could look like $[1, 3, 3, 2, 0, 0, 0, 0, 2]$ and $2$ respectively. - -We have three separate datasets. - -- **Pre-train data**: For each $n \in [8,32)$, we will generate several graphs on $n$ vertices. We generate these graphs by inserting $2n$ random edges into the graph. We always set the target vertex to be $2$ here. -- **Fine-tune data**: For each $n \in [8,16)$, we will generate several graphs on $n$ vertices. We generate these graphs by inserting $2n$ random edges into the graph. We select the target vertex to be a random vertex on the shortest path from $1$ to $2$. -- **Generalization testing data**: The same as the fine-tune data, except we sample $n \in [16,32)$ instead. - -We wrote some Python code to generate the data during the training loop, but Python is slow and the data generation wasted a lot of time during training. To get around this, we pre-generated the data before training and made our Python code multithreaded to speed it up. - -### Architecture - -TODO: honestly not much to say here since it's a pretty typical arch - -We plan to use a standard transformer architecture. We will ensure that the number of layers in our transformer is at least the diameter of the graph. By doing this, we ensure that there is an extremely simple circuit --- namely BFS --- that the transformer could in theory learn to perform the task. Note that if the transformer actually learns a simple circuit to perform this task, then it seems more likely to generalize well. This is also our intuition for why it should be possible to fine tune on a small amount of data for finding shortest paths to other vertices besides $2$ -- it seems like the model should be computing these other distances as intermediate values in its computation to find the distance to vertex $2$. - -### Embeddings - -Since the order of the edges in the input does not matter, we did not use positional encodings. Each edge $(u,v)$ where $u < v$ is embedded to a dimension of $d$ where the first $\frac{d}{2}$ elements are the learned embedding of $u$ and the last $\frac{d}{2}$ elements are the learned embedding of $v$. For the target vertex $t$, we also embedded to dimension $d$, where the first $\frac{d}{2}$ elements are the learned embedding of $t$ and the last $\frac{d}{2}$ are a learned embedding of a special token. - -## Training - -For our model, we used a model dimension of 64, four layers, and two heads per layer, for a total of 200545 parameters in bfloat16 which corresponds to around 3.2e6 bits. The number of possible graphs on 15 vertices generated using our procedure is approximately -$$\frac{\binom{15}{2}^{15}}{15!} = 1.59\cdot10^{18}.$$ - -This is because there are $\binom{15}{2}$ choices for each of the 15 edges and we don't care about the order of the edges. This is only an approximation because some edges might be duplicated. Each graph has an answer between 1 and 15 which requires around 4 bits, so memorizing all the answers requires $4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}$, which is $2\cdot10^{12}$ times larger than our model size. +## Task +Our synthetic task is simple: compute the distance between various vertices in an input graph. To test off-distribution generalization, our experiment has three steps. -We used MSE loss, the Adam optimizer, a learning rate of 8e-4, and a batch size of 131072 for 8000 unique randomly generated batches. Our final MSE loss was approximately 0.3555. +1. **Pre-train** a transformer to predict the distance between vertices $1$ and $2$ in graphs with $n \in [3,15]$ vertices. -![](training-loss.png) +<div style="text-align:center"> +![](img/train.svg) +</div> -![](training-2d-histogram.png) +2. **Fine-tune** a transformer to predict the distances between vertex $1$ to $t$ for any $t$ on the shortest path from $1$ to $2$, but only do fine-tuning on graphs with $n \in [3,7]$ vertices. -One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph above: +<div style="text-align:center"> +![](img/finetune.svg) +</div> -![](grokking.png) +3. **Test** whether the transformer can accurately predict the distances between $1$ to $t$ for any $t \leq 7$ on the shortest path from $1$ to $2$ for graphs with $n \in [3,15]$ vertices. -"grokking" hypothesis: it's memorizing all length 2 paths? +<div style="text-align:center"> +![](img/test.svg) +</div> -TODO: training curves for 1, 2, 3 length paths +### Algorithm for Shortest Paths -### Potential Mathematical Approaches to Shortest Paths? Delete this? +The standard algorithm to find the shortest path in a graph between vertices $u$ and $v$ is **breadth-first search (BFS)**, taught in every intro algorithms class. Initially, BFS starts at $u$, and at each phase, explores a farther layer of vertices from $u$. During a phase, BFS goes through every vertex in the current layer and adds any of their unvisited neighbors to the next layer. The algorithm terminates once we reach $v$ or if the next layer is empty. For a graph with $V$ vertices and $E$ edges, the runtime of BFS is $O(V + E)$. BFS gives us an easy and fast way to find the ground truth answer for any graph, so that we can verify the accuracy of our machine learning approach. -Another way one can think of the shortest path of a graph is using a *matrix* to record which vertices are connected. Given vertices numbered $1$ to $V$, we denote the **adjacency matrix** $\textbf{M}$ of dimensions $V \times V$ as the matrix with element $\textbf{M}_{i, j} = 1$ if vertices $i$ and $j$ are connected by an edge and $\textbf{M}_{i, j} = 0$ if they are not. Now, we note that (1) For all $k$, $(\textbf{M}+I)^k_{i, j} = 0$ if and only if there exists no path from the vertex numbered $i$ to the vertex numbered $j$ that is distance $k$ or less due to Markov matrix processes. As a result, if the distance between vertices numbered $i$ and $j$ is $d$, then $\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = 1$ if $k \ge d$ and $\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = 0$ if $k < d$. +We hope that our model can learn BFS or some other simple, generalizable algorithm for shortest paths, because the model can't just pull some magic number out of a hat but intuitively needs to count or list the vertices on the shortest path from $1$ to $2$. In fact, we will show how to hand-craft a set of weights to implement BFS in a transformer, so it's indeed theoretically possible for a transformer to achieve 100% accuracy. -With this information, because the distance between any two vertices is at most $V-1$ in a graph with $V$ vertices, we note that the *distance* matrix turns out to be simply $$\textbf{D} = \textbf{1}_{V \times V} \cdot V - \Sigma_{i=0}^{V-1}\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right).$$ The runtime to compute this is $O(V)$, although it will take more space to compute all powers of $\textbf{M}$. +### Data -## Fine tuning results +We'll represent an $n$ vertex, $m$ edge unweighted, undirected graph as sequence of the endpoints of the $m$ edges, so $[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]$ represents a graph with the edges $\{(a_i,b_i)\}$ where $a_i < b_i$ for $1 \leq i \leq m$. All sequences are padded to the same length using the padding token $0$. -After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results are shown in the images below. +The full input to our model additionally includes the target vertex $t$ after the padding tokens. The label to an input is the length of the shortest path from $1$ to $t$. If no such path exists, we define the length to be $n+1$ which represents infinity. For example, the input $[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0, 3]$ has the label $2$. -![](fine-tuning-loss.png) +<div style="text-align:center"> +![](img/finetune.svg) +</div> -![](fine-tuning-2d-histogram.png) +We have three datasets for each step. -![](test-2d-histogram.png) +1. **Pre-train data**: For each $n \in [3,15]$, we generated an equal number of graphs on $n$ vertices, with $t = 2$. Each graph was created by choosing $n$ random edges. +2. **Fine-tune data**: For each $n \in [3,7]$, we generated an equal number of graphs on $n$ vertices each with a random $t$ on the shortest path from $1$ to $2$. Again, each graph was created by choosing $n$ random edges. +3. **Generalization test data**: The same as the fine-tune data, except we sample $n \in [3,15]$ and $t \leq 7$. -Memorization? Do some math here to compute how many bits required to memorize 1, 2, 3 +We wrote some Python code to generate the data during the training loop, but Python is excruciatingly slow and data generation wasted a lot of training time. Our solution was to pre-generate the data before training using a multithreaded version of our Python code. ## Complicated explicit transformer formula for shortest paths +TODO: Kevin or Alek + ```py # Configuration NVTXS = 16 @@ -221,9 +187,7 @@ class SillyTransformer(nn.Module): return final_output ``` -It looked like the fine tuning results weren't as meaningful because TODO: ALEK -## Customizing a Transformer After much deliberation, we decided the next step for us was to customize a transformer, writing it ourselves. We observed that we wished for the transformer to do similar work as a BFS. As a result, we decided to work with the following transformer, for a graph with $n$ vertices $v_1, v_2, \cdots, v_n$: @@ -241,8 +205,81 @@ Specifically, we see that $\text{NBR}_{i}$ is a $n \times 1$ vector detailing wh This iterates through all the vertices, and at the end, we may see what run we are on to update $V_{\text{OUT}}$. +## Perturbing the Weights + +SLT folks like to think about geometry of loss landscape CITE +So we did an experiment where we mess with the weights. + +Findings: XXX + + +## Our Model + +### Architecture + +We used a standard transformer architecture. To ensure that it can in theory learn BFS, we ensured that the number of layers in our transformer exceeds the diameter of the input graphs. + +Since the order of the edges in the input doesn't matter, we did not use positional encodings. Each edge $(a,b)$ is embedded to dimension $d$ where the first $\frac{d}{2}$ elements are the learned embedding of $a$ and the last $\frac{d}{2}$ elements are the learned embedding of $b$. For the target vertex $t$, we pair it with the special token $TARGET$ and embed $(t,TARGET)$ in the same way. + +<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm --> +![](img/embeddings.svg) + +### Training + +To match the BFS transformer as closely as possible, we used a model dimension of $64$, $11$ layers, and $2$ heads per layer, for a total of 550433 parameters. In 32-bit float precision, that corresponds to around $1.76\cdot10^6$ bits. The number of possible graphs on 15 vertices generated using our procedure is approximately + +$$\frac{\binom{15}{2}^{15}}{15!} = 1.59\cdot10^{18}.$$ + +This is because there are $\binom{15}{2}$ choices for each of the 15 edges and we don't care about the order of the edges. This is only an approximation because some edges might be duplicated. Each graph has an answer between 1 and 15 which requires around 4 bits, so memorizing all the answers requires $4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}$ bits, which is $3.61\cdot10^{12}$ times larger than our model size. + +To train the model, we used MSE loss, the Adam optimizer, a learning rate of $3\cdot10^{-4}$, and a batch size of $2^{15}$ for one billion randomly generated graphs. A training run takes roughly eight hours to run on a Radeon 7900 XTX graphics card. Our final MSE loss was $0.000555$. + +TODO: use https://mpld3.github.io/index.html to make interactive plots + +![](training-loss.png) + +![](training-2d-histogram.png) + +One pattern we consistently noticed during training is that the model often gets stuck and plateaus for many epochs before rapidly decreasing. For instance, this happened between epochs 100 and 300 in the graph above: + +![](grokking.png) + +"grokking" hypothesis: it's memorizing all length 2 paths? + +TODO: cite Neel Nanda grokking modular addition + +TODO: CRAZY!!! training curves for 1, 2, 3 length paths + +One pitfall we encountered during training is that we initially used bfloat16 to save VRAM, but our huge batch size caused loss-of-precision problems and made training very difficult. It took us two weeks to debug this until we found that switching to float32 improved training significantly. + +## Fine tuning results + +After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results are shown in the images below. + +![](fine-tuning-loss.png) + +![](fine-tuning-2d-histogram.png) + +![](test-2d-histogram.png) + +TODO: get new graphs + +It's pretty good!!! + +Can only generalize to target vertices from 2 to 7 since 8 through 15 didn't appear in the fine-tune data + +but this still means it + ## Conclusion +however, a machine learning approach may do better in time through parallelism, although at the expense of using much more memory. +**TODO: ALEK: this is BS. If you want a parallel algorithm for BFS, here's one https://en.wikipedia.org/wiki/Parallel_single-source_shortest_path_algorithm** + just do bfs lol +**Future Work** +There are a couple of other things that we could try to learn shortest paths better and maybe see more generalization. +- Chain of thought +- Train model to output a path, not just the distance. Give it partial points for outputting anything that is a legitimate path (i.e., consists of real edges) and more points for getting the distance correct. + ## References |