aboutsummaryrefslogtreecommitdiff
path: root/index.html
diff options
context:
space:
mode:
Diffstat (limited to 'index.html')
-rw-r--r--index.html1061
1 files changed, 694 insertions, 367 deletions
diff --git a/index.html b/index.html
index fd5f74f..b01beb3 100644
--- a/index.html
+++ b/index.html
@@ -4,7 +4,7 @@
<meta charset="utf-8" />
<meta name="generator" content="pandoc" />
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
- <title>Discovering Graph Algorithms Using Transformers</title>
+ <title>Off-Distribution Generalization of Transformers for Shortest Paths</title>
<style>
html {
color: #1a1a1a;
@@ -229,27 +229,7 @@
code span.va { color: #19177c; } /* Variable */
code span.vs { color: #4070a0; } /* VerbatimString */
code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warning */
- /* CSS for citations */
- div.csl-bib-body { }
- div.csl-entry {
- clear: both;
- margin-bottom: 0em;
- }
- .hanging-indent div.csl-entry {
- margin-left:2em;
- text-indent:-2em;
- }
- div.csl-left-margin {
- min-width:2em;
- float:left;
- }
- div.csl-right-inline {
- margin-left:2em;
- padding-left:1em;
- }
- div.csl-indent {
- margin-left: 2em;
- } </style>
+ </style>
<script defer=""
src="https://cdn.jsdelivr.net/npm/katex@0.15.1/dist/katex.min.js"></script>
<script>document.addEventListener("DOMContentLoaded", function () {
@@ -271,83 +251,78 @@
</head>
<body>
<header id="title-block-header">
-<h1 class="title">Discovering Graph Algorithms Using Transformers</h1>
+<h1 class="title">Off-Distribution Generalization of Transformers for
+Shortest Paths</h1>
</header>
+<style>
+body {
+ /* Shrink margins */
+ max-width: 45em;
+ /* font-size: 110%; */
+}
+
+/* Make sure toolbar plots aren't cut off */
+.mpld3-figure {
+ height: 480px;
+ display: block;
+ margin: auto;
+}
+</style>
+<!--
+BTW we should run a spell checker on this at some point before submitting
+note to self
+-->
<!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 -->
<div style="text-align:center">
<p>Anthony Wang, Alek Westover, Kevin Zhao</p>
<p>{xy,alekw,kevinmz}@mit.edu</p>
</div>
-<h2 id="motivation">Motivation</h2>
-<p>Transformers–the architecture that powers LLMs–can do incredible
-feats: trained on hundreds of gigabytes of raw text, they can learn to
-hold natural conversations, reason about the physical world, and write
-code. Skeptics argue that LLMs are simply memorizing their datasets
-without gaining any deeper understanding. For instance, GPT’s o1 model,
-achieving 90th percentile on Codeforces, struggles with simple but
-bizarre algorithms problems such as “find the subarray of a 2D array
-with the minimum average”. In this project, we hope to explore
-<strong>when off-distribution generalization happens in a
-transformer</strong>. Paul Christiano proposed an experiment <a
-href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">here</a>
-about shortest paths in a graph to investigate this, so we decided to
-become the first to implement his experiment and put transformers’
-generalization abilities to the test.</p>
-<p>LLMs are notorious for making up complete nonsense, so we also hope
-that our project can shed light on when truthfulness generalizes. It’s
-generally desirable for LLMs to output true statements. One current
-approach for ensuring this is to have a human in the loop rewarding the
-model for true outputs (e.g. RLHF). However, humans can be poor judges
-of truthfulness and have many cognitive biases and superficial
-heuristics. A further challenge is that as LLMs become more capable,
-there might not even exist experts that are good judges of whether the
-models outputs, such as difficult mathematical proofs, are truthful. For
-instance, most Task Rabbit workers would probably be hard pressed to
-evaluate whether a difficult mathematical proof produced by an LLM is
-true. The entire mathematical community has been known on occasion to <a
-href="https://en.wikipedia.org/wiki/Grunwald%E2%80%93Wang_theorem">believe
-false statements for many years</a>.</p>
-<p>One possible solution is to reward an LLM for truthful behavior on
-simple inputs, and then hope that the LLM generalizes its truthful
-behavior for more complex inputs where humans cannot provide helpful
-labels. Deep learning models can be remarkably good at off-distribution
-generalization–for instance, a model trained to transform hand drawn
-cats into images of cats might be able to handle a “cat” with three eyes
-in an intuitive way. We might hope that generalizing truthfully is
-simple, thus promoted by “Occam’s Razor”.</p>
-<h2 id="related-work">Related Work</h2>
-<p>COMMENT FROM ALEK – please remove all mentions of graph neural
-networks – that is BS: there is no actual reason why you’d ever use a
-Neural network to solve shortest paths, the point of choosing a
-synthetic task is because there is a <strong>simple ground
-truth</strong> which makes it easy to evaluate whether or not our model
-is performing correctly. We’d also hoped that the simplicity of the task
-would make it more feasible to do with a limited compute budget, but
-apparently this task was too hard for our architecture.</p>
-<p>There has been some research into the algorithmic optimization of
-GNNs and how they may solve real-world issues; however, none of the
-related work targets using generic machine learning methods to solve
-graph problems.</p>
-<ul>
-<li><p>Cappart et al. has researched more into the Combinatorial
-Optimization of GNNs and developed algorithms for related tasks, thus
-facilitating machine learning <span class="citation"
-data-cites="DBLP:journals/corr/abs-2102-09544">(<a
-href="#ref-DBLP:journals/corr/abs-2102-09544"
-role="doc-biblioref">Cappart et al. 2021</a>)</span>. Their results are
-mostly algorithmic so we develop further by trading a bit of accuracy
-for much faster computation in such tasks.</p></li>
-<li><p>Tutsoy uses a graph-theory-based approach to model the
-epidemiological characteristics of infectious diseases, such as COVID-19
-<span class="citation" data-cites="10.1109/TPAMI.2023.3256421">(<a
-href="#ref-10.1109/TPAMI.2023.3256421" role="doc-biblioref">Tutsoy
-2023</a>)</span>. We understand from his paper how GNN optimization may
-also be useful in researching novel diseases.</p></li>
-</ul>
-<h2 id="task">Task</h2>
-<p>Our synthetic task is simple: compute the distance between various
-vertices in an input graph. To test off-distribution generalization, our
-experiment has three steps.</p>
+<h2 id="introduction">Introduction</h2>
+<h3 id="motivation">Motivation</h3>
+<p>Neural networks are capable of impressive feats of off-distribution
+generalization. For instance, a program trained to convert sketches of
+cats into realistic looking pictures of cats can draw a cat with three
+eyes, if the sketch has three eyes, even if such an image never occured
+in its training data. However, neural networks also often learn
+non-robust features that cause it to perform very poorly
+off-distribution (e.g., adversarial examples for an image classifier).
+In this project, we will investigate the question of when transformers
+generalize off-distribution through via a case study on a simple
+synthetic task. More specifically, the goal of our project is to make
+progress towards answering the following question:</p>
+<blockquote>
+<p>Suppose you train a model to be truthful on task A, and to produce
+coherent answers on a related task B.<br />
+Will the model learn to be truthful on task B?</p>
+</blockquote>
+<p>Here’s a more specific version of this question, inspired by a <a
+href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">blog
+post written by Paul Chrisiano</a> (who was influential in developing
+RLHF).</p>
+<blockquote>
+<p>Suppose we trained a language model on next token prediction in both
+English and French, and then fine-tuned the model to be truthful in
+English. Would the model then learn to answer questions truthfully in
+French? It has been shown <a href="#fn1" class="footnote-ref"
+id="fnref1" role="doc-noteref"><sup>1</sup></a> that this doesn’t happen
+by default, but you could hope that the truthfulness training in English
+would generalize to French.</p>
+</blockquote>
+<p>In the future, humans are likely to trust AIs with performing
+important and complex tasks. For some of these tasks it will be
+infeasible for humans to audit the answers – either because human labor
+is too expensive, or because the AIs are too competent, so humans aren’t
+skilled enough to perform an audit (for example, an amateur chess player
+might have a hard time telling apart a good chess move and a bad chess
+move). However, there will be some simpler tasks where we can audit the
+AI’s and reward truthul answers. So, it’s plausible that models trained
+to be truthful in some simple domains will learn truthfulness as a
+general lesson.</p>
+<h3 id="task">Task</h3>
+<p>In order to investigate generalization of this type in transformers,
+we decided to focus on a simple synthetic task – shortest paths in a
+graph. More specifically, we ran the following experiment (see <a
+href="#data">Data</a> for more details):</p>
<ol type="1">
<li><strong>Pre-train</strong> a transformer to predict the distance
between vertices <span class="math inline">1</span> and <span
@@ -362,47 +337,289 @@ class="math inline">2</span> in graphs with <span class="math inline">n
between vertex <span class="math inline">1</span> to <span
class="math inline">t</span> for any <span class="math inline">t</span>
on the shortest path from <span class="math inline">1</span> to <span
-class="math inline">2</span>, but only do fine-tuning on graphs with
-<span class="math inline">n \in [3,7]</span> vertices.</li>
+class="math inline">2</span>, but only do fine-tuning on small graphs
+with <span class="math inline">n \in [3,7]</span> vertices. Our
+intuition was that, in training a model to find distances between
+vertices <span class="math inline">1</span> and <span
+class="math inline">2</span>, the model would also implicitly need to
+learn about distances from vertex <span class="math inline">1</span> to
+other vertices on this shortest path.</li>
</ol>
<div style="text-align:center">
<p><img src="img/finetune.svg" /></p>
</div>
<ol start="3" type="1">
-<li><strong>Test</strong> whether the transformer can accurately predict
-the distances between <span class="math inline">1</span> to <span
-class="math inline">t</span> for any <span class="math inline">t \leq
-7</span> on the shortest path from <span class="math inline">1</span> to
-<span class="math inline">2</span> for graphs with <span
-class="math inline">n \in [3,15]</span> vertices.</li>
+<li><strong>Test</strong> whether the transformer can accurately
+generalize to predicting the distance between <span
+class="math inline">1</span> to <span class="math inline">t</span> for
+any <span class="math inline">t \in [3,15]</span> on the shortest path
+from <span class="math inline">1</span> to <span
+class="math inline">2</span> for graphs with <span class="math inline">n
+\in [3,15]</span> vertices.</li>
</ol>
<div style="text-align:center">
<p><img src="img/test.svg" /></p>
</div>
-<h3 id="algorithm-for-shortest-paths">Algorithm for Shortest Paths</h3>
-<p>The standard algorithm to find the shortest path in a graph between
-vertices <span class="math inline">u</span> and <span
-class="math inline">v</span> is <strong>breadth-first search
-(BFS)</strong>, taught in every intro algorithms class. Initially, BFS
-starts at <span class="math inline">u</span>, and at each phase,
-explores a farther layer of vertices from <span
-class="math inline">u</span>. During a phase, BFS goes through every
-vertex in the current layer and adds any of their unvisited neighbors to
-the next layer. The algorithm terminates once we reach <span
-class="math inline">v</span> or if the next layer is empty. For a graph
-with <span class="math inline">V</span> vertices and <span
-class="math inline">E</span> edges, the runtime of BFS is <span
-class="math inline">O(V + E)</span>. BFS gives us an easy and fast way
-to find the ground truth answer for any graph, so that we can verify the
-accuracy of our machine learning approach.</p>
-<p>We hope that our model can learn BFS or some other simple,
-generalizable algorithm for shortest paths, because the model can’t just
-pull some magic number out of a hat but intuitively needs to count or
-list the vertices on the shortest path from <span
-class="math inline">1</span> to <span class="math inline">2</span>. In
-fact, we will show how to hand-craft a set of weights to implement BFS
-in a transformer, so it’s indeed theoretically possible for a
-transformer to achieve 100% accuracy.</p>
+<!-- ![Experiment description](traintunegeneralize.png) -->
+<p>Obviously, if you actually want to compute shortest paths in a graph,
+you should use the standard graph algorithm BFS (breadth-first search).
+Our motivation for using this problem to study generalization was
+two-fold:</p>
+<ol type="1">
+<li>It’s very simple to generate data and ground truth labels for this
+task.</li>
+<li>There is a simple circuit which computes shortest paths (exactly),
+by doing something akin to BFS. In fact, we implemented this simple
+circuit by hand in the section <a
+href="#handcrafted-bfs-transformer">Handcrafted BFS Transformer</a>. Our
+hope was that our model can also find a simple circuit that generalized
+well.</li>
+</ol>
+<!--
+ANT:
+RE t<7
+The model couuld've decided to generalize for t > 7 too, but it didn't
+so id say
+we tested generalization for all t in 3, 15
+it succeeded for t in 3,7
+failed for t in 8 to 15
+reason is because it could do "if statements" instead of realizing that it should just be looking at the dor product
+mmm
+i feel like the fact that it works for 3,7 and not 8,15 is kind of compelling
+anyways this is just a comment explaining how im going to / think you should pitch the generalization section
+(it couldve generalized all the way, but didnt, so what we learened is generalization sometimes happens but sometimes doesnt
+
+yeah that's the intuition
+can't really do interpretability to figure out if that's actually what's going on
+-->
+<h3 id="findings">Findings</h3>
+<p>Our main findings were as follows:</p>
+<ol type="1">
+<li><p>There exists a simple set of weights that you can find by hand to
+achieve arbitrarily small MSE loss on shortest paths. Moreover, this set
+of weights is “robust” in the sense that, if you perturb some of these
+weights and run gradient descent, the model will recover low
+loss.</p></li>
+<li><p>Transformers can learn on their own to perform shortest paths.
+The training process consists of a sequence of <strong>phase
+changes</strong> where the loss would barely change for a large number
+of epochs, and then over a very small number of epochs, the loss would
+dramatically go down. (TODO: Is this just because we’re bad at DL? now
+it shouldn’t have as many phase changes. Or maybe just in the
+beginning?) In particular, these phase changes corresponded to times
+when the model decided to, e.g., go from dramatically over-classifying
+paths as being length 1 (resulting in great performance on tasks of
+length 1 but poor performance on other tasks) to spreading these length
+1 guesses out to other distances.</p></li>
+<li><p>A model trained to compute the distance between two specific
+vertices can be fine-tuned efficiently to compute the shortest distance
+to other vertices.</p></li>
+<li><p>A model trained to compute the distance between vertices <span
+class="math inline">1</span> and <span class="math inline">2</span> on
+large graphs, and fine tuned to compute the distance between vertices 1
+and any vertex (on a shortest <span class="math inline">1</span>-<span
+class="math inline">2</span> path) in small graphs, has the following
+generalization properties:</p>
+<ul>
+<li>In larger graphs, the model can still compute the shortest distance
+between <span class="math inline">1</span> and any <span
+class="math inline">t</span> whose label existed in the smaller graphs,
+namely <span class="math inline">t \in [3,7]</span>.</li>
+<li>However, the model fails when the target vertex doesn’t exist in
+smaller graphs.</li>
+</ul></li>
+</ol>
+<p>The main conceptual implication of this work is that, while
+off-distribution generalization sometimes does occur, sometimes it also
+doesn’t, and it seems tricky to predict when off-distribution
+generalization will happen.</p>
+<h3 id="post-outline">Post Outline</h3>
+<p>The remainder of our post is structured as follows:</p>
+<ol type="1">
+<li><p><a href="#handcrafted-bfs-transformer">Handcrafted BFS
+Transformer</a>: To start, we explicitly construct by hand weights for a
+transformer to solve the shortest paths problem. We also investigate the
+“robustness” of our construction by showing that, starting from our
+constructed solution it’s possible to recover low loss after perturbing
+the weights via gradient descent. This gives a proof-of-concept that a
+transformer could encode a simple generalizable solution to the problem,
+and gives some reason to believe that a transformer could learn such a
+solution.</p></li>
+<li><p><a href="#training">Training</a>: Next, we discuss the training.
+First we describe the exact data generation process and model
+architecture. Then we discuss hyperparameter settings and optimization
+choices that were important to get the model to train efficiently.
+Finally, we investigate some interesting phenomena that occured during
+training, such as phase transitions in loss for different path
+lengths.</p></li>
+<li><p><a href="#generalization">Generalization</a>: Finally we discuss
+the phenomena that we wanted to study – generalization! We show that
+some generalization does occur, but other generalization doesn’t. We
+explain why we think this happens.</p></li>
+</ol>
+<h2 id="handcrafted-bfs-transformer">Handcrafted BFS Transformer</h2>
+<p>In this section we describe an explicit set of weights, made (with
+love) by hand, that result in a transformer that achieves MSE less than
+<span class="math inline">10^{-6}</span> for the shortest paths problem.
+We’ll start by describing the (somewhat involved) constructions, and
+then investigate the geometry of the loss landscape near our explicit
+solution. The main reason for doing this is to demonstrate the existence
+of a simple transformer computing shortest paths, to give
+proof-of-concept that the transformer could learn a generalizable
+solution.</p>
+<p>The construction is somewhat tricky to write down, so we’ll make
+several (non-essential) simplifying assumptions:</p>
+<ol type="1">
+<li>We only care about finding the distance between vertex 1 and vertex
+2.</li>
+<li>We are given a nice embedding of the graph.</li>
+</ol>
+<p>It will be obvious from the construction how to get rid of
+assumptions (1) and (2).</p>
+<p>The embedding of the graph that we will work with here will be
+different from the embeddings that we use in the rest of the post. In
+particular, in this section only the tokens will correspond to vertices.
+In the remainder of the post tokens will correspond to edges.</p>
+<p>More specifically, we will create the following embeddings for our
+input sequence:</p>
+<table>
+<colgroup>
+<col style="width: 25%" />
+<col style="width: 20%" />
+<col style="width: 22%" />
+<col style="width: 13%" />
+<col style="width: 18%" />
+</colgroup>
+<thead>
+<tr class="header">
+<th></th>
+<th>Answer token</th>
+<th><span class="math inline">v_{1}</span> token</th>
+<th><span class="math inline">\ldots</span></th>
+<th><span class="math inline">v_n</span> token</th>
+</tr>
+</thead>
+<tbody>
+<tr class="odd">
+<td>Answer flag</td>
+<td>1</td>
+<td>0</td>
+<td></td>
+<td>0</td>
+</tr>
+<tr class="even">
+<td>Neighbors</td>
+<td></td>
+<td>01101</td>
+<td></td>
+<td>10010</td>
+</tr>
+<tr class="odd">
+<td>Reach</td>
+<td></td>
+<td>01101</td>
+<td></td>
+<td>10010</td>
+</tr>
+<tr class="even">
+<td>Out</td>
+<td>00000</td>
+<td></td>
+<td></td>
+<td></td>
+</tr>
+<tr class="odd">
+<td>Self</td>
+<td></td>
+<td>10000</td>
+<td></td>
+<td>00001</td>
+</tr>
+<tr class="even">
+<td>Not-answer flag</td>
+<td>0</td>
+<td>1</td>
+<td></td>
+<td>1</td>
+</tr>
+</tbody>
+</table>
+<p>As you can see, the hidden dimension of the model will be <span
+class="math inline">4n+2</span>. The hidden dimension is broken up into
+5 sections, some of which are only used in the answer token, and some of
+which are only used in the vertex tokens. Here are what the different
+parts of the hidden dimension represent:</p>
+<ol type="1">
+<li>Flags: booleans indicating whether a vertex is the answer token or
+not. If desired, you can think of this as part of a “positional
+embedding”. The flags will not be updated as we move through the
+transformer.</li>
+<li>Neighbors: a boolean vector (of dimension <span
+class="math inline">n</span>) describing the neighbors of each vertex.
+(We just put zeros here for the answer token). This also won’t update
+through the course of running the transformer.</li>
+<li>Reach: a boolean vector representing the vertices that each vertex
+can reach. This will update over the course of a forward pass of the
+transformer. Specifically, at layer <span class="math inline">i</span>
+of the transformer we will extend the reach of each vertex from vertices
+which are distance <span class="math inline">i+1</span> away, to
+vertices at distance <span class="math inline">i+2</span> away.</li>
+<li>Out: this starts as zero for all tokens. Once vertex <span
+class="math inline">1</span> can reach vertex <span
+class="math inline">2</span> (i.e., we are in a layer deep enough that
+the reach vector for <span class="math inline">v_{1}</span>’s token
+indicates that it can reach vertex <span class="math inline">2</span>)
+we’ll start writing <span class="math inline">1</span> on the answer
+token’s out part of its embedding.</li>
+<li>Self: a 1-hot encoding specifying which vertex each token
+corresponds to. You should think of this as a positional embedding
+(along with the flags).</li>
+</ol>
+<p>The transformer will consist of <span class="math inline">n</span>
+layers, each of which has <span class="math inline">n+1</span> heads.
+We’ve included the code for the handcoded transformer in the <a
+href="#appendix">Appendix</a> – possibly it’s helpful to look at the
+code in conjunction with the written explanation below.</p>
+<p>In layer <span class="math inline">i</span>, our goal is to extend
+the reach of each vertex <span class="math inline">v</span> from the set
+of vertices at distance <span class="math inline">i+1</span> from <span
+class="math inline">v</span> to the set of vertices at distance <span
+class="math inline">i+2</span> from <span class="math inline">v</span>.
+The <span class="math inline">j</span>-th head will perform the
+following operation:</p>
+<pre><code>In head j at layer i:
+Vertex v&#39;s query vector will have high dot-product with vertex w&#39;s key vector
+iff vertex v can reach vertex j, and vertex w is neighbors with vertex j
+
+The value that vertex w would give to vertex v is simply a 1-hot vector, with a 1 in the w-th position</code></pre>
+<p>After running the heads, each vertex’s embedding is updated to
+integrate the new vertices that it can reach into its embedding. Note
+that we can think of embeddings as being updated because transformers
+have a residual stream of embeddings that we make modifications to.</p>
+<p>Finally, the last head will be in charge of noticing whether vertex 1
+has reached vertex 2 yet, and if so recording this fact. For the full
+details, see <a href="#appendix">appendix</a>.</p>
+<h3 id="robustness">Robustness</h3>
+<p>The field of Singular Learning Theory (SLT; see Liam Carroll’s
+Master’s thesis “Phase Transitions in Neural Networks” for an
+introduction) aims to understand model training and loss-landscape
+geometry. In efforts to better understand the loss landscape of the
+shortest paths loss function according to the tokens used in our hand
+coded implementation of the shortest paths transformers, we decided to
+start at a good setting of the parameters, and then perturb the weights,
+and see if the model can subsequently achieve low loss. The intuition
+for why this is a good approach at measuring “how attractive of a loss
+basin” we have is that this experiment is similar to the Local Learning
+Coefficient from SLT. (see Lau, Edmund, Zach Furman, George Wang, Daniel
+Murfet, and Susan Wei. “The Local Learning Coefficient: A
+Singularity-Aware Complexity Measure”). We found that, perturbing the
+weights led to high loss, but gradient descent was able to recover low
+loss, indicating that the solution is somewhat “findable” by gradient
+descent.</p>
+<p>TODO: Switch to fancy mpld3 plots <img src="perturb.png"
+alt="perturb.png" /> <img src="perturb-loss.png"
+alt="perturb-loss.png" /></p>
+<h2 id="training">Training</h2>
<h3 id="data">Data</h3>
<p>We’ll represent an <span class="math inline">n</span> vertex, <span
class="math inline">m</span> edge unweighted, undirected graph as
@@ -425,189 +642,76 @@ the input <span class="math inline">[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0,
<div style="text-align:center">
<p><img src="img/finetune.svg" /></p>
</div>
-<p>We have three datasets for each step.</p>
+<p>We have a dataset for each step.</p>
<ol type="1">
<li><strong>Pre-train data</strong>: For each <span
class="math inline">n \in [3,15]</span>, we generated an equal number of
graphs on <span class="math inline">n</span> vertices, with <span
class="math inline">t = 2</span>. Each graph was created by choosing
-<span class="math inline">n</span> random edges.</li>
-<li><strong>Fine-tune data</strong>: For each <span
+<span class="math inline">n</span> random edges. To balance the dataset,
+we removed some graphs with shorter paths to obtain the distribution
+shown below.</li>
+</ol>
+<p>TODO: Add axes titles, switch to bar plot</p>
+
+
+<style>
+
+</style>
+
+<div id="fig_el1200021399424053838726786165297"></div>
+<script>
+function mpld3_load_lib(url, callback){
+ var s = document.createElement('script');
+ s.src = url;
+ s.async = true;
+ s.onreadystatechange = s.onload = callback;
+ s.onerror = function(){console.warn("failed to load library " + url);};
+ document.getElementsByTagName("head")[0].appendChild(s);
+}
+
+if(typeof(mpld3) !== "undefined" && mpld3._mpld3IsLoaded){
+ // already loaded: just create the figure
+ !function(mpld3){
+
+ mpld3.draw_figure("fig_el1200021399424053838726786165297", {"width": 640.0, "height": 480.0, "axes": [{"bbox": [0.125, 0.10999999999999999, 0.775, 0.77], "xlim": [0.29999999999999993, 15.7], "ylim": [0.0, 7533.75], "xdomain": [0.29999999999999993, 15.7], "ydomain": [0.0, 7533.75], "xscale": "linear", "yscale": "linear", "axes": [{"position": "bottom", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}, {"position": "left", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}], "axesbg": "#FFFFFF", "axesbgalpha": null, "zoomable": true, "id": "el120002139942361034016", "lines": [], "paths": [{"data": "data01", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126032", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data02", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361031376", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data03", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942360791184", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data04", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361121472", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data05", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361120896", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data06", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361123824", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data07", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361123392", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data08", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361127904", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data09", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129296", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data10", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129056", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data11", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361128624", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361127376", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 2, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126752", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 3, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126320", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data13", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129872", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}], "markers": [], "texts": [], "collections": [], "images": [], "sharex": [], "sharey": []}], "data": {"data01": [[1.0, 0.0], [1.933333396911621, 0.0], [1.933333396911621, 7175.0], [1.0, 7175.0]], "data02": [[1.933333396911621, 0.0], [2.866666793823242, 0.0], [2.866666793823242, 6576.0], [1.933333396911621, 6576.0]], "data03": [[2.866666793823242, 0.0], [3.799999952316284, 0.0], [3.799999952316284, 4624.0], [2.866666793823242, 4624.0]], "data04": [[3.799999952316284, 0.0], [4.733333587646484, 0.0], [4.733333587646484, 3021.0], [3.799999952316284, 3021.0]], "data05": [[4.733333587646484, 0.0], [5.666666507720947, 0.0], [5.666666507720947, 2275.0], [4.733333587646484, 2275.0]], "data06": [[5.666666507720947, 0.0], [6.599999904632568, 0.0], [6.599999904632568, 1653.0], [5.666666507720947, 1653.0]], "data07": [[6.599999904632568, 0.0], [7.5333333015441895, 0.0], [7.5333333015441895, 492.0], [6.599999904632568, 492.0]], "data08": [[7.5333333015441895, 0.0], [8.466667175292969, 0.0], [8.466667175292969, 119.0], [7.5333333015441895, 119.0]], "data09": [[8.466667175292969, 0.0], [9.399999618530273, 0.0], [9.399999618530273, 28.0], [8.466667175292969, 28.0]], "data10": [[9.399999618530273, 0.0], [10.333333015441895, 0.0], [10.333333015441895, 5.0], [9.399999618530273, 5.0]], "data11": [[10.333333015441895, 0.0], [11.266666412353516, 0.0], [11.266666412353516, 1.0], [10.333333015441895, 1.0]], "data12": [[11.266666412353516, 0.0, 12.199999809265137, 13.133333206176758], [12.199999809265137, 0.0, 13.133333206176758, 14.066666603088379], [12.199999809265137, 0.0, 13.133333206176758, 14.066666603088379], [11.266666412353516, 0.0, 12.199999809265137, 13.133333206176758]], "data13": [[14.066666603088379, 0.0], [15.0, 0.0], [15.0, 6799.0], [14.066666603088379, 6799.0]]}, "id": "el120002139942405383872", "plugins": [{"type": "reset"}, {"type": "zoom", "button": true, "enabled": false}, {"type": "boxzoom", "button": true, "enabled": false}]});
+ }(mpld3);
+}else if(typeof define === "function" && define.amd){
+ // require.js is available: use it to load d3/mpld3
+ require.config({paths: {d3: "https://d3js.org/d3.v5"}});
+ require(["d3"], function(d3){
+ window.d3 = d3;
+ mpld3_load_lib("https://mpld3.github.io/js/mpld3.v0.5.10.js", function(){
+
+ mpld3.draw_figure("fig_el1200021399424053838726786165297", {"width": 640.0, "height": 480.0, "axes": [{"bbox": [0.125, 0.10999999999999999, 0.775, 0.77], "xlim": [0.29999999999999993, 15.7], "ylim": [0.0, 7533.75], "xdomain": [0.29999999999999993, 15.7], "ydomain": [0.0, 7533.75], "xscale": "linear", "yscale": "linear", "axes": [{"position": "bottom", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}, {"position": "left", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}], "axesbg": "#FFFFFF", "axesbgalpha": null, "zoomable": true, "id": "el120002139942361034016", "lines": [], "paths": [{"data": "data01", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126032", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data02", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361031376", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data03", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942360791184", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data04", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361121472", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data05", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361120896", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data06", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361123824", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data07", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361123392", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data08", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361127904", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data09", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129296", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data10", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129056", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data11", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361128624", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361127376", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 2, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126752", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 3, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126320", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data13", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129872", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}], "markers": [], "texts": [], "collections": [], "images": [], "sharex": [], "sharey": []}], "data": {"data01": [[1.0, 0.0], [1.933333396911621, 0.0], [1.933333396911621, 7175.0], [1.0, 7175.0]], "data02": [[1.933333396911621, 0.0], [2.866666793823242, 0.0], [2.866666793823242, 6576.0], [1.933333396911621, 6576.0]], "data03": [[2.866666793823242, 0.0], [3.799999952316284, 0.0], [3.799999952316284, 4624.0], [2.866666793823242, 4624.0]], "data04": [[3.799999952316284, 0.0], [4.733333587646484, 0.0], [4.733333587646484, 3021.0], [3.799999952316284, 3021.0]], "data05": [[4.733333587646484, 0.0], [5.666666507720947, 0.0], [5.666666507720947, 2275.0], [4.733333587646484, 2275.0]], "data06": [[5.666666507720947, 0.0], [6.599999904632568, 0.0], [6.599999904632568, 1653.0], [5.666666507720947, 1653.0]], "data07": [[6.599999904632568, 0.0], [7.5333333015441895, 0.0], [7.5333333015441895, 492.0], [6.599999904632568, 492.0]], "data08": [[7.5333333015441895, 0.0], [8.466667175292969, 0.0], [8.466667175292969, 119.0], [7.5333333015441895, 119.0]], "data09": [[8.466667175292969, 0.0], [9.399999618530273, 0.0], [9.399999618530273, 28.0], [8.466667175292969, 28.0]], "data10": [[9.399999618530273, 0.0], [10.333333015441895, 0.0], [10.333333015441895, 5.0], [9.399999618530273, 5.0]], "data11": [[10.333333015441895, 0.0], [11.266666412353516, 0.0], [11.266666412353516, 1.0], [10.333333015441895, 1.0]], "data12": [[11.266666412353516, 0.0, 12.199999809265137, 13.133333206176758], [12.199999809265137, 0.0, 13.133333206176758, 14.066666603088379], [12.199999809265137, 0.0, 13.133333206176758, 14.066666603088379], [11.266666412353516, 0.0, 12.199999809265137, 13.133333206176758]], "data13": [[14.066666603088379, 0.0], [15.0, 0.0], [15.0, 6799.0], [14.066666603088379, 6799.0]]}, "id": "el120002139942405383872", "plugins": [{"type": "reset"}, {"type": "zoom", "button": true, "enabled": false}, {"type": "boxzoom", "button": true, "enabled": false}]});
+ });
+ });
+}else{
+ // require.js not available: dynamically load d3 & mpld3
+ mpld3_load_lib("https://d3js.org/d3.v5.js", function(){
+ mpld3_load_lib("https://mpld3.github.io/js/mpld3.v0.5.10.js", function(){
+
+ mpld3.draw_figure("fig_el1200021399424053838726786165297", {"width": 640.0, "height": 480.0, "axes": [{"bbox": [0.125, 0.10999999999999999, 0.775, 0.77], "xlim": [0.29999999999999993, 15.7], "ylim": [0.0, 7533.75], "xdomain": [0.29999999999999993, 15.7], "ydomain": [0.0, 7533.75], "xscale": "linear", "yscale": "linear", "axes": [{"position": "bottom", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}, {"position": "left", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}], "axesbg": "#FFFFFF", "axesbgalpha": null, "zoomable": true, "id": "el120002139942361034016", "lines": [], "paths": [{"data": "data01", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126032", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data02", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361031376", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data03", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942360791184", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data04", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361121472", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data05", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361120896", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data06", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361123824", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data07", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361123392", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data08", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361127904", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data09", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129296", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data10", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129056", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data11", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361128624", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361127376", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 2, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126752", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 3, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361126320", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data13", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el120002139942361129872", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}], "markers": [], "texts": [], "collections": [], "images": [], "sharex": [], "sharey": []}], "data": {"data01": [[1.0, 0.0], [1.933333396911621, 0.0], [1.933333396911621, 7175.0], [1.0, 7175.0]], "data02": [[1.933333396911621, 0.0], [2.866666793823242, 0.0], [2.866666793823242, 6576.0], [1.933333396911621, 6576.0]], "data03": [[2.866666793823242, 0.0], [3.799999952316284, 0.0], [3.799999952316284, 4624.0], [2.866666793823242, 4624.0]], "data04": [[3.799999952316284, 0.0], [4.733333587646484, 0.0], [4.733333587646484, 3021.0], [3.799999952316284, 3021.0]], "data05": [[4.733333587646484, 0.0], [5.666666507720947, 0.0], [5.666666507720947, 2275.0], [4.733333587646484, 2275.0]], "data06": [[5.666666507720947, 0.0], [6.599999904632568, 0.0], [6.599999904632568, 1653.0], [5.666666507720947, 1653.0]], "data07": [[6.599999904632568, 0.0], [7.5333333015441895, 0.0], [7.5333333015441895, 492.0], [6.599999904632568, 492.0]], "data08": [[7.5333333015441895, 0.0], [8.466667175292969, 0.0], [8.466667175292969, 119.0], [7.5333333015441895, 119.0]], "data09": [[8.466667175292969, 0.0], [9.399999618530273, 0.0], [9.399999618530273, 28.0], [8.466667175292969, 28.0]], "data10": [[9.399999618530273, 0.0], [10.333333015441895, 0.0], [10.333333015441895, 5.0], [9.399999618530273, 5.0]], "data11": [[10.333333015441895, 0.0], [11.266666412353516, 0.0], [11.266666412353516, 1.0], [10.333333015441895, 1.0]], "data12": [[11.266666412353516, 0.0, 12.199999809265137, 13.133333206176758], [12.199999809265137, 0.0, 13.133333206176758, 14.066666603088379], [12.199999809265137, 0.0, 13.133333206176758, 14.066666603088379], [11.266666412353516, 0.0, 12.199999809265137, 13.133333206176758]], "data13": [[14.066666603088379, 0.0], [15.0, 0.0], [15.0, 6799.0], [14.066666603088379, 6799.0]]}, "id": "el120002139942405383872", "plugins": [{"type": "reset"}, {"type": "zoom", "button": true, "enabled": false}, {"type": "boxzoom", "button": true, "enabled": false}]});
+ })
+ });
+}
+</script>
+<ol start="2" type="1">
+<li><p><strong>Fine-tune data</strong>: For each <span
class="math inline">n \in [3,7]</span>, we generated an equal number of
graphs on <span class="math inline">n</span> vertices each with a random
<span class="math inline">t</span> on the shortest path from <span
class="math inline">1</span> to <span class="math inline">2</span>.
Again, each graph was created by choosing <span
-class="math inline">n</span> random edges.</li>
-<li><strong>Generalization test data</strong>: The same as the fine-tune
-data, except we sample <span class="math inline">n \in [3,15]</span> and
-<span class="math inline">t \leq 7</span>.</li>
+class="math inline">n</span> random edges.</p></li>
+<li><p><strong>Generalization test data</strong>: The same as the
+fine-tune data, except we sample <span class="math inline">n \in
+[3,15]</span> and <span class="math inline">t \leq 7</span>.</p></li>
</ol>
<p>We wrote some Python code to generate the data during the training
loop, but Python is excruciatingly slow and data generation wasted a lot
of training time. Our solution was to pre-generate the data before
training using a multithreaded version of our Python code.</p>
-<h2
-id="complicated-explicit-transformer-formula-for-shortest-paths">Complicated
-explicit transformer formula for shortest paths</h2>
-<p>TODO: Kevin or Alek</p>
-<div class="sourceCode" id="cb1"><pre class="sourceCode py"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration</span></span>
-<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>NVTXS <span class="op">=</span> <span class="dv">16</span></span>
-<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>MAXDIST <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
-<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>AVGDEG <span class="op">=</span> <span class="dv">2</span></span>
-<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>SEQLEN <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
-<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>HIDDENDIM <span class="op">=</span> <span class="dv">4</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">2</span></span>
-<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Start indices for different sections of the input data</span></span>
-<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>START_REACH <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
-<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>START_OUT <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
-<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>START_SELF <span class="op">=</span> <span class="dv">3</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
-<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>SRC_FLAG_IDX <span class="op">=</span> START_SELF</span>
-<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>ANS_FLAG_IDX <span class="op">=</span> <span class="dv">0</span></span>
-<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>NOTANS_FLAG_IDX <span class="op">=</span> <span class="op">-</span><span class="dv">1</span></span>
-<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>BIG <span class="op">=</span> <span class="dv">20</span></span>
-<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>SUPABIG <span class="op">=</span> <span class="dv">100</span></span>
-<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>MED <span class="op">=</span> <span class="dv">10</span></span>
-<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>CURSE <span class="op">=</span> <span class="dv">5</span></span>
-<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SillyTransformer(nn.Module):</span>
-<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, device):</span>
-<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a> <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
-<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.device <span class="op">=</span> device</span>
-<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a> <span class="cf">with</span> torch.no_grad():</span>
-<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a> <span class="co"># Initialize weight parameters with specific configurations</span></span>
-<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostKs <span class="op">=</span> nn.ParameterList()</span>
-<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostQs <span class="op">=</span> nn.ParameterList()</span>
-<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostVs <span class="op">=</span> nn.ParameterList()</span>
-<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> head <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">1</span>):</span>
-<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a> Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
-<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a> Q[<span class="dv">0</span>, START_REACH <span class="op">-</span> <span class="dv">1</span> <span class="op">+</span> head] <span class="op">=</span> SUPABIG</span>
-<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a> Q[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
-<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a> K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
-<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, head] <span class="op">=</span> <span class="dv">1</span></span>
-<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> BIG</span>
-<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a> V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span>
-<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
-<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a> V[i, START_SELF <span class="op">+</span> i] <span class="op">=</span> <span class="dv">1</span></span>
-<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostKs.append(K)</span>
-<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostQs.append(Q)</span>
-<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostVs.append(V)</span>
-<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdKs <span class="op">=</span> nn.ParameterList()</span>
-<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdQs <span class="op">=</span> nn.ParameterList()</span>
-<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdVs <span class="op">=</span> nn.ParameterList()</span>
-<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
-<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a> K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
-<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG</span>
-<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, SRC_FLAG_IDX] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span>
-<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>SUPABIG</span>
-<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">2</span>] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span>
-<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG<span class="op">-</span>SUPABIG</span>
-<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">2</span>, ANS_FLAG_IDX] <span class="op">=</span> MED</span>
-<span id="cb1-59"><a href="#cb1-59" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-60"><a href="#cb1-60" aria-hidden="true" tabindex="-1"></a> Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
-<span id="cb1-61"><a href="#cb1-61" aria-hidden="true" tabindex="-1"></a> Q[:, ANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
-<span id="cb1-62"><a href="#cb1-62" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-63"><a href="#cb1-63" aria-hidden="true" tabindex="-1"></a> V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span>
-<span id="cb1-64"><a href="#cb1-64" aria-hidden="true" tabindex="-1"></a> V[layer, SRC_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
-<span id="cb1-65"><a href="#cb1-65" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-66"><a href="#cb1-66" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdKs.append(K)</span>
-<span id="cb1-67"><a href="#cb1-67" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdQs.append(Q)</span>
-<span id="cb1-68"><a href="#cb1-68" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdVs.append(V)</span>
-<span id="cb1-69"><a href="#cb1-69" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-70"><a href="#cb1-70" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> forward(<span class="va">self</span>, src):</span>
-<span id="cb1-71"><a href="#cb1-71" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
-<span id="cb1-72"><a href="#cb1-72" aria-hidden="true" tabindex="-1"></a> allKs <span class="op">=</span> [<span class="va">self</span>.weirdKs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostKs]</span>
-<span id="cb1-73"><a href="#cb1-73" aria-hidden="true" tabindex="-1"></a> allQs <span class="op">=</span> [<span class="va">self</span>.weirdQs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostQs]</span>
-<span id="cb1-74"><a href="#cb1-74" aria-hidden="true" tabindex="-1"></a> allVs <span class="op">=</span> [<span class="va">self</span>.weirdVs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostVs]</span>
-<span id="cb1-75"><a href="#cb1-75" aria-hidden="true" tabindex="-1"></a> head_outputs <span class="op">=</span> []</span>
-<span id="cb1-76"><a href="#cb1-76" aria-hidden="true" tabindex="-1"></a> </span>
-<span id="cb1-77"><a href="#cb1-77" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> (K, Q, V) <span class="kw">in</span> <span class="bu">zip</span>(allKs, allQs, allVs):</span>
-<span id="cb1-78"><a href="#cb1-78" aria-hidden="true" tabindex="-1"></a> ksrc <span class="op">=</span> torch.matmul(src, K.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
-<span id="cb1-79"><a href="#cb1-79" aria-hidden="true" tabindex="-1"></a> qsrc <span class="op">=</span> torch.matmul(src, Q.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
-<span id="cb1-80"><a href="#cb1-80" aria-hidden="true" tabindex="-1"></a> vsrc <span class="op">=</span> torch.matmul(src, V.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
-<span id="cb1-81"><a href="#cb1-81" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-82"><a href="#cb1-82" aria-hidden="true" tabindex="-1"></a> scores <span class="op">=</span> torch.matmul(qsrc, ksrc.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
-<span id="cb1-83"><a href="#cb1-83" aria-hidden="true" tabindex="-1"></a> attention_weights <span class="op">=</span> torch.softmax(scores, dim<span class="op">=-</span><span class="dv">1</span>)</span>
-<span id="cb1-84"><a href="#cb1-84" aria-hidden="true" tabindex="-1"></a> head_output <span class="op">=</span> torch.matmul(attention_weights, vsrc)</span>
-<span id="cb1-85"><a href="#cb1-85" aria-hidden="true" tabindex="-1"></a> head_outputs.append(head_output)</span>
-<span id="cb1-86"><a href="#cb1-86" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-87"><a href="#cb1-87" aria-hidden="true" tabindex="-1"></a> new_reaches <span class="op">=</span> <span class="bu">sum</span>(head_outputs[<span class="dv">1</span>:])</span>
-<span id="cb1-88"><a href="#cb1-88" aria-hidden="true" tabindex="-1"></a> BSZ <span class="op">=</span> new_reaches.shape[<span class="dv">0</span>]</span>
-<span id="cb1-89"><a href="#cb1-89" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-90"><a href="#cb1-90" aria-hidden="true" tabindex="-1"></a> nodelta_nbrs <span class="op">=</span> torch.zeros((BSZ, SEQLEN, NVTXS <span class="op">+</span> <span class="dv">1</span>), device<span class="op">=</span><span class="va">self</span>.device)</span>
-<span id="cb1-91"><a href="#cb1-91" aria-hidden="true" tabindex="-1"></a> morepadlol <span class="op">=</span> torch.zeros((BSZ, SEQLEN, <span class="dv">1</span> <span class="op">+</span> NVTXS), device<span class="op">=</span><span class="va">self</span>.device)</span>
-<span id="cb1-92"><a href="#cb1-92" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-93"><a href="#cb1-93" aria-hidden="true" tabindex="-1"></a> src <span class="op">=</span> src <span class="op">+</span> torch.cat((nodelta_nbrs, new_reaches, head_outputs[<span class="dv">0</span>], morepadlol), dim<span class="op">=</span><span class="dv">2</span>)</span>
-<span id="cb1-94"><a href="#cb1-94" aria-hidden="true" tabindex="-1"></a> src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> torch.sigmoid(src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">*</span> CURSE) <span class="op">-</span> <span class="dv">1</span></span>
-<span id="cb1-95"><a href="#cb1-95" aria-hidden="true" tabindex="-1"></a></span>
-<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a> canreach <span class="op">=</span> src[:, <span class="dv">0</span>, START_OUT:START_OUT <span class="op">+</span> NVTXS]</span>
-<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a> final_output <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> torch.<span class="bu">sum</span>(<span class="dv">1</span> <span class="op">-</span> canreach, dim<span class="op">=</span><span class="dv">1</span>)</span>
-<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> final_output</span></code></pre></div>
-<p>After much deliberation, we decided the next step for us was to
-customize a transformer, writing it ourselves. We observed that we
-wished for the transformer to do similar work as a BFS. As a result, we
-decided to work with the following transformer, for a graph with <span
-class="math inline">n</span> vertices <span class="math inline">v_1,
-v_2, \cdots, v_n</span>:</p>
-<span class="math display">\begin{array}{|c|c|c|c|c|c}
-\text{ANS} &amp; v_{1} &amp; v_{2} &amp; \cdots &amp; v_{n} &amp; \\
-\hline
-1 &amp; 0 &amp; 0 &amp; \cdots &amp; 0 &amp; \text{ANS}\\ \hline
-\text{ANS} &amp; \text{NBR}_{1} &amp; \text{NBR}_{2} &amp; \cdots &amp;
-\text{NBR}_{n} &amp; \text{NBR}\\ \hline
-\text{ANS} &amp; \text{REACH}_{1} &amp; \text{REACH}_{2} &amp; \cdots
-&amp; \text{REACH}_{n} &amp; \text{REACH}\\ \hline
-\text{ANS} &amp; \text{SELF}_{1} &amp; \text{SELF}_{2} &amp; \cdots
-&amp; \text{SELF}_{n} &amp; \text{SELF}\\ \hline
-V_{\text{OUT}} &amp; NULL&amp; NULL&amp; NULL&amp; NULL&amp;
-\text{OUT}\\ \hline
-0 &amp; 1 &amp; 1 &amp; \cdots &amp;1 &amp; \text{NOT}\\ \hline
-\end{array}</span>
-<p>Specifically, we see that <span
-class="math inline">\text{NBR}_{i}</span> is a <span
-class="math inline">n \times 1</span> vector detailing which of the
-vertices are neighboring vertex <span class="math inline">v_i</span>, so
-the <span class="math inline">j</span>th element of <span
-class="math inline">v_i</span> is <span class="math inline">1</span> if
-<span class="math inline">v_i</span> and <span
-class="math inline">v_j</span> are neighboring vertices, and <span
-class="math inline">0</span> otherwise. Additionally, <span
-class="math inline">\text{SELF}_{i}</span> is just the <span
-class="math inline">n \times 1</span> vector with the <span
-class="math inline">i</span>th element <span
-class="math inline">1</span> and all other elements <span
-class="math inline">0</span> (e.g. the one-hot encoding of the vector).
-Now, at every step, the <span class="math inline">\text{REACH}_k</span>
-vector for all <span class="math inline">k</span> is updated based on
-the previous <span class="math inline">\text{REACH}_k</span> vector and
-<span class="math inline">\text{NBR}_{k}</span> (since all entries that
-are <span class="math inline">1</span> in <span
-class="math inline">\text{REACH}_k\text{NBR}_{k}^T</span> must be
-updated in the manner such that if the <span class="math inline">(i,
-j)</span>th element of <span
-class="math inline">\text{REACH}_k\text{NBR}_{k}^T</span> is <span
-class="math inline">1</span>, then <span
-class="math inline">\text{REACH}_i</span>’s <span
-class="math inline">j</span>th column is set to <span
-class="math inline">1</span>. This is equivalent to adding <span
-class="math inline">\text{REACH}_k</span> to each integer <span
-class="math inline">i</span> where <span
-class="math inline">\text{NBR}_{k}</span>’s <span
-class="math inline">i</span>th entry is nonzero.</p>
-<p>This iterates through all the vertices, and at the end, we may see
-what run we are on to update <span
-class="math inline">V_{\text{OUT}}</span>.</p>
-<h2 id="perturbing-the-weights">Perturbing the Weights</h2>
-<p>SLT folks like to think about geometry of loss landscape CITE So we
-did an experiment where we mess with the weights.</p>
-<p>Findings: XXX</p>
-<h2 id="our-model">Our Model</h2>
<h3 id="architecture">Architecture</h3>
<p>We used a standard transformer architecture. To ensure that it can in
theory learn BFS, we ensured that the number of layers in our
@@ -625,92 +729,315 @@ embedding of <span class="math inline">b</span>. For the target vertex
class="math inline">(t,TARGET)</span> in the same way.</p>
<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm -->
<p><img src="img/embeddings.svg" /></p>
-<h3 id="training">Training</h3>
-<p>To match the BFS transformer as closely as possible, we used a model
-dimension of <span class="math inline">64</span>, <span
-class="math inline">11</span> layers, and <span
-class="math inline">2</span> heads per layer, for a total of 550433
-parameters. In 32-bit float precision, that corresponds to around <span
-class="math inline">1.76\cdot10^6</span> bits. The number of possible
-graphs on 15 vertices generated using our procedure is approximately</p>
-<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} =
-1.59\cdot10^{18}.</span></p>
-<p>This is because there are <span
+<h3 id="specifications">Specifications</h3>
+<p>For our training run, we used the following specifications:</p>
+<table>
+<thead>
+<tr class="header">
+<th>Hyperparameter</th>
+<th>Value</th>
+</tr>
+</thead>
+<tbody>
+<tr class="odd">
+<td>Model dimension</td>
+<td><span class="math inline">64</span></td>
+</tr>
+<tr class="even">
+<td>Layers</td>
+<td><span class="math inline">11</span></td>
+</tr>
+<tr class="odd">
+<td>Heads</td>
+<td><span class="math inline">2</span></td>
+</tr>
+<tr class="even">
+<td>Epochs</td>
+<td><span class="math inline">100</span></td>
+</tr>
+<tr class="odd">
+<td>Batch size</td>
+<td><span class="math inline">2^{15}</span></td>
+</tr>
+<tr class="even">
+<td>Dropout</td>
+<td><span class="math inline">0</span></td>
+</tr>
+<tr class="odd">
+<td>Training examples</td>
+<td><span class="math inline">10^8</span></td>
+</tr>
+<tr class="even">
+<td>Parameters</td>
+<td><span class="math inline">550433</span></td>
+</tr>
+<tr class="odd">
+<td>Learning rate</td>
+<td><span class="math inline">5\cdot10^{-4}</span></td>
+</tr>
+<tr class="even">
+<td>Loss</td>
+<td>MSE</td>
+</tr>
+<tr class="odd">
+<td>Optimizer</td>
+<td>Adam</td>
+</tr>
+</tbody>
+</table>
+<p>The number of bits required to store the model parameters in float32
+is around <span class="math inline">1.76\cdot10^6</span>. The number of
+possible graphs on 15 vertices generated using our procedure is
+approximately <span class="math display">\frac{\binom{15}{2}^{15}}{15!}
+\approx 1.59\cdot10^{18}.</span> This is because there are <span
class="math inline">\binom{15}{2}</span> choices for each of the 15
edges and we don’t care about the order of the edges. This is only an
approximation because some edges might be duplicated. Each graph has an
answer between 1 and 15 which requires around 4 bits, so memorizing all
-the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} =
-6.36\cdot10^{18}</span> bits, which is <span
-class="math inline">3.61\cdot10^{12}</span> times larger than our model
-size.</p>
-<p>To train the model, we used MSE loss, the Adam optimizer, a learning
-rate of <span class="math inline">3\cdot10^{-4}</span>, and a batch size
-of <span class="math inline">2^{15}</span> for one billion randomly
-generated graphs. A training run takes roughly eight hours to run on a
-Radeon 7900 XTX graphics card. Our final MSE loss was <span
-class="math inline">0.000555</span>.</p>
-<p>TODO: use https://mpld3.github.io/index.html to make interactive
-plots</p>
-<p><img src="training-loss.png" /></p>
-<p><img src="training-2d-histogram.png" /></p>
-<p>One pattern we consistently noticed during training is that the model
-often gets stuck and plateaus for many epochs before rapidly decreasing.
-For instance, this happened between epochs 100 and 300 in the graph
-above:</p>
-<p><img src="grokking.png" /></p>
-<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p>
-<p>TODO: cite Neel Nanda grokking modular addition</p>
-<p>TODO: CRAZY!!! training curves for 1, 2, 3 length paths</p>
+the answers requires requires <span
+class="math inline">4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}</span>
+bits, which is <span class="math inline">3.61\cdot10^{12}</span> times
+larger than our model size. This implies that in order to get really low
+loss, our model needs to do something other than brute memorization.</p>
+<p>A single training run takes roughly eight hours to run on a Radeon
+7900 XTX graphics card.</p>
<p>One pitfall we encountered during training is that we initially used
bfloat16 to save VRAM, but our huge batch size caused loss-of-precision
problems and made training very difficult. It took us two weeks to debug
this until we found that switching to float32 improved training
significantly.</p>
-<h2 id="fine-tuning-results">Fine tuning results</h2>
+<h3 id="training-results">Training Results</h3>
+<p>Here is our training loss curve:</p>
+<p>$include<code>raw="html"</code> plots/train-loss.html</p>
+<p>Here is a “confusion matrix”-like plot, showing for each true label,
+the probability distribution over classifications that the model
+gives.</p>
+<p>$include<code>raw="html"</code> plots/train-hist.html</p>
+<p>One pattern we consistently noticed during training is that the model
+often gets stuck and plateaus for many epochs before rapidly decreasing.
+For instance, this happened between epochs 100 and 300 in the graph
+above:</p>
+<p>$include<code>raw="html"</code> plots/grokking.html</p>
+<p>Such <strong>phase change</strong> phenomena are ubiquitous in deep
+learning and for simple algorithmic tasks we can sometimes interpret
+them <a href="#fn2" class="footnote-ref" id="fnref2"
+role="doc-noteref"><sup>2</sup></a>.</p>
+<p><strong>TODO ANT: training curves for 1, 2, 3 length paths. and add
+some text about why this is an interesting finding.</strong></p>
+<p>$include<code>raw="html"</code> plots/len-loss.html</p>
+<h2 id="generalization">Generalization</h2>
+<h3 id="fine-tuning-results">Fine Tuning Results</h3>
<p>After receiving our initial results, we fine-tuned with a learning
rate of 1e-5, also with MSE and the same batch size. Our final results
are shown in the images below.</p>
-<p><img src="fine-tuning-loss.png" /></p>
-<p><img src="fine-tuning-2d-histogram.png" /></p>
-<p><img src="test-2d-histogram.png" /></p>
-<p>TODO: get new graphs</p>
-<p>It’s pretty good!!!</p>
-<p>Can only generalize to target vertices from 2 to 7 since 8 through 15
-didn’t appear in the fine-tune data</p>
-<p>but this still means it</p>
+<p>$include<code>raw="html"</code> plots/tune-loss.html</p>
+<p>$include<code>raw="html"</code> plots/tune-hist.html</p>
+<p>$include<code>raw="html"</code>
+plots/test-onpath-largetarget.html</p>
+<p>$include<code>raw="html"</code>
+plots/test-onpath-smalltarget.html</p>
+<p>$include<code>raw="html"</code> plots/test-small-any.html</p>
+<p>$include<code>raw="html"</code> plots/test-large-any.html</p>
<h2 id="conclusion">Conclusion</h2>
-<p>however, a machine learning approach may do better in time through
-parallelism, although at the expense of using much more memory.
-<strong>TODO: ALEK: this is BS. If you want a parallel algorithm for
-BFS, here’s one
-https://en.wikipedia.org/wiki/Parallel_single-source_shortest_path_algorithm</strong></p>
-<p>just do bfs lol</p>
-<p><strong>Future Work</strong> There are a couple of other things that
-we could try to learn shortest paths better and maybe see more
-generalization. - Chain of thought - Train model to output a path, not
-just the distance. Give it partial points for outputting anything that
-is a legitimate path (i.e., consists of real edges) and more points for
-getting the distance correct.</p>
-<h2 class="unnumbered" id="references">References</h2>
-<div id="refs" class="references csl-bib-body hanging-indent"
-data-entry-spacing="0" role="list">
-<div id="ref-DBLP:journals/corr/abs-2102-09544" class="csl-entry"
-role="listitem">
-Cappart, Quentin, Didier Chételat, Elias B. Khalil, Andrea Lodi,
-Christopher Morris, and Petar Velickovic. 2021. <span>“Combinatorial
-Optimization and Reasoning with Graph Neural Networks.”</span>
-<em>CoRR</em> abs/2102.09544. <a
-href="https://arxiv.org/abs/2102.09544">https://arxiv.org/abs/2102.09544</a>.
-</div>
-<div id="ref-10.1109/TPAMI.2023.3256421" class="csl-entry"
-role="listitem">
-Tutsoy, Onder. 2023. <span>“Graph Theory Based Large-Scale Machine
-Learning with Multi-Dimensional Constrained Optimization Approaches for
-Exact Epidemiological Modeling of Pandemic Diseases.”</span> <em>IEEE
-Trans. Pattern Anal. Mach. Intell.</em> 45 (8): 9836–45. <a
-href="https://doi.org/10.1109/TPAMI.2023.3256421">https://doi.org/10.1109/TPAMI.2023.3256421</a>.
-</div>
-</div>
+<p>In this post we’ve investigated solving the shortest paths problem
+with a transformer. We’ve shown that this problem has a simple solution
+that can be implemented by a transformer. We’ve shown that a transformer
+can learn to solve this problem. We’ve shown that when the transformer
+learns to solve this problem it also implicitly builds good internal
+representations of the input data that allow for efficient fine tuning
+to find shortest paths between other vertices. Finally, we’ve seen that
+some off-distribution generalization does occur, and some does not! The
+main conceptual take-away from our work is that it’s hard to predict
+when models will and won’t generalize.</p>
+<h2 id="appendix">Appendix</h2>
+<div class="sourceCode" id="cb2"><pre
+class="sourceCode python"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
+<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
+<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
+<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> deque</span>
+<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration</span></span>
+<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>NVTXS <span class="op">=</span> <span class="dv">8</span></span>
+<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>MAXDIST <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
+<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>AVGDEG <span class="op">=</span> <span class="dv">2</span></span>
+<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>SEQLEN <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
+<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>HIDDENDIM <span class="op">=</span> <span class="dv">4</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">2</span></span>
+<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>START_REACH <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
+<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>START_OUT <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
+<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>START_SELF <span class="op">=</span> <span class="dv">3</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
+<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>SRC_FLAG_IDX <span class="op">=</span> START_SELF</span>
+<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>ANS_FLAG_IDX <span class="op">=</span> <span class="dv">0</span></span>
+<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>NOTANS_FLAG_IDX <span class="op">=</span> <span class="op">-</span><span class="dv">1</span></span>
+<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Determine device</span></span>
+<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">&quot;cuda&quot;</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">&quot;cpu&quot;</span>)</span>
+<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> random_graph(device):</span>
+<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;Generate a random graph tensor.&quot;&quot;&quot;</span></span>
+<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a> data <span class="op">=</span> torch.zeros((SEQLEN, HIDDENDIM), device<span class="op">=</span>device)</span>
+<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a> </span>
+<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a> <span class="co"># Mark self vertices</span></span>
+<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">1</span>):</span>
+<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a> data[i, START_SELF <span class="op">-</span> <span class="dv">1</span> <span class="op">+</span> i] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a> <span class="co"># Create adjacency list</span></span>
+<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a> adj_list <span class="op">=</span> [<span class="bu">set</span>() <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(SEQLEN)]</span>
+<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a> indices <span class="op">=</span> [random.randint(<span class="dv">1</span>, NVTXS) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(AVGDEG <span class="op">*</span> NVTXS)]</span>
+<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a> </span>
+<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(indices), <span class="dv">2</span>):</span>
+<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a> u <span class="op">=</span> indices[i]</span>
+<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a> v <span class="op">=</span> indices[i <span class="op">+</span> <span class="dv">1</span>]</span>
+<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a> <span class="cf">if</span> u <span class="op">!=</span> v:</span>
+<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a> <span class="co"># Bidirectional connections</span></span>
+<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a> data[v, u] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a> data[u, v] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a> data[v, NVTXS <span class="op">+</span> u] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a> data[u, NVTXS <span class="op">+</span> v] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a> adj_list[u].add(v)</span>
+<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a> adj_list[v].add(u)</span>
+<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a> <span class="co"># Set flags</span></span>
+<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a> data[<span class="dv">0</span>, ANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a> data[<span class="dv">1</span>:, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a> data[<span class="dv">0</span>, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> data, adj_list</span>
+<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> SSSP(G):</span>
+<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;Single Source Shortest Path algorithm.&quot;&quot;&quot;</span></span>
+<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a> dist <span class="op">=</span> [MAXDIST <span class="cf">for</span> _ <span class="kw">in</span> G]</span>
+<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a> dist[<span class="dv">1</span>] <span class="op">=</span> <span class="dv">0</span></span>
+<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a> frontier <span class="op">=</span> deque([<span class="dv">1</span>])</span>
+<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a> <span class="cf">while</span> frontier:</span>
+<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a> vtx <span class="op">=</span> frontier.popleft()</span>
+<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> x <span class="kw">in</span> G[vtx]:</span>
+<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a> <span class="cf">if</span> dist[x] <span class="op">==</span> MAXDIST:</span>
+<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a> dist[x] <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> dist[vtx]</span>
+<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a> frontier.append(x)</span>
+<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a> <span class="cf">if</span> x <span class="op">==</span> <span class="dv">2</span>:</span>
+<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> dist[<span class="dv">2</span>]</span>
+<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> MAXDIST</span>
+<span id="cb2-66"><a href="#cb2-66" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-67"><a href="#cb2-67" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> mkbatch(size):</span>
+<span id="cb2-68"><a href="#cb2-68" aria-hidden="true" tabindex="-1"></a> <span class="co">&quot;&quot;&quot;Create a batch of graph data.&quot;&quot;&quot;</span></span>
+<span id="cb2-69"><a href="#cb2-69" aria-hidden="true" tabindex="-1"></a> graphs <span class="op">=</span> []</span>
+<span id="cb2-70"><a href="#cb2-70" aria-hidden="true" tabindex="-1"></a> distances <span class="op">=</span> []</span>
+<span id="cb2-71"><a href="#cb2-71" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-72"><a href="#cb2-72" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size):</span>
+<span id="cb2-73"><a href="#cb2-73" aria-hidden="true" tabindex="-1"></a> data, adj_list <span class="op">=</span> random_graph(device)</span>
+<span id="cb2-74"><a href="#cb2-74" aria-hidden="true" tabindex="-1"></a> dist <span class="op">=</span> SSSP(adj_list)</span>
+<span id="cb2-75"><a href="#cb2-75" aria-hidden="true" tabindex="-1"></a> graphs.append(data)</span>
+<span id="cb2-76"><a href="#cb2-76" aria-hidden="true" tabindex="-1"></a> distances.append(dist)</span>
+<span id="cb2-77"><a href="#cb2-77" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-78"><a href="#cb2-78" aria-hidden="true" tabindex="-1"></a> data <span class="op">=</span> torch.stack(graphs)</span>
+<span id="cb2-79"><a href="#cb2-79" aria-hidden="true" tabindex="-1"></a> labels <span class="op">=</span> torch.tensor(distances, dtype<span class="op">=</span>torch.float32, device<span class="op">=</span>device)</span>
+<span id="cb2-80"><a href="#cb2-80" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> data, labels</span>
+<span id="cb2-81"><a href="#cb2-81" aria-hidden="true" tabindex="-1"></a> </span>
+<span id="cb2-82"><a href="#cb2-82" aria-hidden="true" tabindex="-1"></a>BIG,SUPABIG,MED,CURSE <span class="op">=</span> <span class="dv">12</span>,<span class="dv">30</span>,<span class="dv">7</span>,<span class="dv">5</span></span>
+<span id="cb2-83"><a href="#cb2-83" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-84"><a href="#cb2-84" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SillyTransformer(nn.Module):</span>
+<span id="cb2-85"><a href="#cb2-85" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, device):</span>
+<span id="cb2-86"><a href="#cb2-86" aria-hidden="true" tabindex="-1"></a> <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
+<span id="cb2-87"><a href="#cb2-87" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.device <span class="op">=</span> device</span>
+<span id="cb2-88"><a href="#cb2-88" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-89"><a href="#cb2-89" aria-hidden="true" tabindex="-1"></a> <span class="cf">with</span> torch.no_grad():</span>
+<span id="cb2-90"><a href="#cb2-90" aria-hidden="true" tabindex="-1"></a> <span class="co"># Initialize weight parameters with specific configurations</span></span>
+<span id="cb2-91"><a href="#cb2-91" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostKs <span class="op">=</span> nn.ParameterList()</span>
+<span id="cb2-92"><a href="#cb2-92" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostQs <span class="op">=</span> nn.ParameterList()</span>
+<span id="cb2-93"><a href="#cb2-93" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostVs <span class="op">=</span> nn.ParameterList()</span>
+<span id="cb2-94"><a href="#cb2-94" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> head <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">1</span>):</span>
+<span id="cb2-95"><a href="#cb2-95" aria-hidden="true" tabindex="-1"></a> Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
+<span id="cb2-96"><a href="#cb2-96" aria-hidden="true" tabindex="-1"></a> Q[<span class="dv">0</span>, START_REACH <span class="op">-</span> <span class="dv">1</span> <span class="op">+</span> head] <span class="op">=</span> SUPABIG</span>
+<span id="cb2-97"><a href="#cb2-97" aria-hidden="true" tabindex="-1"></a> Q[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-98"><a href="#cb2-98" aria-hidden="true" tabindex="-1"></a> K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
+<span id="cb2-99"><a href="#cb2-99" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, head] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-100"><a href="#cb2-100" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> BIG</span>
+<span id="cb2-101"><a href="#cb2-101" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-102"><a href="#cb2-102" aria-hidden="true" tabindex="-1"></a> V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span>
+<span id="cb2-103"><a href="#cb2-103" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
+<span id="cb2-104"><a href="#cb2-104" aria-hidden="true" tabindex="-1"></a> V[i, START_SELF <span class="op">+</span> i] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-105"><a href="#cb2-105" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-106"><a href="#cb2-106" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostKs.append(K)</span>
+<span id="cb2-107"><a href="#cb2-107" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostQs.append(Q)</span>
+<span id="cb2-108"><a href="#cb2-108" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostVs.append(V)</span>
+<span id="cb2-109"><a href="#cb2-109" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdKs <span class="op">=</span> nn.ParameterList()</span>
+<span id="cb2-110"><a href="#cb2-110" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdQs <span class="op">=</span> nn.ParameterList()</span>
+<span id="cb2-111"><a href="#cb2-111" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdVs <span class="op">=</span> nn.ParameterList()</span>
+<span id="cb2-112"><a href="#cb2-112" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
+<span id="cb2-113"><a href="#cb2-113" aria-hidden="true" tabindex="-1"></a> K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
+<span id="cb2-114"><a href="#cb2-114" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG</span>
+<span id="cb2-115"><a href="#cb2-115" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, SRC_FLAG_IDX] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span>
+<span id="cb2-116"><a href="#cb2-116" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>SUPABIG</span>
+<span id="cb2-117"><a href="#cb2-117" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">2</span>] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span>
+<span id="cb2-118"><a href="#cb2-118" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG<span class="op">-</span>SUPABIG</span>
+<span id="cb2-119"><a href="#cb2-119" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">2</span>, ANS_FLAG_IDX] <span class="op">=</span> MED</span>
+<span id="cb2-120"><a href="#cb2-120" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-121"><a href="#cb2-121" aria-hidden="true" tabindex="-1"></a> Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
+<span id="cb2-122"><a href="#cb2-122" aria-hidden="true" tabindex="-1"></a> Q[:, ANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-123"><a href="#cb2-123" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-124"><a href="#cb2-124" aria-hidden="true" tabindex="-1"></a> V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span>
+<span id="cb2-125"><a href="#cb2-125" aria-hidden="true" tabindex="-1"></a> V[layer, SRC_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
+<span id="cb2-126"><a href="#cb2-126" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-127"><a href="#cb2-127" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdKs.append(K)</span>
+<span id="cb2-128"><a href="#cb2-128" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdQs.append(Q)</span>
+<span id="cb2-129"><a href="#cb2-129" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdVs.append(V)</span>
+<span id="cb2-130"><a href="#cb2-130" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-131"><a href="#cb2-131" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> forward(<span class="va">self</span>, src):</span>
+<span id="cb2-132"><a href="#cb2-132" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
+<span id="cb2-133"><a href="#cb2-133" aria-hidden="true" tabindex="-1"></a> allKs <span class="op">=</span> [<span class="va">self</span>.weirdKs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostKs]</span>
+<span id="cb2-134"><a href="#cb2-134" aria-hidden="true" tabindex="-1"></a> allQs <span class="op">=</span> [<span class="va">self</span>.weirdQs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostQs]</span>
+<span id="cb2-135"><a href="#cb2-135" aria-hidden="true" tabindex="-1"></a> allVs <span class="op">=</span> [<span class="va">self</span>.weirdVs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostVs]</span>
+<span id="cb2-136"><a href="#cb2-136" aria-hidden="true" tabindex="-1"></a> head_outputs <span class="op">=</span> []</span>
+<span id="cb2-137"><a href="#cb2-137" aria-hidden="true" tabindex="-1"></a> </span>
+<span id="cb2-138"><a href="#cb2-138" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> (K, Q, V) <span class="kw">in</span> <span class="bu">zip</span>(allKs, allQs, allVs):</span>
+<span id="cb2-139"><a href="#cb2-139" aria-hidden="true" tabindex="-1"></a> ksrc <span class="op">=</span> torch.matmul(src, K.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
+<span id="cb2-140"><a href="#cb2-140" aria-hidden="true" tabindex="-1"></a> qsrc <span class="op">=</span> torch.matmul(src, Q.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
+<span id="cb2-141"><a href="#cb2-141" aria-hidden="true" tabindex="-1"></a> vsrc <span class="op">=</span> torch.matmul(src, V.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
+<span id="cb2-142"><a href="#cb2-142" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-143"><a href="#cb2-143" aria-hidden="true" tabindex="-1"></a> scores <span class="op">=</span> torch.matmul(qsrc, ksrc.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
+<span id="cb2-144"><a href="#cb2-144" aria-hidden="true" tabindex="-1"></a> attention_weights <span class="op">=</span> torch.softmax(scores, dim<span class="op">=-</span><span class="dv">1</span>)</span>
+<span id="cb2-145"><a href="#cb2-145" aria-hidden="true" tabindex="-1"></a> head_output <span class="op">=</span> torch.matmul(attention_weights, vsrc)</span>
+<span id="cb2-146"><a href="#cb2-146" aria-hidden="true" tabindex="-1"></a> head_outputs.append(head_output)</span>
+<span id="cb2-147"><a href="#cb2-147" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-148"><a href="#cb2-148" aria-hidden="true" tabindex="-1"></a> new_reaches <span class="op">=</span> <span class="bu">sum</span>(head_outputs[<span class="dv">1</span>:])</span>
+<span id="cb2-149"><a href="#cb2-149" aria-hidden="true" tabindex="-1"></a> BSZ <span class="op">=</span> new_reaches.shape[<span class="dv">0</span>]</span>
+<span id="cb2-150"><a href="#cb2-150" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-151"><a href="#cb2-151" aria-hidden="true" tabindex="-1"></a> nodelta_nbrs <span class="op">=</span> torch.zeros((BSZ, SEQLEN, NVTXS <span class="op">+</span> <span class="dv">1</span>), device<span class="op">=</span><span class="va">self</span>.device)</span>
+<span id="cb2-152"><a href="#cb2-152" aria-hidden="true" tabindex="-1"></a> morepadlol <span class="op">=</span> torch.zeros((BSZ, SEQLEN, <span class="dv">1</span> <span class="op">+</span> NVTXS), device<span class="op">=</span><span class="va">self</span>.device)</span>
+<span id="cb2-153"><a href="#cb2-153" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-154"><a href="#cb2-154" aria-hidden="true" tabindex="-1"></a> src <span class="op">=</span> src <span class="op">+</span> torch.cat((nodelta_nbrs, new_reaches, head_outputs[<span class="dv">0</span>], morepadlol), dim<span class="op">=</span><span class="dv">2</span>)</span>
+<span id="cb2-155"><a href="#cb2-155" aria-hidden="true" tabindex="-1"></a> src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> torch.sigmoid(src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">*</span> CURSE) <span class="op">-</span> <span class="dv">1</span></span>
+<span id="cb2-156"><a href="#cb2-156" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-157"><a href="#cb2-157" aria-hidden="true" tabindex="-1"></a> canreach <span class="op">=</span> src[:, <span class="dv">0</span>, START_OUT:START_OUT <span class="op">+</span> NVTXS]</span>
+<span id="cb2-158"><a href="#cb2-158" aria-hidden="true" tabindex="-1"></a> final_output <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> torch.<span class="bu">sum</span>(<span class="dv">1</span> <span class="op">-</span> canreach, dim<span class="op">=</span><span class="dv">1</span>)</span>
+<span id="cb2-159"><a href="#cb2-159" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> final_output</span>
+<span id="cb2-160"><a href="#cb2-160" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-161"><a href="#cb2-161" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SillyTransformer(device).to(device)</span>
+<span id="cb2-162"><a href="#cb2-162" aria-hidden="true" tabindex="-1"></a>params <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())</span>
+<span id="cb2-163"><a href="#cb2-163" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f&quot;Total number of parameters: </span><span class="sc">{</span>params<span class="sc">}</span><span class="ss">&quot;</span>)</span>
+<span id="cb2-164"><a href="#cb2-164" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-165"><a href="#cb2-165" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> destroy_rand_weights(model):</span>
+<span id="cb2-166"><a href="#cb2-166" aria-hidden="true" tabindex="-1"></a> weight_lists <span class="op">=</span> [model.mostKs, model.mostQs, model.mostVs, </span>
+<span id="cb2-167"><a href="#cb2-167" aria-hidden="true" tabindex="-1"></a> model.weirdKs, model.weirdQs, model.weirdVs]</span>
+<span id="cb2-168"><a href="#cb2-168" aria-hidden="true" tabindex="-1"></a> random_list <span class="op">=</span> random.choice(weight_lists)</span>
+<span id="cb2-169"><a href="#cb2-169" aria-hidden="true" tabindex="-1"></a> random_matrix <span class="op">=</span> random.choice(random_list)</span>
+<span id="cb2-170"><a href="#cb2-170" aria-hidden="true" tabindex="-1"></a> random_matrix.data <span class="op">=</span> torch.randn_like(random_matrix)</span>
+<span id="cb2-171"><a href="#cb2-171" aria-hidden="true" tabindex="-1"></a></span>
+<span id="cb2-172"><a href="#cb2-172" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">1e-6</span>)</span>
+<span id="cb2-173"><a href="#cb2-173" aria-hidden="true" tabindex="-1"></a>loss_fn <span class="op">=</span> nn.MSELoss()</span></code></pre></div>
+<h2 id="references">References</h2>
+<section id="footnotes" class="footnotes footnotes-end-of-document"
+role="doc-endnotes">
+<hr />
+<ol>
+<li id="fn1"><p>Lin, Hilton, and Evans. “TruthfulQA: Measuring How
+Models Mimic Human Falsehoods”<a href="#fnref1" class="footnote-back"
+role="doc-backlink">↩︎</a></p></li>
+<li id="fn2"><p>Nanda, Neel, Lawrence Chan, Tom Lieberum, Jess Smith,
+and Jacob Steinhardt. “Progress measures for grokking via mechanistic
+interpretability.” arXiv preprint arXiv:2301.05217 (2023).
+https://arxiv.org/abs/2301.05217.<a href="#fnref2" class="footnote-back"
+role="doc-backlink">↩︎</a></p></li>
+</ol>
+</section>
</body>
</html>