aboutsummaryrefslogtreecommitdiff
path: root/index.html
diff options
context:
space:
mode:
Diffstat (limited to 'index.html')
-rw-r--r--index.html447
1 files changed, 222 insertions, 225 deletions
diff --git a/index.html b/index.html
index efec318..fd5f74f 100644
--- a/index.html
+++ b/index.html
@@ -4,7 +4,7 @@
<meta charset="utf-8" />
<meta name="generator" content="pandoc" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
- <title>6.7960 Project: Investigating Off-Distribution Generalization of Transformers</title>
+ <title>Discovering Graph Algorithms Using Transformers</title>
<style>
html {
color: #1a1a1a;
@@ -271,74 +271,51 @@
</head>
<body>
<header id="title-block-header">
-<h1 class="title">6.7960 Project: Investigating Off-Distribution
-Generalization of Transformers</h1>
+<h1 class="title">Discovering Graph Algorithms Using Transformers</h1>
</header>
<!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 -->
<div style="text-align:center">
<p>Anthony Wang, Alek Westover, Kevin Zhao</p>
<p>{xy,alekw,kevinmz}@mit.edu</p>
</div>
-<h2 id="goals">Goals</h2>
-<p>Recently, LLMs have been developing very fast, and with that comes
-the concern of aligning the models to output true and productive
-statements. One common approach for ensuring this is to have a human in
-the loop rewarding the model for true outputs (e.g. RLHF), but one
-drawback to this problem is that humans can be poor judges of
-truthfulness. As LLMs become more capable, there might not even exist
-experts that are good judges of whether the model’s outputs, such as
-difficult mathematical proofs, are truthful. So, we’d like to propose a
-potential solution to this issue via <strong>off-distribution
-generalization</strong> - applying human-like intuition to solve
-problems not in the dataset. Paul Christiano <a
-href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">proposed
-an experiment</a> about shortest paths in a graph; our project is
-essentially to implement Christiano’s proposed experiment. To the best
-of our knowledge, although there has been research in applying machine
-learning for different variations of graph searches <span
-class="citation" data-cites="10.5555/3666122.3666260">(<a
-href="#ref-10.5555/3666122.3666260" role="doc-biblioref">Zang et al.
-2024</a>)</span>, no one has done our exact experiment yet.</p>
-<p>It is generally desirable for LLMs to output true statements. A
-current approach for ensuring this is to have a human in the loop
-rewarding the model for true outputs (e.g. RLHF); however, humans can be
-poor judges of truthfulness. We enjoy many cognitive biases and might
-employ superficial heuristics when judging truthfulness. A further
-challenge is that as LLMs develop further, there might not even exist
-experts that can correctly judge the accuracy and truthfulness of
-sophisticated outputs such as difficult mathematical proofs.</p>
-<p>One approach to solving this problem is to reward an LLM for truthful
-behavior on simple inputs, and then hoping that the LLM generalizes its
-truthful behavior for more complex inputs where humans cannot provide
-helpful labels. Deep learning models often perform remarkable feats of
-off-distribution generalization – for instance, a model trained to
-transform hand drawn cats into images of cats might be able to handle a
-“cat” with three eyes in an intuitive way. We might hope that
-generalizing truthfully is simple, thus promoted by “Occam’s Razor”, and
-aim to investigate that with this project.</p>
-<p>COMMENT FROM KEVIN – synthesize from intorduction</p>
-<h3 id="task">Task</h3>
-<p>We will use a synthetic task to test our hypothesis that models will
-generalize truthfully off-distribution. The synthetic task is computing
-the distance between various vertices in an input graph. Our experiment
-will have three parts:</p>
-<ol type="1">
-<li>Pre-train a transformer to predict the distance between two fixed
-vertices <span class="math inline">s,t</span> on graphs with <span
-class="math inline">n\in [8, 32)</span> vertices.</li>
-<li>Fine-tune a transformer to predict the distances between <span
-class="math inline">s,t&#39;</span> for any <span
-class="math inline">t&#39;</span> which is on the shortest path from
-<span class="math inline">s</span> to <span
-class="math inline">t</span>, but only do fine-tuning on graphs with
-<span class="math inline">n\in [8,16)</span> vertices.</li>
-<li>Test whether the transformer can accurately predict the distances
-between <span class="math inline">s,t&#39;</span> for any <span
-class="math inline">t&#39;</span> on the shortest path from <span
-class="math inline">s</span> to <span class="math inline">t</span> for
-graphs with <span class="math inline">n\in [16,32)</span> vertices.</li>
-</ol>
-<h3 id="related-work">Related Work</h3>
+<h2 id="motivation">Motivation</h2>
+<p>Transformers–the architecture that powers LLMs–can do incredible
+feats: trained on hundreds of gigabytes of raw text, they can learn to
+hold natural conversations, reason about the physical world, and write
+code. Skeptics argue that LLMs are simply memorizing their datasets
+without gaining any deeper understanding. For instance, GPT’s o1 model,
+achieving 90th percentile on Codeforces, struggles with simple but
+bizarre algorithms problems such as “find the subarray of a 2D array
+with the minimum average”. In this project, we hope to explore
+<strong>when off-distribution generalization happens in a
+transformer</strong>. Paul Christiano proposed an experiment <a
+href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">here</a>
+about shortest paths in a graph to investigate this, so we decided to
+become the first to implement his experiment and put transformers’
+generalization abilities to the test.</p>
+<p>LLMs are notorious for making up complete nonsense, so we also hope
+that our project can shed light on when truthfulness generalizes. It’s
+generally desirable for LLMs to output true statements. One current
+approach for ensuring this is to have a human in the loop rewarding the
+model for true outputs (e.g. RLHF). However, humans can be poor judges
+of truthfulness and have many cognitive biases and superficial
+heuristics. A further challenge is that as LLMs become more capable,
+there might not even exist experts that are good judges of whether the
+models outputs, such as difficult mathematical proofs, are truthful. For
+instance, most Task Rabbit workers would probably be hard pressed to
+evaluate whether a difficult mathematical proof produced by an LLM is
+true. The entire mathematical community has been known on occasion to <a
+href="https://en.wikipedia.org/wiki/Grunwald%E2%80%93Wang_theorem">believe
+false statements for many years</a>.</p>
+<p>One possible solution is to reward an LLM for truthful behavior on
+simple inputs, and then hope that the LLM generalizes its truthful
+behavior for more complex inputs where humans cannot provide helpful
+labels. Deep learning models can be remarkably good at off-distribution
+generalization–for instance, a model trained to transform hand drawn
+cats into images of cats might be able to handle a “cat” with three eyes
+in an intuitive way. We might hope that generalizing truthfully is
+simple, thus promoted by “Occam’s Razor”.</p>
+<h2 id="related-work">Related Work</h2>
<p>COMMENT FROM ALEK – please remove all mentions of graph neural
networks – that is BS: there is no actual reason why you’d ever use a
Neural network to solve shortest paths, the point of choosing a
@@ -367,170 +344,113 @@ href="#ref-10.1109/TPAMI.2023.3256421" role="doc-biblioref">Tutsoy
2023</a>)</span>. We understand from his paper how GNN optimization may
also be useful in researching novel diseases.</p></li>
</ul>
-<h2 id="methods">Methods</h2>
+<h2 id="task">Task</h2>
+<p>Our synthetic task is simple: compute the distance between various
+vertices in an input graph. To test off-distribution generalization, our
+experiment has three steps.</p>
+<ol type="1">
+<li><strong>Pre-train</strong> a transformer to predict the distance
+between vertices <span class="math inline">1</span> and <span
+class="math inline">2</span> in graphs with <span class="math inline">n
+\in [3,15]</span> vertices.</li>
+</ol>
+<div style="text-align:center">
+<p><img src="img/train.svg" /></p>
+</div>
+<ol start="2" type="1">
+<li><strong>Fine-tune</strong> a transformer to predict the distances
+between vertex <span class="math inline">1</span> to <span
+class="math inline">t</span> for any <span class="math inline">t</span>
+on the shortest path from <span class="math inline">1</span> to <span
+class="math inline">2</span>, but only do fine-tuning on graphs with
+<span class="math inline">n \in [3,7]</span> vertices.</li>
+</ol>
+<div style="text-align:center">
+<p><img src="img/finetune.svg" /></p>
+</div>
+<ol start="3" type="1">
+<li><strong>Test</strong> whether the transformer can accurately predict
+the distances between <span class="math inline">1</span> to <span
+class="math inline">t</span> for any <span class="math inline">t \leq
+7</span> on the shortest path from <span class="math inline">1</span> to
+<span class="math inline">2</span> for graphs with <span
+class="math inline">n \in [3,15]</span> vertices.</li>
+</ol>
+<div style="text-align:center">
+<p><img src="img/test.svg" /></p>
+</div>
<h3 id="algorithm-for-shortest-paths">Algorithm for Shortest Paths</h3>
-<p>The standard algorithm to find the shortest path in a graph between a
-source numbered as <span class="math inline">u</span> and sink numbered
-as <span class="math inline">v</span> is <strong>breadth-first search
-(BFS)</strong>. The BFS algorithm maintains a mapping of visited
-vertices to their distances with respect to <span
-class="math inline">u</span>, and each run of the algorithm goes through
-all the vertices newly visited in the previous run, and for each vertex,
-visits any of its unvisited neighbors. The algorithm terminates once
-either <span class="math inline">v</span> is visited or the set of newly
-visited vertices in a single run is empty.</p>
-<p>We will use this algorithm to verify the accuracy of our machine
-learning approach. Given <span class="math inline">V</span> vertices and
-<span class="math inline">E</span> edges, the runtime of this algorithm
-is thus <span class="math inline">O(V + E)</span>; however, a machine
-learning approach may do better in time through parallelism, although at
-the expense of using much more memory.</p>
+<p>The standard algorithm to find the shortest path in a graph between
+vertices <span class="math inline">u</span> and <span
+class="math inline">v</span> is <strong>breadth-first search
+(BFS)</strong>, taught in every intro algorithms class. Initially, BFS
+starts at <span class="math inline">u</span>, and at each phase,
+explores a farther layer of vertices from <span
+class="math inline">u</span>. During a phase, BFS goes through every
+vertex in the current layer and adds any of their unvisited neighbors to
+the next layer. The algorithm terminates once we reach <span
+class="math inline">v</span> or if the next layer is empty. For a graph
+with <span class="math inline">V</span> vertices and <span
+class="math inline">E</span> edges, the runtime of BFS is <span
+class="math inline">O(V + E)</span>. BFS gives us an easy and fast way
+to find the ground truth answer for any graph, so that we can verify the
+accuracy of our machine learning approach.</p>
+<p>We hope that our model can learn BFS or some other simple,
+generalizable algorithm for shortest paths, because the model can’t just
+pull some magic number out of a hat but intuitively needs to count or
+list the vertices on the shortest path from <span
+class="math inline">1</span> to <span class="math inline">2</span>. In
+fact, we will show how to hand-craft a set of weights to implement BFS
+in a transformer, so it’s indeed theoretically possible for a
+transformer to achieve 100% accuracy.</p>
<h3 id="data">Data</h3>
-<p>We will represent an <span class="math inline">n</span> vertex, <span
+<p>We’ll represent an <span class="math inline">n</span> vertex, <span
class="math inline">m</span> edge unweighted, undirected graph as
sequence of the endpoints of the <span class="math inline">m</span>
edges, so <span
class="math inline">[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]</span> represents a
-graph with the edges <span class="math inline">\{(a_i,b_i)\}</span> for
-<span class="math inline">1 \leq i \leq m</span>. We will pad all
-sequences to be the same length using the padding token 0.</p>
-<p>The full input to our model will additionally add the target vertex
-after the padding tokens. The model is tasked with predicting the length
-of the shortest path between vertex 1 and the target vertex <span
-class="math inline">t</span>. If no such path exists, we define the
-length to be <span class="math inline">n+1</span> which represents
-infinity. For example, an input-output pair for our model could look
-like <span class="math inline">[1, 3, 3, 2, 0, 0, 0, 0, 2]</span> and
-<span class="math inline">2</span> respectively.</p>
-<p>We have three separate datasets.</p>
-<ul>
+graph with the edges <span class="math inline">\{(a_i,b_i)\}</span>
+where <span class="math inline">a_i &lt; b_i</span> for <span
+class="math inline">1 \leq i \leq m</span>. All sequences are padded to
+the same length using the padding token <span
+class="math inline">0</span>.</p>
+<p>The full input to our model additionally includes the target vertex
+<span class="math inline">t</span> after the padding tokens. The label
+to an input is the length of the shortest path from <span
+class="math inline">1</span> to <span class="math inline">t</span>. If
+no such path exists, we define the length to be <span
+class="math inline">n+1</span> which represents infinity. For example,
+the input <span class="math inline">[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0,
+3]</span> has the label <span class="math inline">2</span>.</p>
+<div style="text-align:center">
+<p><img src="img/finetune.svg" /></p>
+</div>
+<p>We have three datasets for each step.</p>
+<ol type="1">
<li><strong>Pre-train data</strong>: For each <span
-class="math inline">n \in [8,32)</span>, we will generate several graphs
-on <span class="math inline">n</span> vertices. We generate these graphs
-by inserting <span class="math inline">2n</span> random edges into the
-graph. We always set the target vertex to be <span
-class="math inline">2</span> here.</li>
+class="math inline">n \in [3,15]</span>, we generated an equal number of
+graphs on <span class="math inline">n</span> vertices, with <span
+class="math inline">t = 2</span>. Each graph was created by choosing
+<span class="math inline">n</span> random edges.</li>
<li><strong>Fine-tune data</strong>: For each <span
-class="math inline">n \in [8,16)</span>, we will generate several graphs
-on <span class="math inline">n</span> vertices. We generate these graphs
-by inserting <span class="math inline">2n</span> random edges into the
-graph. We select the target vertex to be a random vertex on the shortest
-path from <span class="math inline">1</span> to <span
-class="math inline">2</span>.</li>
-<li><strong>Generalization testing data</strong>: The same as the
-fine-tune data, except we sample <span class="math inline">n \in
-[16,32)</span> instead.</li>
-</ul>
+class="math inline">n \in [3,7]</span>, we generated an equal number of
+graphs on <span class="math inline">n</span> vertices each with a random
+<span class="math inline">t</span> on the shortest path from <span
+class="math inline">1</span> to <span class="math inline">2</span>.
+Again, each graph was created by choosing <span
+class="math inline">n</span> random edges.</li>
+<li><strong>Generalization test data</strong>: The same as the fine-tune
+data, except we sample <span class="math inline">n \in [3,15]</span> and
+<span class="math inline">t \leq 7</span>.</li>
+</ol>
<p>We wrote some Python code to generate the data during the training
-loop, but Python is slow and the data generation wasted a lot of time
-during training. To get around this, we pre-generated the data before
-training and made our Python code multithreaded to speed it up.</p>
-<h3 id="architecture">Architecture</h3>
-<p>TODO: honestly not much to say here since it’s a pretty typical
-arch</p>
-<p>We plan to use a standard transformer architecture. We will ensure
-that the number of layers in our transformer is at least the diameter of
-the graph. By doing this, we ensure that there is an extremely simple
-circuit — namely BFS — that the transformer could in theory learn to
-perform the task. Note that if the transformer actually learns a simple
-circuit to perform this task, then it seems more likely to generalize
-well. This is also our intuition for why it should be possible to fine
-tune on a small amount of data for finding shortest paths to other
-vertices besides <span class="math inline">2</span> – it seems like the
-model should be computing these other distances as intermediate values
-in its computation to find the distance to vertex <span
-class="math inline">2</span>.</p>
-<h3 id="embeddings">Embeddings</h3>
-<p>Since the order of the edges in the input does not matter, we did not
-use positional encodings. Each edge <span
-class="math inline">(u,v)</span> where <span class="math inline">u &lt;
-v</span> is embedded to a dimension of <span
-class="math inline">d</span> where the first <span
-class="math inline">\frac{d}{2}</span> elements are the learned
-embedding of <span class="math inline">u</span> and the last <span
-class="math inline">\frac{d}{2}</span> elements are the learned
-embedding of <span class="math inline">v</span>. For the target vertex
-<span class="math inline">t</span>, we also embedded to dimension <span
-class="math inline">d</span>, where the first <span
-class="math inline">\frac{d}{2}</span> elements are the learned
-embedding of <span class="math inline">t</span> and the last <span
-class="math inline">\frac{d}{2}</span> are a learned embedding of a
-special token.</p>
-<h2 id="training">Training</h2>
-<p>For our model, we used a model dimension of 64, four layers, and two
-heads per layer, for a total of 200545 parameters in bfloat16 which
-corresponds to around 3.2e6 bits. The number of possible graphs on 15
-vertices generated using our procedure is approximately</p>
-<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} =
-1.59\cdot10^{18}.</span></p>
-<p>This is because there are <span
-class="math inline">\binom{15}{2}</span> choices for each of the 15
-edges and we don’t care about the order of the edges. This is only an
-approximation because some edges might be duplicated. Each graph has an
-answer between 1 and 15 which requires around 4 bits, so memorizing all
-the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} =
-6.36\cdot10^{18}</span>, which is <span
-class="math inline">2\cdot10^{12}</span> times larger than our model
-size.</p>
-<p>We used MSE loss, the Adam optimizer, a learning rate of 8e-4, and a
-batch size of 131072 for 8000 unique randomly generated batches. Our
-final MSE loss was approximately 0.3555.</p>
-<p><img src="training-loss.png" /></p>
-<p><img src="training-2d-histogram.png" /></p>
-<p>One pattern we consistently noticed during training is that the model
-often gets stuck and plateaus for many epochs before rapidly decreasing.
-For instance, this happened between epochs 100 and 300 in the graph
-above:</p>
-<p><img src="grokking.png" /></p>
-<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p>
-<p>TODO: training curves for 1, 2, 3 length paths</p>
-<h3
-id="potential-mathematical-approaches-to-shortest-paths-delete-this">Potential
-Mathematical Approaches to Shortest Paths? Delete this?</h3>
-<p>Another way one can think of the shortest path of a graph is using a
-<em>matrix</em> to record which vertices are connected. Given vertices
-numbered <span class="math inline">1</span> to <span
-class="math inline">V</span>, we denote the <strong>adjacency
-matrix</strong> <span class="math inline">\textbf{M}</span> of
-dimensions <span class="math inline">V \times V</span> as the matrix
-with element <span class="math inline">\textbf{M}_{i, j} = 1</span> if
-vertices <span class="math inline">i</span> and <span
-class="math inline">j</span> are connected by an edge and <span
-class="math inline">\textbf{M}_{i, j} = 0</span> if they are not. Now,
-we note that (1) For all <span class="math inline">k</span>, <span
-class="math inline">(\textbf{M}+I)^k_{i, j} = 0</span> if and only if
-there exists no path from the vertex numbered <span
-class="math inline">i</span> to the vertex numbered <span
-class="math inline">j</span> that is distance <span
-class="math inline">k</span> or less due to Markov matrix processes. As
-a result, if the distance between vertices numbered <span
-class="math inline">i</span> and <span class="math inline">j</span> is
-<span class="math inline">d</span>, then <span
-class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) =
-1</span> if <span class="math inline">k \ge d</span> and <span
-class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) =
-0</span> if <span class="math inline">k &lt; d</span>.</p>
-<p>With this information, because the distance between any two vertices
-is at most <span class="math inline">V-1</span> in a graph with <span
-class="math inline">V</span> vertices, we note that the
-<em>distance</em> matrix turns out to be simply <span
-class="math display">\textbf{D} = \textbf{1}_{V \times V} \cdot V -
-\Sigma_{i=0}^{V-1}\text{min}\left((\textbf{M}+I)^k_{i, j},
-1\right).</span> The runtime to compute this is <span
-class="math inline">O(V)</span>, although it will take more space to
-compute all powers of <span class="math inline">\textbf{M}</span>.</p>
-<h2 id="fine-tuning-results">Fine tuning results</h2>
-<p>After receiving our initial results, we fine-tuned with a learning
-rate of 1e-5, also with MSE and the same batch size. Our final results
-are shown in the images below.</p>
-<p><img src="fine-tuning-loss.png" /></p>
-<p><img src="fine-tuning-2d-histogram.png" /></p>
-<p><img src="test-2d-histogram.png" /></p>
-<p>Memorization? Do some math here to compute how many bits required to
-memorize 1, 2, 3</p>
+loop, but Python is excruciatingly slow and data generation wasted a lot
+of training time. Our solution was to pre-generate the data before
+training using a multithreaded version of our Python code.</p>
<h2
id="complicated-explicit-transformer-formula-for-shortest-paths">Complicated
explicit transformer formula for shortest paths</h2>
+<p>TODO: Kevin or Alek</p>
<div class="sourceCode" id="cb1"><pre class="sourceCode py"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration</span></span>
<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>NVTXS <span class="op">=</span> <span class="dv">16</span></span>
<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>MAXDIST <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
@@ -629,9 +549,6 @@ explicit transformer formula for shortest paths</h2>
<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a> canreach <span class="op">=</span> src[:, <span class="dv">0</span>, START_OUT:START_OUT <span class="op">+</span> NVTXS]</span>
<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a> final_output <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> torch.<span class="bu">sum</span>(<span class="dv">1</span> <span class="op">-</span> canreach, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> final_output</span></code></pre></div>
-<p>It looked like the fine tuning results weren’t as meaningful because
-TODO: ALEK</p>
-<h2 id="customizing-a-transformer">Customizing a Transformer</h2>
<p>After much deliberation, we decided the next step for us was to
customize a transformer, writing it ourselves. We observed that we
wished for the transformer to do similar work as a BFS. As a result, we
@@ -686,8 +603,95 @@ class="math inline">i</span>th entry is nonzero.</p>
<p>This iterates through all the vertices, and at the end, we may see
what run we are on to update <span
class="math inline">V_{\text{OUT}}</span>.</p>
+<h2 id="perturbing-the-weights">Perturbing the Weights</h2>
+<p>SLT folks like to think about geometry of loss landscape CITE So we
+did an experiment where we mess with the weights.</p>
+<p>Findings: XXX</p>
+<h2 id="our-model">Our Model</h2>
+<h3 id="architecture">Architecture</h3>
+<p>We used a standard transformer architecture. To ensure that it can in
+theory learn BFS, we ensured that the number of layers in our
+transformer exceeds the diameter of the input graphs.</p>
+<p>Since the order of the edges in the input doesn’t matter, we did not
+use positional encodings. Each edge <span
+class="math inline">(a,b)</span> is embedded to dimension <span
+class="math inline">d</span> where the first <span
+class="math inline">\frac{d}{2}</span> elements are the learned
+embedding of <span class="math inline">a</span> and the last <span
+class="math inline">\frac{d}{2}</span> elements are the learned
+embedding of <span class="math inline">b</span>. For the target vertex
+<span class="math inline">t</span>, we pair it with the special token
+<span class="math inline">TARGET</span> and embed <span
+class="math inline">(t,TARGET)</span> in the same way.</p>
+<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm -->
+<p><img src="img/embeddings.svg" /></p>
+<h3 id="training">Training</h3>
+<p>To match the BFS transformer as closely as possible, we used a model
+dimension of <span class="math inline">64</span>, <span
+class="math inline">11</span> layers, and <span
+class="math inline">2</span> heads per layer, for a total of 550433
+parameters. In 32-bit float precision, that corresponds to around <span
+class="math inline">1.76\cdot10^6</span> bits. The number of possible
+graphs on 15 vertices generated using our procedure is approximately</p>
+<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} =
+1.59\cdot10^{18}.</span></p>
+<p>This is because there are <span
+class="math inline">\binom{15}{2}</span> choices for each of the 15
+edges and we don’t care about the order of the edges. This is only an
+approximation because some edges might be duplicated. Each graph has an
+answer between 1 and 15 which requires around 4 bits, so memorizing all
+the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} =
+6.36\cdot10^{18}</span> bits, which is <span
+class="math inline">3.61\cdot10^{12}</span> times larger than our model
+size.</p>
+<p>To train the model, we used MSE loss, the Adam optimizer, a learning
+rate of <span class="math inline">3\cdot10^{-4}</span>, and a batch size
+of <span class="math inline">2^{15}</span> for one billion randomly
+generated graphs. A training run takes roughly eight hours to run on a
+Radeon 7900 XTX graphics card. Our final MSE loss was <span
+class="math inline">0.000555</span>.</p>
+<p>TODO: use https://mpld3.github.io/index.html to make interactive
+plots</p>
+<p><img src="training-loss.png" /></p>
+<p><img src="training-2d-histogram.png" /></p>
+<p>One pattern we consistently noticed during training is that the model
+often gets stuck and plateaus for many epochs before rapidly decreasing.
+For instance, this happened between epochs 100 and 300 in the graph
+above:</p>
+<p><img src="grokking.png" /></p>
+<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p>
+<p>TODO: cite Neel Nanda grokking modular addition</p>
+<p>TODO: CRAZY!!! training curves for 1, 2, 3 length paths</p>
+<p>One pitfall we encountered during training is that we initially used
+bfloat16 to save VRAM, but our huge batch size caused loss-of-precision
+problems and made training very difficult. It took us two weeks to debug
+this until we found that switching to float32 improved training
+significantly.</p>
+<h2 id="fine-tuning-results">Fine tuning results</h2>
+<p>After receiving our initial results, we fine-tuned with a learning
+rate of 1e-5, also with MSE and the same batch size. Our final results
+are shown in the images below.</p>
+<p><img src="fine-tuning-loss.png" /></p>
+<p><img src="fine-tuning-2d-histogram.png" /></p>
+<p><img src="test-2d-histogram.png" /></p>
+<p>TODO: get new graphs</p>
+<p>It’s pretty good!!!</p>
+<p>Can only generalize to target vertices from 2 to 7 since 8 through 15
+didn’t appear in the fine-tune data</p>
+<p>but this still means it</p>
<h2 id="conclusion">Conclusion</h2>
+<p>however, a machine learning approach may do better in time through
+parallelism, although at the expense of using much more memory.
+<strong>TODO: ALEK: this is BS. If you want a parallel algorithm for
+BFS, here’s one
+https://en.wikipedia.org/wiki/Parallel_single-source_shortest_path_algorithm</strong></p>
<p>just do bfs lol</p>
+<p><strong>Future Work</strong> There are a couple of other things that
+we could try to learn shortest paths better and maybe see more
+generalization. - Chain of thought - Train model to output a path, not
+just the distance. Give it partial points for outputting anything that
+is a legitimate path (i.e., consists of real edges) and more points for
+getting the distance correct.</p>
<h2 class="unnumbered" id="references">References</h2>
<div id="refs" class="references csl-bib-body hanging-indent"
data-entry-spacing="0" role="list">
@@ -707,13 +711,6 @@ Exact Epidemiological Modeling of Pandemic Diseases.”</span> <em>IEEE
Trans. Pattern Anal. Mach. Intell.</em> 45 (8): 9836–45. <a
href="https://doi.org/10.1109/TPAMI.2023.3256421">https://doi.org/10.1109/TPAMI.2023.3256421</a>.
</div>
-<div id="ref-10.5555/3666122.3666260" class="csl-entry" role="listitem">
-Zang, Xiao, Miao Yin, Jinqi Xiao, Saman Zonouz, and Bo Yuan. 2024.
-<span>“GraphMP: Graph Neural Network-Based Motion Planning with
-Efficient Graph Search.”</span> In <em>Proceedings of the 37th
-International Conference on Neural Information Processing Systems</em>.
-NIPS ’23. Red Hook, NY, USA: Curran Associates Inc.
-</div>
</div>
</body>
</html>