diff options
Diffstat (limited to 'index.html')
-rw-r--r-- | index.html | 466 |
1 files changed, 326 insertions, 140 deletions
@@ -165,7 +165,70 @@ margin: 0 0.8em 0.2em -1.6em; vertical-align: middle; } - .display.math{display: block; text-align: center; margin: 0.5rem auto;} + /* CSS for syntax highlighting */ + pre > code.sourceCode { white-space: pre; position: relative; } + pre > code.sourceCode > span { line-height: 1.25; } + pre > code.sourceCode > span:empty { height: 1.2em; } + .sourceCode { overflow: visible; } + code.sourceCode > span { color: inherit; text-decoration: inherit; } + div.sourceCode { margin: 1em 0; } + pre.sourceCode { margin: 0; } + @media screen { + div.sourceCode { overflow: auto; } + } + @media print { + pre > code.sourceCode { white-space: pre-wrap; } + pre > code.sourceCode > span { display: inline-block; text-indent: -5em; padding-left: 5em; } + } + pre.numberSource code + { counter-reset: source-line 0; } + pre.numberSource code > span + { position: relative; left: -4em; counter-increment: source-line; } + pre.numberSource code > span > a:first-child::before + { content: counter(source-line); + position: relative; left: -1em; text-align: right; vertical-align: baseline; + border: none; display: inline-block; + -webkit-touch-callout: none; -webkit-user-select: none; + -khtml-user-select: none; -moz-user-select: none; + -ms-user-select: none; user-select: none; + padding: 0 4px; width: 4em; + color: #aaaaaa; + } + pre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa; padding-left: 4px; } + div.sourceCode + { } + @media screen { + pre > code.sourceCode > span > a:first-child::before { text-decoration: underline; } + } + code span.al { color: #ff0000; font-weight: bold; } /* Alert */ + code span.an { color: #60a0b0; font-weight: bold; font-style: italic; } /* Annotation */ + code span.at { color: #7d9029; } /* Attribute */ + code span.bn { color: #40a070; } /* BaseN */ + code span.bu { color: #008000; } /* BuiltIn */ + code span.cf { color: #007020; font-weight: bold; } /* ControlFlow */ + code span.ch { color: #4070a0; } /* Char */ + code span.cn { color: #880000; } /* Constant */ + code span.co { color: #60a0b0; font-style: italic; } /* Comment */ + code span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } /* CommentVar */ + code span.do { color: #ba2121; font-style: italic; } /* Documentation */ + code span.dt { color: #902000; } /* DataType */ + code span.dv { color: #40a070; } /* DecVal */ + code span.er { color: #ff0000; font-weight: bold; } /* Error */ + code span.ex { } /* Extension */ + code span.fl { color: #40a070; } /* Float */ + code span.fu { color: #06287e; } /* Function */ + code span.im { color: #008000; font-weight: bold; } /* Import */ + code span.in { color: #60a0b0; font-weight: bold; font-style: italic; } /* Information */ + code span.kw { color: #007020; font-weight: bold; } /* Keyword */ + code span.op { color: #666666; } /* Operator */ + code span.ot { color: #007020; } /* Other */ + code span.pp { color: #bc7a00; } /* Preprocessor */ + code span.sc { color: #4070a0; } /* SpecialChar */ + code span.ss { color: #bb6688; } /* SpecialString */ + code span.st { color: #4070a0; } /* String */ + code span.va { color: #19177c; } /* Variable */ + code span.vs { color: #4070a0; } /* VerbatimString */ + code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warning */ /* CSS for citations */ div.csl-bib-body { } div.csl-entry { @@ -187,6 +250,24 @@ div.csl-indent { margin-left: 2em; } </style> + <script defer="" + src="https://cdn.jsdelivr.net/npm/katex@0.15.1/dist/katex.min.js"></script> + <script>document.addEventListener("DOMContentLoaded", function () { + var mathElements = document.getElementsByClassName("math"); + var macros = []; + for (var i = 0; i < mathElements.length; i++) { + var texText = mathElements[i].firstChild; + if (mathElements[i].tagName == "SPAN") { + katex.render(texText.data, mathElements[i], { + displayMode: mathElements[i].classList.contains('display'), + throwOnError: false, + macros: macros, + fleqn: false + }); +}}}); + </script> + <link rel="stylesheet" + href="https://cdn.jsdelivr.net/npm/katex@0.15.1/dist/katex.min.css" /> </head> <body> <header id="title-block-header"> @@ -194,17 +275,11 @@ Generalization of Transformers</h1> </header> <!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 --> -<!-- <div style="display: flex; justify-content: space-between;"> - -<div style="flex: 1; margin: 5px; padding: 10px; border: 1px solid #ddd; text-align: center;"> --> <div style="text-align:center"> <p>Anthony Wang, Alek Westover, Kevin Zhao</p> <p>{xy,alekw,kevinmz}@mit.edu</p> </div> -<h2 id="abstract">Abstract</h2> -<p>TODO</p> -<h2 id="introduction">Introduction</h2> -<h3 id="overview">Overview</h3> +<h2 id="goals">Goals</h2> <p>Recently, LLMs have been developing very fast, and with that comes the concern of aligning the models to output true and productive statements. One common approach for ensuring this is to have a human in @@ -242,31 +317,28 @@ transform hand drawn cats into images of cats might be able to handle a generalizing truthfully is simple, thus promoted by “Occam’s Razor”, and aim to investigate that with this project.</p> <p>COMMENT FROM KEVIN – synthesize from intorduction</p> -<h2 id="task">Task</h2> +<h3 id="task">Task</h3> <p>We will use a synthetic task to test our hypothesis that models will generalize truthfully off-distribution. The synthetic task is computing the distance between various vertices in an input graph. Our experiment will have three parts:</p> <ol type="1"> <li>Pre-train a transformer to predict the distance between two fixed -vertices <span class="math inline"><em>s</em>, <em>t</em></span> on -graphs with <span class="math inline"><em>n</em> ∈ [8, 32)</span> -vertices.</li> +vertices <span class="math inline">s,t</span> on graphs with <span +class="math inline">n\in [8, 32)</span> vertices.</li> <li>Fine-tune a transformer to predict the distances between <span -class="math inline"><em>s</em>, <em>t</em>′</span> for any <span -class="math inline"><em>t</em>′</span> which is on the shortest path -from <span class="math inline"><em>s</em></span> to <span -class="math inline"><em>t</em></span>, but only do fine-tuning on graphs -with <span class="math inline"><em>n</em> ∈ [8, 16)</span> -vertices.</li> +class="math inline">s,t'</span> for any <span +class="math inline">t'</span> which is on the shortest path from +<span class="math inline">s</span> to <span +class="math inline">t</span>, but only do fine-tuning on graphs with +<span class="math inline">n\in [8,16)</span> vertices.</li> <li>Test whether the transformer can accurately predict the distances -between <span class="math inline"><em>s</em>, <em>t</em>′</span> for any -<span class="math inline"><em>t</em>′</span> on the shortest path from -<span class="math inline"><em>s</em></span> to <span -class="math inline"><em>t</em></span> for graphs with <span -class="math inline"><em>n</em> ∈ [16, 32)</span> vertices.</li> +between <span class="math inline">s,t'</span> for any <span +class="math inline">t'</span> on the shortest path from <span +class="math inline">s</span> to <span class="math inline">t</span> for +graphs with <span class="math inline">n\in [16,32)</span> vertices.</li> </ol> -<h2 id="related-work">Related Work</h2> +<h3 id="related-work">Related Work</h3> <p>COMMENT FROM ALEK – please remove all mentions of graph neural networks – that is BS: there is no actual reason why you’d ever use a Neural network to solve shortest paths, the point of choosing a @@ -295,111 +367,67 @@ href="#ref-10.1109/TPAMI.2023.3256421" role="doc-biblioref">Tutsoy 2023</a>)</span>. We understand from his paper how GNN optimization may also be useful in researching novel diseases.</p></li> </ul> -<h3 id="theory">Theory</h3> +<h2 id="methods">Methods</h2> <h3 id="algorithm-for-shortest-paths">Algorithm for Shortest Paths</h3> <p>The standard algorithm to find the shortest path in a graph between a -source numbered as <span class="math inline"><em>u</em></span> and sink -numbered as <span class="math inline"><em>v</em></span> is -<strong>breadth-first search (BFS)</strong>. The BFS algorithm maintains -a mapping of visited vertices to their distances with respect to <span -class="math inline"><em>u</em></span>, and each run of the algorithm -goes through all the vertices newly visited in the previous run, and for -each vertex, visits any of its unvisited neighbors. The algorithm -terminates once either <span class="math inline"><em>v</em></span> is -visited or the set of newly visited vertices in a single run is -empty.</p> +source numbered as <span class="math inline">u</span> and sink numbered +as <span class="math inline">v</span> is <strong>breadth-first search +(BFS)</strong>. The BFS algorithm maintains a mapping of visited +vertices to their distances with respect to <span +class="math inline">u</span>, and each run of the algorithm goes through +all the vertices newly visited in the previous run, and for each vertex, +visits any of its unvisited neighbors. The algorithm terminates once +either <span class="math inline">v</span> is visited or the set of newly +visited vertices in a single run is empty.</p> <p>We will use this algorithm to verify the accuracy of our machine -learning approach. Given <span class="math inline"><em>V</em></span> -vertices and <span class="math inline"><em>E</em></span> edges, the -runtime of this algorithm is thus <span -class="math inline"><em>O</em>(<em>V</em> + <em>E</em>)</span>; however, -a machine learning approach may do better in time through parallelism, -although at the expense of using much more memory.</p> -<h3 id="potential-mathematical-approaches-to-shortest-paths">Potential -Mathematical Approaches to Shortest Paths</h3> -<p>Another way one can think of the shortest path of a graph is using a -<em>matrix</em> to record which vertices are connected. Given vertices -numbered <span class="math inline">1</span> to <span -class="math inline"><em>V</em></span>, we denote the <strong>adjacency -matrix</strong> <span class="math inline"><strong>M</strong></span> of -dimensions <span class="math inline"><em>V</em> × <em>V</em></span> as -the matrix with element <span -class="math inline"><strong>M</strong><sub><em>i</em>, <em>j</em></sub> = 1</span> -if vertices <span class="math inline"><em>i</em></span> and <span -class="math inline"><em>j</em></span> are connected by an edge and <span -class="math inline"><strong>M</strong><sub><em>i</em>, <em>j</em></sub> = 0</span> -if they are not. Now, we note that (1) For all <span -class="math inline"><em>k</em></span>, <span -class="math inline">(<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup> = 0</span> -if and only if there exists no path from the vertex numbered <span -class="math inline"><em>i</em></span> to the vertex numbered <span -class="math inline"><em>j</em></span> that is distance <span -class="math inline"><em>k</em></span> or less due to Markov matrix -processes. As a result, if the distance between vertices numbered <span -class="math inline"><em>i</em></span> and <span -class="math inline"><em>j</em></span> is <span -class="math inline"><em>d</em></span>, then <span -class="math inline">min((<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup>, 1) = 1</span> -if <span class="math inline"><em>k</em> ≥ <em>d</em></span> and <span -class="math inline">min((<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup>, 1) = 0</span> -if <span class="math inline"><em>k</em> < <em>d</em></span>.</p> -<p>With this information, because the distance between any two vertices -is at most <span class="math inline"><em>V</em> − 1</span> in a graph -with <span class="math inline"><em>V</em></span> vertices, we note that -the <em>distance</em> matrix turns out to be simply <span -class="math display"><strong>D</strong> = <strong>1</strong><sub><em>V</em> × <em>V</em></sub> ⋅ <em>V</em> − <em>Σ</em><sub><em>i</em> = 0</sub><sup><em>V</em> − 1</sup>min((<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup>, 1).</span> -The runtime to compute this is <span -class="math inline"><em>O</em>(<em>V</em>)</span>, although it will take -more space to compute all powers of <span -class="math inline"><strong>M</strong></span>.</p> -<h2 id="our-machine-learning-approach">Our Machine Learning -Approach</h2> +learning approach. Given <span class="math inline">V</span> vertices and +<span class="math inline">E</span> edges, the runtime of this algorithm +is thus <span class="math inline">O(V + E)</span>; however, a machine +learning approach may do better in time through parallelism, although at +the expense of using much more memory.</p> <h3 id="data">Data</h3> -<p>We will represent an <span class="math inline"><em>n</em></span> -vertex, <span class="math inline"><em>m</em></span> edge unweighted, -undirected graph as sequence of the endpoints of the <span -class="math inline"><em>m</em></span> edges, so <span -class="math inline">[<em>a</em><sub>1</sub>, <em>b</em><sub>1</sub>, <em>a</em><sub>2</sub>, <em>b</em><sub>2</sub>, …, <em>a</em><sub><em>m</em></sub>, <em>b</em><sub><em>m</em></sub>]</span> -represents a graph with the edges <span -class="math inline">{(<em>a</em><sub><em>i</em></sub>, <em>b</em><sub><em>i</em></sub>)}</span> -for <span class="math inline">1 ≤ <em>i</em> ≤ <em>m</em></span>. We -will pad all sequences to be the same length using the padding token -0.</p> +<p>We will represent an <span class="math inline">n</span> vertex, <span +class="math inline">m</span> edge unweighted, undirected graph as +sequence of the endpoints of the <span class="math inline">m</span> +edges, so <span +class="math inline">[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]</span> represents a +graph with the edges <span class="math inline">\{(a_i,b_i)\}</span> for +<span class="math inline">1 \leq i \leq m</span>. We will pad all +sequences to be the same length using the padding token 0.</p> <p>The full input to our model will additionally add the target vertex after the padding tokens. The model is tasked with predicting the length of the shortest path between vertex 1 and the target vertex <span -class="math inline"><em>t</em></span>. If no such path exists, we define -the length to be <span class="math inline"><em>n</em> + 1</span> which -represents infinity. For example, an input-output pair for our model -could look like <span -class="math inline">[1, 3, 3, 2, 0, 0, 0, 0, 2]</span> and <span -class="math inline">2</span> respectively.</p> +class="math inline">t</span>. If no such path exists, we define the +length to be <span class="math inline">n+1</span> which represents +infinity. For example, an input-output pair for our model could look +like <span class="math inline">[1, 3, 3, 2, 0, 0, 0, 0, 2]</span> and +<span class="math inline">2</span> respectively.</p> <p>We have three separate datasets.</p> <ul> <li><strong>Pre-train data</strong>: For each <span -class="math inline"><em>n</em> ∈ [8, 32)</span>, we will generate -several graphs on <span class="math inline"><em>n</em></span> vertices. -We generate these graphs by inserting <span -class="math inline">2<em>n</em></span> random edges into the graph. We -always set the target vertex to be <span class="math inline">2</span> -here.</li> +class="math inline">n \in [8,32)</span>, we will generate several graphs +on <span class="math inline">n</span> vertices. We generate these graphs +by inserting <span class="math inline">2n</span> random edges into the +graph. We always set the target vertex to be <span +class="math inline">2</span> here.</li> <li><strong>Fine-tune data</strong>: For each <span -class="math inline"><em>n</em> ∈ [8, 16)</span>, we will generate -several graphs on <span class="math inline"><em>n</em></span> vertices. -We generate these graphs by inserting <span -class="math inline">2<em>n</em></span> random edges into the graph. We -select the target vertex to be a random vertex on the shortest path from -<span class="math inline">1</span> to <span +class="math inline">n \in [8,16)</span>, we will generate several graphs +on <span class="math inline">n</span> vertices. We generate these graphs +by inserting <span class="math inline">2n</span> random edges into the +graph. We select the target vertex to be a random vertex on the shortest +path from <span class="math inline">1</span> to <span class="math inline">2</span>.</li> <li><strong>Generalization testing data</strong>: The same as the -fine-tune data, except we sample <span -class="math inline"><em>n</em> ∈ [16, 32)</span> instead.</li> +fine-tune data, except we sample <span class="math inline">n \in +[16,32)</span> instead.</li> </ul> -<p>As a side note, we are also curious whether the transformer learns to -generalize to different distributions of graphs, such as denser graphs -or graphs with different properties. Time permitting, we will also -investigate this.</p> +<p>We wrote some Python code to generate the data during the training +loop, but Python is slow and the data generation wasted a lot of time +during training. To get around this, we pre-generated the data before +training and made our Python code multithreaded to speed it up.</p> <h3 id="architecture">Architecture</h3> +<p>TODO: honestly not much to say here since it’s a pretty typical +arch</p> <p>We plan to use a standard transformer architecture. We will ensure that the number of layers in our transformer is at least the diameter of the graph. By doing this, we ensure that there is an extremely simple @@ -413,39 +441,197 @@ model should be computing these other distances as intermediate values in its computation to find the distance to vertex <span class="math inline">2</span>.</p> <h3 id="embeddings">Embeddings</h3> -<p>TODO: fix this</p> -<p>In order to facilitate performing this task with limited -computational resources, we plan to use custom-made positional encodings -that tell the model extra information about the structure of the -problem, rather than the traditional sine/cosine positional encodings. -(TODO: THIS IS OUTDATED) Specifically, our positional encodings are -<span -class="math inline"><em>v</em><sub>1</sub>, <em>v</em><sub>1</sub>, <em>v</em><sub>2</sub>, <em>v</em><sub>2</sub>, …, <em>v</em><sub><em>m</em></sub>, <em>v</em><sub><em>m</em></sub>, <em>v</em><sub><em>m</em> + 1</sub></span> -where each <span -class="math inline"><em>v</em><sub><em>i</em></sub></span> is a random -vector so each <span -class="math inline"><em>v</em><sub><em>i</em></sub>, <em>v</em><sub><em>j</em></sub></span> -pair is nearly orthogonal with high probability. We will concatenate -these with the token encodings rather than adding them. This should let -the model easily have large attention scores between vertices -corresponding to a single edge.</p> -<h3 id="explicit-transformer-formula-for-shortest-paths">Explicit -transformer formula for shortest paths</h3> -<h2 id="results">Results</h2> -<h3 id="initial-results">Initial Results</h3> -<p>We used a model dimension of 64 64, four layers, and two heads per -layer. We used MSE loss, the Adam optimizer, a learning rate of 8e-4, -and a batch size of 131,072 for 8000 unique randomly generated batches. -Our final MSE loss was 0.35546875.</p> +<p>Since the order of the edges in the input does not matter, we did not +use positional encodings. Each edge <span +class="math inline">(u,v)</span> where <span class="math inline">u < +v</span> is embedded to a dimension of <span +class="math inline">d</span> where the first <span +class="math inline">\frac{d}{2}</span> elements are the learned +embedding of <span class="math inline">u</span> and the last <span +class="math inline">\frac{d}{2}</span> elements are the learned +embedding of <span class="math inline">v</span>. For the target vertex +<span class="math inline">t</span>, we also embedded to dimension <span +class="math inline">d</span>, where the first <span +class="math inline">\frac{d}{2}</span> elements are the learned +embedding of <span class="math inline">t</span> and the last <span +class="math inline">\frac{d}{2}</span> are a learned embedding of a +special token.</p> +<h2 id="training">Training</h2> +<p>For our model, we used a model dimension of 64, four layers, and two +heads per layer, for a total of 200545 parameters in bfloat16 which +corresponds to around 3.2e6 bits. The number of possible graphs on 15 +vertices generated using our procedure is approximately</p> +<p><span class="math display">\frac{\binom{15}{2}^{15}}{15!} = +1.59\cdot10^{18}.</span></p> +<p>This is because there are <span +class="math inline">\binom{15}{2}</span> choices for each of the 15 +edges and we don’t care about the order of the edges. This is only an +approximation because some edges might be duplicated. Each graph has an +answer between 1 and 15 which requires around 4 bits, so memorizing all +the answers requires <span class="math inline">4\cdot1.59\cdot10^{18} = +6.36\cdot10^{18}</span>, which is <span +class="math inline">2\cdot10^{12}</span> times larger than our model +size.</p> +<p>We used MSE loss, the Adam optimizer, a learning rate of 8e-4, and a +batch size of 131072 for 8000 unique randomly generated batches. Our +final MSE loss was approximately 0.3555.</p> <p><img src="training-loss.png" /></p> <p><img src="training-2d-histogram.png" /></p> -<h3 id="fine-tuning">Fine Tuning</h3> +<p>One pattern we consistently noticed during training is that the model +often gets stuck and plateaus for many epochs before rapidly decreasing. +For instance, this happened between epochs 100 and 300 in the graph +above:</p> +<p><img src="grokking.png" /></p> +<p>“grokking” hypothesis: it’s memorizing all length 2 paths?</p> +<p>TODO: training curves for 1, 2, 3 length paths</p> +<h3 +id="potential-mathematical-approaches-to-shortest-paths-delete-this">Potential +Mathematical Approaches to Shortest Paths? Delete this?</h3> +<p>Another way one can think of the shortest path of a graph is using a +<em>matrix</em> to record which vertices are connected. Given vertices +numbered <span class="math inline">1</span> to <span +class="math inline">V</span>, we denote the <strong>adjacency +matrix</strong> <span class="math inline">\textbf{M}</span> of +dimensions <span class="math inline">V \times V</span> as the matrix +with element <span class="math inline">\textbf{M}_{i, j} = 1</span> if +vertices <span class="math inline">i</span> and <span +class="math inline">j</span> are connected by an edge and <span +class="math inline">\textbf{M}_{i, j} = 0</span> if they are not. Now, +we note that (1) For all <span class="math inline">k</span>, <span +class="math inline">(\textbf{M}+I)^k_{i, j} = 0</span> if and only if +there exists no path from the vertex numbered <span +class="math inline">i</span> to the vertex numbered <span +class="math inline">j</span> that is distance <span +class="math inline">k</span> or less due to Markov matrix processes. As +a result, if the distance between vertices numbered <span +class="math inline">i</span> and <span class="math inline">j</span> is +<span class="math inline">d</span>, then <span +class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = +1</span> if <span class="math inline">k \ge d</span> and <span +class="math inline">\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = +0</span> if <span class="math inline">k < d</span>.</p> +<p>With this information, because the distance between any two vertices +is at most <span class="math inline">V-1</span> in a graph with <span +class="math inline">V</span> vertices, we note that the +<em>distance</em> matrix turns out to be simply <span +class="math display">\textbf{D} = \textbf{1}_{V \times V} \cdot V - +\Sigma_{i=0}^{V-1}\text{min}\left((\textbf{M}+I)^k_{i, j}, +1\right).</span> The runtime to compute this is <span +class="math inline">O(V)</span>, although it will take more space to +compute all powers of <span class="math inline">\textbf{M}</span>.</p> +<h2 id="fine-tuning-results">Fine tuning results</h2> <p>After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results -are shown below.</p> +are shown in the images below.</p> <p><img src="fine-tuning-loss.png" /></p> <p><img src="fine-tuning-2d-histogram.png" /></p> <p><img src="test-2d-histogram.png" /></p> +<p>Memorization? Do some math here to compute how many bits required to +memorize 1, 2, 3</p> +<h2 +id="complicated-explicit-transformer-formula-for-shortest-paths">Complicated +explicit transformer formula for shortest paths</h2> +<div class="sourceCode" id="cb1"><pre class="sourceCode py"><code class="sourceCode python"><span id="cb1-1"><a href="#cb1-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration</span></span> +<span id="cb1-2"><a href="#cb1-2" aria-hidden="true" tabindex="-1"></a>NVTXS <span class="op">=</span> <span class="dv">16</span></span> +<span id="cb1-3"><a href="#cb1-3" aria-hidden="true" tabindex="-1"></a>MAXDIST <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span> +<span id="cb1-4"><a href="#cb1-4" aria-hidden="true" tabindex="-1"></a>AVGDEG <span class="op">=</span> <span class="dv">2</span></span> +<span id="cb1-5"><a href="#cb1-5" aria-hidden="true" tabindex="-1"></a>SEQLEN <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span> +<span id="cb1-6"><a href="#cb1-6" aria-hidden="true" tabindex="-1"></a>HIDDENDIM <span class="op">=</span> <span class="dv">4</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">2</span></span> +<span id="cb1-7"><a href="#cb1-7" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-8"><a href="#cb1-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Start indices for different sections of the input data</span></span> +<span id="cb1-9"><a href="#cb1-9" aria-hidden="true" tabindex="-1"></a>START_REACH <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span> +<span id="cb1-10"><a href="#cb1-10" aria-hidden="true" tabindex="-1"></a>START_OUT <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span> +<span id="cb1-11"><a href="#cb1-11" aria-hidden="true" tabindex="-1"></a>START_SELF <span class="op">=</span> <span class="dv">3</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span> +<span id="cb1-12"><a href="#cb1-12" aria-hidden="true" tabindex="-1"></a>SRC_FLAG_IDX <span class="op">=</span> START_SELF</span> +<span id="cb1-13"><a href="#cb1-13" aria-hidden="true" tabindex="-1"></a>ANS_FLAG_IDX <span class="op">=</span> <span class="dv">0</span></span> +<span id="cb1-14"><a href="#cb1-14" aria-hidden="true" tabindex="-1"></a>NOTANS_FLAG_IDX <span class="op">=</span> <span class="op">-</span><span class="dv">1</span></span> +<span id="cb1-15"><a href="#cb1-15" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-16"><a href="#cb1-16" aria-hidden="true" tabindex="-1"></a>BIG <span class="op">=</span> <span class="dv">20</span></span> +<span id="cb1-17"><a href="#cb1-17" aria-hidden="true" tabindex="-1"></a>SUPABIG <span class="op">=</span> <span class="dv">100</span></span> +<span id="cb1-18"><a href="#cb1-18" aria-hidden="true" tabindex="-1"></a>MED <span class="op">=</span> <span class="dv">10</span></span> +<span id="cb1-19"><a href="#cb1-19" aria-hidden="true" tabindex="-1"></a>CURSE <span class="op">=</span> <span class="dv">5</span></span> +<span id="cb1-20"><a href="#cb1-20" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-21"><a href="#cb1-21" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SillyTransformer(nn.Module):</span> +<span id="cb1-22"><a href="#cb1-22" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, device):</span> +<span id="cb1-23"><a href="#cb1-23" aria-hidden="true" tabindex="-1"></a> <span class="bu">super</span>().<span class="fu">__init__</span>()</span> +<span id="cb1-24"><a href="#cb1-24" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.device <span class="op">=</span> device</span> +<span id="cb1-25"><a href="#cb1-25" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-26"><a href="#cb1-26" aria-hidden="true" tabindex="-1"></a> <span class="cf">with</span> torch.no_grad():</span> +<span id="cb1-27"><a href="#cb1-27" aria-hidden="true" tabindex="-1"></a> <span class="co"># Initialize weight parameters with specific configurations</span></span> +<span id="cb1-28"><a href="#cb1-28" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostKs <span class="op">=</span> nn.ParameterList()</span> +<span id="cb1-29"><a href="#cb1-29" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostQs <span class="op">=</span> nn.ParameterList()</span> +<span id="cb1-30"><a href="#cb1-30" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostVs <span class="op">=</span> nn.ParameterList()</span> +<span id="cb1-31"><a href="#cb1-31" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> head <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">1</span>):</span> +<span id="cb1-32"><a href="#cb1-32" aria-hidden="true" tabindex="-1"></a> Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span> +<span id="cb1-33"><a href="#cb1-33" aria-hidden="true" tabindex="-1"></a> Q[<span class="dv">0</span>, START_REACH <span class="op">-</span> <span class="dv">1</span> <span class="op">+</span> head] <span class="op">=</span> SUPABIG</span> +<span id="cb1-34"><a href="#cb1-34" aria-hidden="true" tabindex="-1"></a> Q[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span> +<span id="cb1-35"><a href="#cb1-35" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-36"><a href="#cb1-36" aria-hidden="true" tabindex="-1"></a> K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span> +<span id="cb1-37"><a href="#cb1-37" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, head] <span class="op">=</span> <span class="dv">1</span></span> +<span id="cb1-38"><a href="#cb1-38" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> BIG</span> +<span id="cb1-39"><a href="#cb1-39" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-40"><a href="#cb1-40" aria-hidden="true" tabindex="-1"></a> V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span> +<span id="cb1-41"><a href="#cb1-41" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span> +<span id="cb1-42"><a href="#cb1-42" aria-hidden="true" tabindex="-1"></a> V[i, START_SELF <span class="op">+</span> i] <span class="op">=</span> <span class="dv">1</span></span> +<span id="cb1-43"><a href="#cb1-43" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-44"><a href="#cb1-44" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostKs.append(K)</span> +<span id="cb1-45"><a href="#cb1-45" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostQs.append(Q)</span> +<span id="cb1-46"><a href="#cb1-46" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.mostVs.append(V)</span> +<span id="cb1-47"><a href="#cb1-47" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-48"><a href="#cb1-48" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdKs <span class="op">=</span> nn.ParameterList()</span> +<span id="cb1-49"><a href="#cb1-49" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdQs <span class="op">=</span> nn.ParameterList()</span> +<span id="cb1-50"><a href="#cb1-50" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdVs <span class="op">=</span> nn.ParameterList()</span> +<span id="cb1-51"><a href="#cb1-51" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span> +<span id="cb1-52"><a href="#cb1-52" aria-hidden="true" tabindex="-1"></a> K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span> +<span id="cb1-53"><a href="#cb1-53" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG</span> +<span id="cb1-54"><a href="#cb1-54" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">0</span>, SRC_FLAG_IDX] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span> +<span id="cb1-55"><a href="#cb1-55" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>SUPABIG</span> +<span id="cb1-56"><a href="#cb1-56" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">2</span>] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span> +<span id="cb1-57"><a href="#cb1-57" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG<span class="op">-</span>SUPABIG</span> +<span id="cb1-58"><a href="#cb1-58" aria-hidden="true" tabindex="-1"></a> K[<span class="dv">2</span>, ANS_FLAG_IDX] <span class="op">=</span> MED</span> +<span id="cb1-59"><a href="#cb1-59" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-60"><a href="#cb1-60" aria-hidden="true" tabindex="-1"></a> Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span> +<span id="cb1-61"><a href="#cb1-61" aria-hidden="true" tabindex="-1"></a> Q[:, ANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span> +<span id="cb1-62"><a href="#cb1-62" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-63"><a href="#cb1-63" aria-hidden="true" tabindex="-1"></a> V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span> +<span id="cb1-64"><a href="#cb1-64" aria-hidden="true" tabindex="-1"></a> V[layer, SRC_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span> +<span id="cb1-65"><a href="#cb1-65" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-66"><a href="#cb1-66" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdKs.append(K)</span> +<span id="cb1-67"><a href="#cb1-67" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdQs.append(Q)</span> +<span id="cb1-68"><a href="#cb1-68" aria-hidden="true" tabindex="-1"></a> <span class="va">self</span>.weirdVs.append(V)</span> +<span id="cb1-69"><a href="#cb1-69" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-70"><a href="#cb1-70" aria-hidden="true" tabindex="-1"></a> <span class="kw">def</span> forward(<span class="va">self</span>, src):</span> +<span id="cb1-71"><a href="#cb1-71" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span> +<span id="cb1-72"><a href="#cb1-72" aria-hidden="true" tabindex="-1"></a> allKs <span class="op">=</span> [<span class="va">self</span>.weirdKs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostKs]</span> +<span id="cb1-73"><a href="#cb1-73" aria-hidden="true" tabindex="-1"></a> allQs <span class="op">=</span> [<span class="va">self</span>.weirdQs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostQs]</span> +<span id="cb1-74"><a href="#cb1-74" aria-hidden="true" tabindex="-1"></a> allVs <span class="op">=</span> [<span class="va">self</span>.weirdVs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostVs]</span> +<span id="cb1-75"><a href="#cb1-75" aria-hidden="true" tabindex="-1"></a> head_outputs <span class="op">=</span> []</span> +<span id="cb1-76"><a href="#cb1-76" aria-hidden="true" tabindex="-1"></a> </span> +<span id="cb1-77"><a href="#cb1-77" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> (K, Q, V) <span class="kw">in</span> <span class="bu">zip</span>(allKs, allQs, allVs):</span> +<span id="cb1-78"><a href="#cb1-78" aria-hidden="true" tabindex="-1"></a> ksrc <span class="op">=</span> torch.matmul(src, K.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span> +<span id="cb1-79"><a href="#cb1-79" aria-hidden="true" tabindex="-1"></a> qsrc <span class="op">=</span> torch.matmul(src, Q.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span> +<span id="cb1-80"><a href="#cb1-80" aria-hidden="true" tabindex="-1"></a> vsrc <span class="op">=</span> torch.matmul(src, V.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span> +<span id="cb1-81"><a href="#cb1-81" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-82"><a href="#cb1-82" aria-hidden="true" tabindex="-1"></a> scores <span class="op">=</span> torch.matmul(qsrc, ksrc.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span> +<span id="cb1-83"><a href="#cb1-83" aria-hidden="true" tabindex="-1"></a> attention_weights <span class="op">=</span> torch.softmax(scores, dim<span class="op">=-</span><span class="dv">1</span>)</span> +<span id="cb1-84"><a href="#cb1-84" aria-hidden="true" tabindex="-1"></a> head_output <span class="op">=</span> torch.matmul(attention_weights, vsrc)</span> +<span id="cb1-85"><a href="#cb1-85" aria-hidden="true" tabindex="-1"></a> head_outputs.append(head_output)</span> +<span id="cb1-86"><a href="#cb1-86" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-87"><a href="#cb1-87" aria-hidden="true" tabindex="-1"></a> new_reaches <span class="op">=</span> <span class="bu">sum</span>(head_outputs[<span class="dv">1</span>:])</span> +<span id="cb1-88"><a href="#cb1-88" aria-hidden="true" tabindex="-1"></a> BSZ <span class="op">=</span> new_reaches.shape[<span class="dv">0</span>]</span> +<span id="cb1-89"><a href="#cb1-89" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-90"><a href="#cb1-90" aria-hidden="true" tabindex="-1"></a> nodelta_nbrs <span class="op">=</span> torch.zeros((BSZ, SEQLEN, NVTXS <span class="op">+</span> <span class="dv">1</span>), device<span class="op">=</span><span class="va">self</span>.device)</span> +<span id="cb1-91"><a href="#cb1-91" aria-hidden="true" tabindex="-1"></a> morepadlol <span class="op">=</span> torch.zeros((BSZ, SEQLEN, <span class="dv">1</span> <span class="op">+</span> NVTXS), device<span class="op">=</span><span class="va">self</span>.device)</span> +<span id="cb1-92"><a href="#cb1-92" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-93"><a href="#cb1-93" aria-hidden="true" tabindex="-1"></a> src <span class="op">=</span> src <span class="op">+</span> torch.cat((nodelta_nbrs, new_reaches, head_outputs[<span class="dv">0</span>], morepadlol), dim<span class="op">=</span><span class="dv">2</span>)</span> +<span id="cb1-94"><a href="#cb1-94" aria-hidden="true" tabindex="-1"></a> src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> torch.sigmoid(src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">*</span> CURSE) <span class="op">-</span> <span class="dv">1</span></span> +<span id="cb1-95"><a href="#cb1-95" aria-hidden="true" tabindex="-1"></a></span> +<span id="cb1-96"><a href="#cb1-96" aria-hidden="true" tabindex="-1"></a> canreach <span class="op">=</span> src[:, <span class="dv">0</span>, START_OUT:START_OUT <span class="op">+</span> NVTXS]</span> +<span id="cb1-97"><a href="#cb1-97" aria-hidden="true" tabindex="-1"></a> final_output <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> torch.<span class="bu">sum</span>(<span class="dv">1</span> <span class="op">-</span> canreach, dim<span class="op">=</span><span class="dv">1</span>)</span> +<span id="cb1-98"><a href="#cb1-98" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> final_output</span></code></pre></div> +<h2 id="alek-perturbed-experiment">Alek perturbed experiment</h2> +<h2 id="conclusion">Conclusion</h2> +<p>just do bfs lol</p> <h2 class="unnumbered" id="references">References</h2> <div id="refs" class="references csl-bib-body hanging-indent" data-entry-spacing="0" role="list"> |