diff options
Diffstat (limited to 'index.html')
-rw-r--r-- | index.html | 447 |
1 files changed, 222 insertions, 225 deletions
@@ -4,7 +4,7 @@ <meta charset="utf-8" /> <meta name="generator" content="pandoc" /> <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" /> - <title>6.7960 Project: Investigating Off-Distribution Generalization of Transformers</title> + <title>Discovering Graph Algorithms Using Transformers</title> <style> html { color: #1a1a1a; @@ -271,74 +271,51 @@ </head> <body> <header id="title-block-header"> -<h1 class="title">6.7960 Project: Investigating Off-Distribution -Generalization of Transformers</h1> +<h1 class="title">Discovering Graph Algorithms Using Transformers</h1> </header> <!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 --> <div style="text-align:center"> <p>Anthony Wang, Alek Westover, Kevin Zhao</p> <p>{xy,alekw,kevinmz}@mit.edu</p> </div> -<h2 id="goals">Goals</h2> -<p>Recently, LLMs have been developing very fast, and with that comes -the concern of aligning the models to output true and productive -statements. One common approach for ensuring this is to have a human in -the loop rewarding the model for true outputs (e.g. RLHF), but one -drawback to this problem is that humans can be poor judges of -truthfulness. As LLMs become more capable, there might not even exist -experts that are good judges of whether the model’s outputs, such as -difficult mathematical proofs, are truthful. So, we’d like to propose a -potential solution to this issue via <strong>off-distribution -generalization</strong> - applying human-like intuition to solve -problems not in the dataset. Paul Christiano <a -href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">proposed -an experiment</a> about shortest paths in a graph; our project is -essentially to implement Christiano’s proposed experiment. To the best -of our knowledge, although there has been research in applying machine -learning for different variations of graph searches <span -class="citation" data-cites="10.5555/3666122.3666260">(<a -href="#ref-10.5555/3666122.3666260" role="doc-biblioref">Zang et al. -2024</a>)</span>, no one has done our exact experiment yet.</p> -<p>It is generally desirable for LLMs to output true statements. A -current approach for ensuring this is to have a human in the loop -rewarding the model for true outputs (e.g. RLHF); however, humans can be -poor judges of truthfulness. We enjoy many cognitive biases and might -employ superficial heuristics when judging truthfulness. A further -challenge is that as LLMs develop further, there might not even exist -experts that can correctly judge the accuracy and truthfulness of -sophisticated outputs such as difficult mathematical proofs.</p> -<p>One approach to solving this problem is to reward an LLM for truthful -behavior on simple inputs, and then hoping that the LLM generalizes its -truthful behavior for more complex inputs where humans cannot provide -helpful labels. Deep learning models often perform remarkable feats of -off-distribution generalization – for instance, a model trained to -transform hand drawn cats into images of cats might be able to handle a -“cat” with three eyes in an intuitive way. We might hope that -generalizing truthfully is simple, thus promoted by “Occam’s Razor”, and -aim to investigate that with this project.</p> -<p>COMMENT FROM KEVIN – synthesize from intorduction</p> -<h3 id="task">Task</h3> -<p>We will use a synthetic task to test our hypothesis that models will -generalize truthfully off-distribution. The synthetic task is computing -the distance between various vertices in an input graph. Our experiment -will have three parts:</p> -<ol type="1"> -<li>Pre-train a transformer to predict the distance between two fixed -vertices <span class="math inline">s,t</span> on graphs with <span -class="math inline">n\in [8, 32)</span> vertices.</li> -<li>Fine-tune a transformer to predict the distances between <span -class="math inline">s,t'</span> for any <span -class="math inline">t'</span> which is on the shortest path from -<span class="math inline">s</span> to <span -class="math inline">t</span>, but only do fine-tuning on graphs with -<span class="math inline">n\in [8,16)</span> vertices.</li> -<li>Test whether the transformer can accurately predict the distances -between <span class="math inline">s,t'</span> for any <span -class="math inline">t'</span> on the shortest path from <span -class="math inline">s</span> to <span class="math inline">t</span> for -graphs with <span class="math inline">n\in [16,32)</span> vertices.</li> -</ol> -<h3 id="related-work">Related Work</h3> +<h2 id="motivation">Motivation</h2> +<p>Transformers–the architecture that powers LLMs–can do incredible +feats: trained on hundreds of gigabytes of raw text, they can learn to +hold natural conversations, reason about the physical world, and write +code. Skeptics argue that LLMs are simply memorizing their datasets +without gaining any deeper understanding. For instance, GPT’s o1 model, +achieving 90th percentile on Codeforces, struggles with simple but +bizarre algorithms problems such as “find the subarray of a 2D array +with the minimum average”. In this project, we hope to explore +<strong>when off-distribution generalization happens in a +transformer</strong>. Paul Christiano proposed an experiment <a +href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">here</a> +about shortest paths in a graph to investigate this, so we decided to +become the first to implement his experiment and put transformers’ +generalization abilities to the test.</p> +<p>LLMs are notorious for making up complete nonsense, so we also hope +that our project can shed light on when truthfulness generalizes. It’s +generally desirable for LLMs to output true statements. One current +approach for ensuring this is to have a human in the loop rewarding the +model for true outputs (e.g. RLHF). However, humans can be poor judges +of truthfulness and have many cognitive biases and superficial +heuristics. A further challenge is that as LLMs become more capable, +there might not even exist experts that are good judges of whether the +models outputs, such as difficult mathematical proofs, are truthful. For +instance, most Task Rabbit workers would probably be hard pressed to +evaluate whether a difficult mathematical proof produced by an LLM is +true. The entire mathematical community has been known on occasion to <a +href="https://en.wikipedia.org/wiki/Grunwald%E2%80%93Wang_theorem">believe +false statements for many years</a>.</p> +<p>One possible solution is to reward an LLM for truthful behavior on +simple inputs, and then hope that the LLM generalizes its truthful +behavior for more complex inputs where humans cannot provide helpful +labels. Deep learning models can be remarkably good at off-distribution +generalization–for instance, a model trained to transform hand drawn +cats into images of cats might be able to handle a “cat” with three eyes +in an intuitive way. We might hope that generalizing truthfully is +simple, thus promoted by “Occam’s Razor”.</p> +<h2 id="related-work">Related Work</h2> <p>COMMENT FROM ALEK – please remove all mentions of graph neural networks – that is BS: there is no actual reason why you’d ever use a Neural network to solve shortest paths, the point of choosing a @@ -367,170 +344,113 @@ href="#ref-10.1109/TPAMI.2023.3256421" role="doc-biblioref">Tutsoy 2023</a>)</span>. We understand from his paper how GNN optimization may also be useful in researching novel diseases.</p></li> </ul> -<h2 id="methods">Methods</h2> +<h2 id="task">Task</h2> +<p>Our synthetic task is simple: compute the distance between various +vertices in an input graph. To test off-distribution generalization, our +experiment has three steps.</p> +<ol type="1"> +<li><strong>Pre-train</strong> a transformer to predict the distance +between vertices <span class="math inline">1</span> and <span +class="math inline">2</span> in graphs with <span class="math inline">n +\in [3,15]</span> vertices.</li> +</ol> +<div style="text-align:center"> +<p><img src="img/train.svg" /></p> +</div> +<ol start="2" type="1"> +<li><strong>Fine-tune</strong> a transformer to predict the distances +between vertex <span class="math inline">1</span> to <span +class="math inline">t</span> for any <span class="math inline">t</span> +on the shortest path from <span class="math inline">1</span> to <span +class="math inline">2</span>, but only do fine-tuning on graphs with +<span class="math inline">n \in [3,7]</span> vertices.</li> +</ol> +<div style="text-align:center"> +<p><img src="img/finetune.svg" /></p> +</div> +<ol start="3" type="1"> +<li><strong>Test</strong> whether the transformer can accurately predict +the distances between <span class="math inline">1</span> to <span +class="math inline">t</span> for any <span class="math inline">t \leq +7</span> on the shortest path from <span class="math inline">1</span> to +<span class="math inline">2</span> for graphs with <span +class="math inline">n \in [3,15]</span> vertices.</li> +</ol> +<div style="text-align:center"> +<p><img src="img/test.svg" /></p> +</div> <h3 id="algorithm-for-shortest-paths">Algorithm for Shortest Paths</h3> -<p>The standard algorithm to find the shortest path in a graph between a -source numbered as <span class="math inline">u</span> and sink numbered -as <span class="math inline">v</span> is <strong>breadth-first search -(BFS)</strong>. The BFS algorithm maintains a mapping of visited -vertices to their distances with respect to <span -class="math inline">u</span>, and each run of the algorithm goes through -all the vertices newly visited in the previous run, and for each vertex, -visits any of its unvisited neighbors. The algorithm terminates once -either <span class="math inline">v</span> is visited or the set of newly -visited vertices in a single run is empty.</p> -<p>We will use this algorithm to verify the accuracy of our machine -learning approach. Given <span class="math inline">V</span> vertices and -<span class="math inline">E</span> edges, the runtime of this algorithm -is thus <span class="math inline">O(V + E)</span>; however, a machine -learning approach may do better in time through parallelism, although at -the expense of using much more memory.</p> +<p>The standard algorithm to find the shortest path in a graph between +vertices <span class="math inline">u</span> and <span +class="math inline">v</span> is <strong>breadth-first search +(BFS)</strong>, taught in every intro algorithms class. Initially, BFS +starts at <span class="math inline">u</span>, and at each phase, +explores a farther layer of vertices from <span +class="math inline">u</span>. During a phase, BFS goes through every +vertex in the current layer and adds any of their unvisited neighbors to +the next layer. The algorithm terminates once we reach <span +class="math inline">v</span> or if the next layer is empty. For a graph +with <span class="math inline">V</span> vertices and <span +class="math inline">E</span> edges, the runtime of BFS is <span +class="math inline">O(V + E)</span>. BFS gives us an easy and fast way +to find the ground truth answer for any graph, so that we can verify the +accuracy of our machine learning approach.</p> +<p>We hope that our model can learn BFS or some other simple, +generalizable algorithm for shortest paths, because the model can’t just +pull some magic number out of a hat but intuitively needs to count or +list the vertices on the shortest path from <span +class="math inline">1</span> to <span class="math inline">2</span>. In +fact, we will show how to hand-craft a set of weights to implement BFS +in a transformer, so it’s indeed theoretically possible for a +transformer to achieve 100% accuracy.</p> <h3 id="data">Data</h3> -<p>We will represent an <span class="math inline">n</span> vertex, <span +<p>We’ll represent an <span class="math inline">n</span> vertex, <span class="math inline">m</span> edge unweighted, undirected graph as sequence of the endpoints of the <span class="math inline">m</span> edges, so <span class="math inline">[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]</span> represents a -graph with the edges <span class="math inline">\{(a_i,b_i)\}</span> for -<span class="math inline">1 \leq i \leq m</span>. We will pad all -sequences to be the same length using the padding token 0.</p> -<p>The full input to our model will additionally add the target vertex -after the padding tokens. The model is tasked with predicting the length -of the shortest path between vertex 1 and the target vertex <span -class="math inline">t</span>. If no such path exists, we define the -length to be <span class="math inline">n+1</span> which represents -infinity. For example, an input-output pair for our model could look -like <span class="math inline">[1, 3, 3, 2, 0, 0, 0, 0, 2]</span> and -<span class="math inline">2</span> respectively.</p> -<p>We have three separate datasets.</p> -<ul> +graph with the edges <span class="math inline">\{(a_i,b_i)\}</span> +where <span class="math inline">a_i < b_i</span> for <span +class="math inline">1 \leq i \leq m</span>. All sequences are padded to +the same length using the padding token <span +class="math inline">0</span>.</p> +<p>The full input to our model additionally includes the target vertex +<span class="math inline">t</span> after the padding tokens. The label +to an input is the length of the shortest path from <span +class="math inline">1</span> to <span class="math inline">t</span>. If +no such path exists, we define the length to be <span +class="math inline">n+1</span> which represents infinity. For example, +the input <span class="math inline">[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0, +3]</span> has the label <span class="math inline">2</span>.</p> +<div style="text-align:center"> +<p><img src="img/finetune.svg" /></p> +</div> +<p>We have three datasets for each step.</p> +<ol type="1"> <li><strong>Pre-train data</strong>: For each <span -class="math inline">n \in [8,32)</span>, we will generate several graphs -on <span class="math inline">n</span> vertices. We generate these graphs -by inserting <span class="math inline">2n</span> random edges into the -graph. We always set the target vertex to be <span -class="math inline">2</span> here.</li> +class="math inline">n \in [3,15]</span>, we generated an equal number of +graphs on <span class="math inline">n</span> vertices, with <span +class="math inline">t = 2</span>. Each graph was created by choosing +<span class="math inline">n</span> random edges.</li> <li><strong>Fine-tune data</strong>: For each <span -class="math inline">n \in [8,16)</span>, we will generate several graphs -on <span class="math inline">n</span> vertices. We generate these graphs -by inserting <span class="math inline">2n</span> random edges into the -graph. We select the target vertex to be a random vertex on the shortest -path from <span class="math inline">1</span> to <span -class="math inline">2</span>.</li> -<li><strong>Generalization testing data</strong>: The same as the -fine-tune data, except we sample <span class="math inline">n \in -[16,32)</span> instead.</li> -</ul> +class="math inline">n \in [3,7]</span>, we generated an equal number of +graphs on <span class="math inline">n</span> vertices each with a random +<span class="math inline">t</span> on the shortest path from <span +class="math inline">1</span> to <span class="math inline">2</span>. +Again, each graph was created by choosing <span +class="math inline">n</span> random edges.</li> +<li><strong>Generalization test data</strong>: The same as the fine-tune +data, except we sample <span class="math inline">n \in [3,15]</span> and +<span class="math inline">t \leq 7</span>.</li> +</ol> <p>We wrote some Python code to generate the data during the training -loop, but Python is slow and the data generation wasted a lot of time -during training. To get around this, we pre-generated the data before -training and made our Python code multithreaded to speed it up.</p> -<h3 id="architecture">Architecture</h3> -<p>TODO: honestly not much to say here since it’s a pretty typical -arch</p> -<p>We plan to use a standard transformer architecture. We will ensure -that the number of layers in our transformer is at least the diameter of -the graph. By doing this, we ensure that there is an extremely simple -circuit — namely BFS — that the transformer could in theory learn to -perform the task. Note that if the transformer actually learns a simple -circuit to perform this task, then it seems more likely to generalize -well. This is also our intuition for why it should be possible to fine -tune on a small amount of data for finding shortest paths to other -vertices besides <span class="math inline">2</span> – it seems like the -model should be computing these other distances as intermediate values -in its computation to find the distance to vertex <span -class="math inline">2</span>.</p> -<h3 id="embeddings">Embeddings</h3> -<p>Since the order of the edges in the input does not matter, we did not -use positional encodings. Each edge <span -class="math inline">(u,v)</span> where <span class="math inline">u < -v</span> is embedded to a dimension of <span -class="math inline">d</span> where the first <span -class="math inline">\frac{d}{2}</span> elements are the learned -embedding of <span class="math inline">u</span> and the last <span -class="math inline">\frac{d}{2}</span> elements are the learned -embedding of <span class="math inline">v</span>. For the target vertex -<span class="math inline">t</span>, we also embedded to dimension <span -class="math inline">d</span>, where the first <span -class="math inline">\frac{d}{2}</span> elements are the learned -embedding of <span class="math inline">t</span> and the last <span -class="math inline">\frac{d}{2}</span> are a learned embedding of a -special token.</p> -<h2 id="training">Training</h2> -<p>For our model, we used a model dimension of 64, four layers, and two -heads per layer, for a total of 200545 parameters in bfloat16 which -corresponds to around 3.2e6 bits. The number of possible graphs on 15 -vertices generated using our procedure is approximately</p> -<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} = -1.59\cdot10^{18}.</span></p> -<p>This is because there are <span -class="math inline">\binom{15}{2}</span> choices for each of the 15 -edges and we don’t care about the order of the edges. This is only an -approximation because some edges might be duplicated. Each graph has an -answer between 1 and 15 which requires around 4 bits, so memorizing all -the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} = -6.36\cdot10^{18}</span>, which is <span -class="math inline">2\cdot10^{12}</span> times larger than our model -size.</p> -<p>We used MSE loss, the Adam optimizer, a learning rate of 8e-4, and a -batch size of 131072 for 8000 unique randomly generated batches. Our -final MSE loss was approximately 0.3555.</p> -<p><img src="training-loss.png" /></p> -<p><img src="training-2d-histogram.png" /></p> -<p>One pattern we consistently noticed during training is that the model -often gets stuck and plateaus for many epochs before rapidly decreasing. -For instance, this happened between epochs 100 and 300 in the graph -above:</p> -<p><img src="grokking.png" /></p> -<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p> -<p>TODO: training curves for 1, 2, 3 length paths</p> -<h3 -id="potential-mathematical-approaches-to-shortest-paths-delete-this">Potential -Mathematical Approaches to Shortest Paths? Delete this?</h3> -<p>Another way one can think of the shortest path of a graph is using a -<em>matrix</em> to record which vertices are connected. Given vertices -numbered <span class="math inline">1</span> to <span -class="math inline">V</span>, we denote the <strong>adjacency -matrix</strong> <span class="math inline">\textbf{M}</span> of -dimensions <span class="math inline">V \times V</span> as the matrix -with element <span class="math inline">\textbf{M}_{i, j} = 1</span> if -vertices <span class="math inline">i</span> and <span -class="math inline">j</span> are connected by an edge and <span -class="math inline">\textbf{M}_{i, j} = 0</span> if they are not. Now, -we note that (1) For all <span class="math inline">k</span>, <span -class="math inline">(\textbf{M}+I)^k_{i, j} = 0</span> if and only if -there exists no path from the vertex numbered <span -class="math inline">i</span> to the vertex numbered <span -class="math inline">j</span> that is distance <span -class="math inline">k</span> or less due to Markov matrix processes. As -a result, if the distance between vertices numbered <span -class="math inline">i</span> and <span class="math inline">j</span> is -<span class="math inline">d</span>, then <span -class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = -1</span> if <span class="math inline">k \ge d</span> and <span -class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = -0</span> if <span class="math inline">k < d</span>.</p> -<p>With this information, because the distance between any two vertices -is at most <span class="math inline">V-1</span> in a graph with <span -class="math inline">V</span> vertices, we note that the -<em>distance</em> matrix turns out to be simply <span -class="math display">\textbf{D} = \textbf{1}_{V \times V} \cdot V - -\Sigma_{i=0}^{V-1}\text{min}\left((\textbf{M}+I)^k_{i, j}, -1\right).</span> The runtime to compute this is <span -class="math inline">O(V)</span>, although it will take more space to -compute all powers of <span class="math inline">\textbf{M}</span>.</p> -<h2 id="fine-tuning-results">Fine tuning results</h2> -<p>After receiving our initial results, we fine-tuned with a learning -rate of 1e-5, also with MSE and the same batch size. Our final results -are shown in the images below.</p> -<p><img src="fine-tuning-loss.png" /></p> -<p><img src="fine-tuning-2d-histogram.png" /></p> -<p><img src="test-2d-histogram.png" /></p> -<p>Memorization? Do some math here to compute how many bits required to -memorize 1, 2, 3</p> +loop, but Python is excruciatingly slow and data generation wasted a lot +of training time. Our solution was to pre-generate the data before +training using a multithreaded version of our Python code.</p> <h2 id="complicated-explicit-transformer-formula-for-shortest-paths">Complicated explicit transformer formula for shortest paths</h2> +<p>TODO: Kevin or Alek</p> <div class="sourceCode" id="cb1"><pre class="sourceCode py"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration</span></span> <span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>NVTXS <span class="op">=</span> <span class="dv">16</span></span> <span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>MAXDIST <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span> @@ -629,9 +549,6 @@ explicit transformer formula for shortest paths</h2> <span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a> canreach <span class="op">=</span> src[:, <span class="dv">0</span>, START_OUT:START_OUT <span class="op">+</span> NVTXS]</span> <span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a> final_output <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> torch.<span class="bu">sum</span>(<span class="dv">1</span> <span class="op">-</span> canreach, dim<span class="op">=</span><span class="dv">1</span>)</span> <span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> final_output</span></code></pre></div> -<p>It looked like the fine tuning results weren’t as meaningful because -TODO: ALEK</p> -<h2 id="customizing-a-transformer">Customizing a Transformer</h2> <p>After much deliberation, we decided the next step for us was to customize a transformer, writing it ourselves. We observed that we wished for the transformer to do similar work as a BFS. As a result, we @@ -686,8 +603,95 @@ class="math inline">i</span>th entry is nonzero.</p> <p>This iterates through all the vertices, and at the end, we may see what run we are on to update <span class="math inline">V_{\text{OUT}}</span>.</p> +<h2 id="perturbing-the-weights">Perturbing the Weights</h2> +<p>SLT folks like to think about geometry of loss landscape CITE So we +did an experiment where we mess with the weights.</p> +<p>Findings: XXX</p> +<h2 id="our-model">Our Model</h2> +<h3 id="architecture">Architecture</h3> +<p>We used a standard transformer architecture. To ensure that it can in +theory learn BFS, we ensured that the number of layers in our +transformer exceeds the diameter of the input graphs.</p> +<p>Since the order of the edges in the input doesn’t matter, we did not +use positional encodings. Each edge <span +class="math inline">(a,b)</span> is embedded to dimension <span +class="math inline">d</span> where the first <span +class="math inline">\frac{d}{2}</span> elements are the learned +embedding of <span class="math inline">a</span> and the last <span +class="math inline">\frac{d}{2}</span> elements are the learned +embedding of <span class="math inline">b</span>. For the target vertex +<span class="math inline">t</span>, we pair it with the special token +<span class="math inline">TARGET</span> and embed <span +class="math inline">(t,TARGET)</span> in the same way.</p> +<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm --> +<p><img src="img/embeddings.svg" /></p> +<h3 id="training">Training</h3> +<p>To match the BFS transformer as closely as possible, we used a model +dimension of <span class="math inline">64</span>, <span +class="math inline">11</span> layers, and <span +class="math inline">2</span> heads per layer, for a total of 550433 +parameters. In 32-bit float precision, that corresponds to around <span +class="math inline">1.76\cdot10^6</span> bits. The number of possible +graphs on 15 vertices generated using our procedure is approximately</p> +<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} = +1.59\cdot10^{18}.</span></p> +<p>This is because there are <span +class="math inline">\binom{15}{2}</span> choices for each of the 15 +edges and we don’t care about the order of the edges. This is only an +approximation because some edges might be duplicated. Each graph has an +answer between 1 and 15 which requires around 4 bits, so memorizing all +the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} = +6.36\cdot10^{18}</span> bits, which is <span +class="math inline">3.61\cdot10^{12}</span> times larger than our model +size.</p> +<p>To train the model, we used MSE loss, the Adam optimizer, a learning +rate of <span class="math inline">3\cdot10^{-4}</span>, and a batch size +of <span class="math inline">2^{15}</span> for one billion randomly +generated graphs. A training run takes roughly eight hours to run on a +Radeon 7900 XTX graphics card. Our final MSE loss was <span +class="math inline">0.000555</span>.</p> +<p>TODO: use https://mpld3.github.io/index.html to make interactive +plots</p> +<p><img src="training-loss.png" /></p> +<p><img src="training-2d-histogram.png" /></p> +<p>One pattern we consistently noticed during training is that the model +often gets stuck and plateaus for many epochs before rapidly decreasing. +For instance, this happened between epochs 100 and 300 in the graph +above:</p> +<p><img src="grokking.png" /></p> +<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p> +<p>TODO: cite Neel Nanda grokking modular addition</p> +<p>TODO: CRAZY!!! training curves for 1, 2, 3 length paths</p> +<p>One pitfall we encountered during training is that we initially used +bfloat16 to save VRAM, but our huge batch size caused loss-of-precision +problems and made training very difficult. It took us two weeks to debug +this until we found that switching to float32 improved training +significantly.</p> +<h2 id="fine-tuning-results">Fine tuning results</h2> +<p>After receiving our initial results, we fine-tuned with a learning +rate of 1e-5, also with MSE and the same batch size. Our final results +are shown in the images below.</p> +<p><img src="fine-tuning-loss.png" /></p> +<p><img src="fine-tuning-2d-histogram.png" /></p> +<p><img src="test-2d-histogram.png" /></p> +<p>TODO: get new graphs</p> +<p>It’s pretty good!!!</p> +<p>Can only generalize to target vertices from 2 to 7 since 8 through 15 +didn’t appear in the fine-tune data</p> +<p>but this still means it</p> <h2 id="conclusion">Conclusion</h2> +<p>however, a machine learning approach may do better in time through +parallelism, although at the expense of using much more memory. +<strong>TODO: ALEK: this is BS. If you want a parallel algorithm for +BFS, here’s one +https://en.wikipedia.org/wiki/Parallel_single-source_shortest_path_algorithm</strong></p> <p>just do bfs lol</p> +<p><strong>Future Work</strong> There are a couple of other things that +we could try to learn shortest paths better and maybe see more +generalization. - Chain of thought - Train model to output a path, not +just the distance. Give it partial points for outputting anything that +is a legitimate path (i.e., consists of real edges) and more points for +getting the distance correct.</p> <h2 class="unnumbered" id="references">References</h2> <div id="refs" class="references csl-bib-body hanging-indent" data-entry-spacing="0" role="list"> @@ -707,13 +711,6 @@ Exact Epidemiological Modeling of Pandemic Diseases.”</span> <em>IEEE Trans. Pattern Anal. Mach. Intell.</em> 45 (8): 9836–45. <a href="https://doi.org/10.1109/TPAMI.2023.3256421">https://doi.org/10.1109/TPAMI.2023.3256421</a>. </div> -<div id="ref-10.5555/3666122.3666260" class="csl-entry" role="listitem"> -Zang, Xiao, Miao Yin, Jinqi Xiao, Saman Zonouz, and Bo Yuan. 2024. -<span>“GraphMP: Graph Neural Network-Based Motion Planning with -Efficient Graph Search.”</span> In <em>Proceedings of the 37th -International Conference on Neural Information Processing Systems</em>. -NIPS ’23. Red Hook, NY, USA: Curran Associates Inc. -</div> </div> </body> </html> |