aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--blog.md81
-rw-r--r--fine-tuning-2d-histogram.pngbin0 -> 8484 bytes
-rw-r--r--fine-tuning-loss.pngbin0 -> 15999 bytes
-rw-r--r--index.html166
-rw-r--r--insane-shortest-paths.ipynb263
-rw-r--r--test-2d-histogram.pngbin0 -> 10783 bytes
-rw-r--r--training-2d-histogram.pngbin0 -> 8764 bytes
-rw-r--r--training-loss.pngbin0 -> 16108 bytes
8 files changed, 457 insertions, 53 deletions
diff --git a/blog.md b/blog.md
index 68de189..02a425e 100644
--- a/blog.md
+++ b/blog.md
@@ -1,5 +1,6 @@
---
build: pandoc blog.md --citeproc -s -o index.html
+mkzip: zip project.zip index.html *.png
title: "6.7960 Project: Investigating Off-Distribution Generalization of Transformers"
bibliography: blog.bib
link-citations: true
@@ -7,10 +8,6 @@ link-citations: true
<!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 -->
-<!-- <div style="display: flex; justify-content: space-between;">
-
-<div style="flex: 1; margin: 5px; padding: 10px; border: 1px solid #ddd; text-align: center;"> -->
-
<div style="text-align:center">
Anthony Wang, Alek Westover, Kevin Zhao
@@ -19,35 +16,57 @@ Anthony Wang, Alek Westover, Kevin Zhao
## Abstract
-TODO: Probably should move this to the introduction instead of abstract?
+TODO
+
+## Introduction
+
+### Overview
+
+Recently, LLMs have been developing very fast, and with that comes the concern of aligning the models to output true and productive statements. One common approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF), but one drawback to this problem is that humans can be poor judges of truthfulness. As LLMs become more capable, there might not even exist experts that are good judges of whether the model's outputs, such as difficult mathematical proofs, are truthful. So, we'd like to propose a potential solution to this issue via **off-distribution generalization** - applying human-like intuition to solve problems not in the dataset. Paul Christiano [proposed an experiment](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) about shortest paths in a graph; our project is essentially to implement Christiano's proposed experiment. To the best of our knowledge, although there has been research in applying machine learning for different variations of graph searches [@10.5555/3666122.3666260], no one has done our exact experiment yet.
-Recently, LLMs have been developing very fast, and with that comes the concern of aligning the models to output true and productive statements. One common approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF), but one drawback to this problem is that humans can be poor judges of truthfulness. As LLMs become more capable, there might not even exist experts that are good judges of whether the model's outputs, such as difficult mathematical proofs, are truthful. So, we'd like to propose a potential solution to this issue via **off-distribution generalization** - applying human-like intuition to solve problems not in the dataset. To do so involves Graph Neural Networks (GNNs), which are important in many artificial intelligence applications, including the design of practical features. Paul Christiano [proposed an experiment](https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX) about shortest paths in a graph; our project is essentially to implement Christiano's proposed experiment. To the best of our knowledge, although there has been research in applying machine learning for different variations of graph searches [@10.5555/3666122.3666260], no one has done our exact experiment yet.
+It is generally desirable for LLMs to output true statements. A current approach for ensuring this is to have a human in the loop rewarding the model for true outputs (e.g. RLHF); however, humans can be poor judges of truthfulness. We enjoy many cognitive biases and might employ superficial heuristics when judging truthfulness. A further challenge is that as LLMs develop further, there might not even exist experts that can correctly judge the accuracy and truthfulness of sophisticated outputs such as difficult mathematical proofs.
-# Introduction
+One approach to solving this problem is to reward an LLM for truthful behavior on simple inputs, and then hoping that the LLM generalizes its truthful behavior for more complex inputs where humans cannot provide helpful labels. Deep learning models often perform remarkable feats of off-distribution generalization -- for instance, a model trained to transform hand drawn cats into images of cats might be able to handle a "cat" with three eyes in an intuitive way. We might hope that generalizing truthfully is simple, thus promoted by "Occam's Razor", and aim to investigate that with this project.
-## Overview
+COMMENT FROM KEVIN -- synthesize from intorduction
+
+## Task
+
+We will use a synthetic task to test our hypothesis that models will generalize truthfully off-distribution. The synthetic task is computing the distance between various vertices in an input graph. Our experiment will have three parts:
+
+1. Pre-train a transformer to predict the distance between two fixed vertices $s,t$ on graphs with $n\in [8, 32)$ vertices.
+2. Fine-tune a transformer to predict the distances between $s,t'$ for any $t'$ which is on the shortest path from $s$ to $t$, but only do fine-tuning on graphs with $n\in [8,16)$ vertices.
+3. Test whether the transformer can accurately predict the distances between $s,t'$ for any $t'$ on the shortest path from $s$ to $t$ for graphs with $n\in [16,32)$ vertices.
## Related Work
+COMMENT FROM ALEK
+-- please remove all mentions of graph neural networks -- that is BS: there is no actual reason why you'd ever use a Neural network to solve shortest paths, the point of choosing a synthetic task is because there is a **simple ground truth** which makes it easy to evaluate whether or not our model is performing correctly. We'd also hoped that the simplicity of the task would make it more feasible to do with a limited compute budget, but apparently this task was too hard for our architecture.
+
+
There has been some research into the algorithmic optimization of GNNs and how they may solve real-world issues; however, none of the related work targets using generic machine learning methods to solve graph problems.
- Cappart et al. has researched more into the Combinatorial Optimization of GNNs and developed algorithms for related tasks, thus facilitating machine learning [@DBLP:journals/corr/abs-2102-09544]. Their results are mostly algorithmic so we develop further by trading a bit of accuracy for much faster computation in such tasks.
- Tutsoy uses a graph-theory-based approach to model the epidemiological characteristics of infectious diseases, such as COVID-19 [@10.1109/TPAMI.2023.3256421]. We understand from his paper how GNN optimization may also be useful in researching novel diseases.
-## Sectioning
+### Theory
-# Our Machine Learning Approach
+### Algorithm for Shortest Paths
-## Task
+The standard algorithm to find the shortest path in a graph between a source numbered as $u$ and sink numbered as $v$ is **breadth-first search (BFS)**. The BFS algorithm maintains a mapping of visited vertices to their distances with respect to $u$, and each run of the algorithm goes through all the vertices newly visited in the previous run, and for each vertex, visits any of its unvisited neighbors. The algorithm terminates once either $v$ is visited or the set of newly visited vertices in a single run is empty.
-We will use a synthetic task to test our hypothesis that models will generalize truthfully off-distribution. The synthetic task is computing the distance between various vertices in an input graph. Our experiment will have three parts:
+We will use this algorithm to verify the accuracy of our machine learning approach. Given $V$ vertices and $E$ edges, the runtime of this algorithm is thus $O(V + E)$; however, a machine learning approach may do better in time through parallelism, although at the expense of using much more memory.
-1. Pre-train a transformer to predict the distance between two fixed vertices $s,t$ on graphs with $n\in [8, 32)$ vertices.
-2. Fine-tune a transformer to predict the distances between $s,t'$ for any $t'$ which is on the shortest path from $s$ to $t$, but only do fine-tuning on graphs with $n\in [8,16)$ vertices.
-3. Test whether the transformer can accurately predict the distances between $s,t'$ for any $t'$ on the shortest path from $s$ to $t$ for graphs with $n\in [16,32)$ vertices.
+### Potential Mathematical Approaches to Shortest Paths
+
+Another way one can think of the shortest path of a graph is using a *matrix* to record which vertices are connected. Given vertices numbered $1$ to $V$, we denote the **adjacency matrix** $\textbf{M}$ of dimensions $V \times V$ as the matrix with element $\textbf{M}_{i, j} = 1$ if vertices $i$ and $j$ are connected by an edge and $\textbf{M}_{i, j} = 0$ if they are not. Now, we note that (1) For all $k$, $(\textbf{M}+I)^k_{i, j} = 0$ if and only if there exists no path from the vertex numbered $i$ to the vertex numbered $j$ that is distance $k$ or less due to Markov matrix processes. As a result, if the distance between vertices numbered $i$ and $j$ is $d$, then $\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = 1$ if $k \ge d$ and $\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right) = 0$ if $k < d$.
+
+With this information, because the distance between any two vertices is at most $V-1$ in a graph with $V$ vertices, we note that the *distance* matrix turns out to be simply $$\textbf{D} = \textbf{1}_{V \times V} \cdot V - \Sigma_{i=0}^{V-1}\text{min}\left((\textbf{M}+I)^k_{i, j}, 1\right).$$ The runtime to compute this is $O(V)$, although it will take more space to compute all powers of $\textbf{M}$.
-## Data
+## Our Machine Learning Approach
+
+### Data
We will represent an $n$ vertex, $m$ edge unweighted, undirected graph as sequence of the endpoints of the $m$ edges, so $[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]$ represents a graph with the edges $\{(a_i,b_i)\}$ for $1 \leq i \leq m$. We will pad all sequences to be the same length using the padding token 0.
@@ -61,13 +80,39 @@ We have three separate datasets.
As a side note, we are also curious whether the transformer learns to generalize to different distributions of graphs, such as denser graphs or graphs with different properties. Time permitting, we will also investigate this.
-## Architecture
+### Architecture
We plan to use a standard transformer architecture. We will ensure that the number of layers in our transformer is at least the diameter of the graph. By doing this, we ensure that there is an extremely simple circuit --- namely BFS --- that the transformer could in theory learn to perform the task. Note that if the transformer actually learns a simple circuit to perform this task, then it seems more likely to generalize well. This is also our intuition for why it should be possible to fine tune on a small amount of data for finding shortest paths to other vertices besides $2$ -- it seems like the model should be computing these other distances as intermediate values in its computation to find the distance to vertex $2$.
-## Positional Encodings
+### Embeddings
+
+TODO: fix this
In order to facilitate performing this task with limited computational resources, we plan to use custom-made positional encodings that tell the model extra information about the
structure of the problem, rather than the traditional sine/cosine positional encodings. (TODO: THIS IS OUTDATED) Specifically, our positional encodings are $v_1,v_1,v_2,v_2,\ldots,v_m,v_m,v_{m+1}$ where each $v_i$ is a random vector so each $v_i,v_j$ pair is nearly orthogonal with high probability. We will concatenate these with the token encodings rather than adding them. This should let the model easily have large attention scores between vertices corresponding to a single edge.
+### Explicit transformer formula for shortest paths
+
+
+
+## Results
+
+### Initial Results
+
+We used a model dimension of 64, four layers, and two heads per layer. We used MSE loss, the Adam optimizer, a learning rate of 8e-4, and a batch size of 131,072 for 8000 unique randomly generated batches. Our final MSE loss was 0.35546875.
+
+![](training-loss.png)
+
+![](training-2d-histogram.png)
+
+### Fine Tuning
+
+After receiving our initial results, we fine-tuned with a learning rate of 1e-5, also with MSE and the same batch size. Our final results are shown below.
+
+![](fine-tuning-loss.png)
+
+![](fine-tuning-2d-histogram.png)
+
+![](test-2d-histogram.png)
+
## References
diff --git a/fine-tuning-2d-histogram.png b/fine-tuning-2d-histogram.png
new file mode 100644
index 0000000..df45973
--- /dev/null
+++ b/fine-tuning-2d-histogram.png
Binary files differ
diff --git a/fine-tuning-loss.png b/fine-tuning-loss.png
new file mode 100644
index 0000000..7e95a06
--- /dev/null
+++ b/fine-tuning-loss.png
Binary files differ
diff --git a/index.html b/index.html
index 0126a3b..95d23ff 100644
--- a/index.html
+++ b/index.html
@@ -202,8 +202,9 @@ Generalization of Transformers</h1>
<p>{xy,alekw,kevinmz}@mit.edu</p>
</div>
<h2 id="abstract">Abstract</h2>
-<p>TODO: Probably should move this to the introduction instead of
-abstract?</p>
+<p>TODO</p>
+<h2 id="introduction">Introduction</h2>
+<h3 id="overview">Overview</h3>
<p>Recently, LLMs have been developing very fast, and with that comes
the concern of aligning the models to output true and productive
statements. One common approach for ensuring this is to have a human in
@@ -214,10 +215,7 @@ experts that are good judges of whether the model’s outputs, such as
difficult mathematical proofs, are truthful. So, we’d like to propose a
potential solution to this issue via <strong>off-distribution
generalization</strong> - applying human-like intuition to solve
-problems not in the dataset. To do so involves Graph Neural Networks
-(GNNs), which are important in many artificial intelligence
-applications, including the design of practical features. Paul
-Christiano <a
+problems not in the dataset. Paul Christiano <a
href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">proposed
an experiment</a> about shortest paths in a graph; our project is
essentially to implement Christiano’s proposed experiment. To the best
@@ -226,32 +224,24 @@ learning for different variations of graph searches <span
class="citation" data-cites="10.5555/3666122.3666260">(<a
href="#ref-10.5555/3666122.3666260" role="doc-biblioref">Zang et al.
2024</a>)</span>, no one has done our exact experiment yet.</p>
-<h1 id="introduction">Introduction</h1>
-<h2 id="overview">Overview</h2>
-<h2 id="related-work">Related Work</h2>
-<p>There has been some research into the algorithmic optimization of
-GNNs and how they may solve real-world issues; however, none of the
-related work targets using generic machine learning methods to solve
-graph problems.</p>
-<ul>
-<li><p>Cappart et al. has researched more into the Combinatorial
-Optimization of GNNs and developed algorithms for related tasks, thus
-facilitating machine learning <span class="citation"
-data-cites="DBLP:journals/corr/abs-2102-09544">(<a
-href="#ref-DBLP:journals/corr/abs-2102-09544"
-role="doc-biblioref">Cappart et al. 2021</a>)</span>. Their results are
-mostly algorithmic so we develop further by trading a bit of accuracy
-for much faster computation in such tasks.</p></li>
-<li><p>Tutsoy uses a graph-theory-based approach to model the
-epidemiological characteristics of infectious diseases, such as COVID-19
-<span class="citation" data-cites="10.1109/TPAMI.2023.3256421">(<a
-href="#ref-10.1109/TPAMI.2023.3256421" role="doc-biblioref">Tutsoy
-2023</a>)</span>. We understand from his paper how GNN optimization may
-also be useful in researching novel diseases.</p></li>
-</ul>
-<h2 id="sectioning">Sectioning</h2>
-<h1 id="our-machine-learning-approach">Our Machine Learning
-Approach</h1>
+<p>It is generally desirable for LLMs to output true statements. A
+current approach for ensuring this is to have a human in the loop
+rewarding the model for true outputs (e.g. RLHF); however, humans can be
+poor judges of truthfulness. We enjoy many cognitive biases and might
+employ superficial heuristics when judging truthfulness. A further
+challenge is that as LLMs develop further, there might not even exist
+experts that can correctly judge the accuracy and truthfulness of
+sophisticated outputs such as difficult mathematical proofs.</p>
+<p>One approach to solving this problem is to reward an LLM for truthful
+behavior on simple inputs, and then hoping that the LLM generalizes its
+truthful behavior for more complex inputs where humans cannot provide
+helpful labels. Deep learning models often perform remarkable feats of
+off-distribution generalization – for instance, a model trained to
+transform hand drawn cats into images of cats might be able to handle a
+“cat” with three eyes in an intuitive way. We might hope that
+generalizing truthfully is simple, thus promoted by “Occam’s Razor”, and
+aim to investigate that with this project.</p>
+<p>COMMENT FROM KEVIN – synthesize from intorduction</p>
<h2 id="task">Task</h2>
<p>We will use a synthetic task to test our hypothesis that models will
generalize truthfully off-distribution. The synthetic task is computing
@@ -276,7 +266,95 @@ between <span class="math inline"><em>s</em>, <em>t</em>′</span> for any
class="math inline"><em>t</em></span> for graphs with <span
class="math inline"><em>n</em> ∈ [16, 32)</span> vertices.</li>
</ol>
-<h2 id="data">Data</h2>
+<h2 id="related-work">Related Work</h2>
+<p>COMMENT FROM ALEK – please remove all mentions of graph neural
+networks – that is BS: there is no actual reason why you’d ever use a
+Neural network to solve shortest paths, the point of choosing a
+synthetic task is because there is a <strong>simple ground
+truth</strong> which makes it easy to evaluate whether or not our model
+is performing correctly. We’d also hoped that the simplicity of the task
+would make it more feasible to do with a limited compute budget, but
+apparently this task was too hard for our architecture.</p>
+<p>There has been some research into the algorithmic optimization of
+GNNs and how they may solve real-world issues; however, none of the
+related work targets using generic machine learning methods to solve
+graph problems.</p>
+<ul>
+<li><p>Cappart et al. has researched more into the Combinatorial
+Optimization of GNNs and developed algorithms for related tasks, thus
+facilitating machine learning <span class="citation"
+data-cites="DBLP:journals/corr/abs-2102-09544">(<a
+href="#ref-DBLP:journals/corr/abs-2102-09544"
+role="doc-biblioref">Cappart et al. 2021</a>)</span>. Their results are
+mostly algorithmic so we develop further by trading a bit of accuracy
+for much faster computation in such tasks.</p></li>
+<li><p>Tutsoy uses a graph-theory-based approach to model the
+epidemiological characteristics of infectious diseases, such as COVID-19
+<span class="citation" data-cites="10.1109/TPAMI.2023.3256421">(<a
+href="#ref-10.1109/TPAMI.2023.3256421" role="doc-biblioref">Tutsoy
+2023</a>)</span>. We understand from his paper how GNN optimization may
+also be useful in researching novel diseases.</p></li>
+</ul>
+<h3 id="theory">Theory</h3>
+<h3 id="algorithm-for-shortest-paths">Algorithm for Shortest Paths</h3>
+<p>The standard algorithm to find the shortest path in a graph between a
+source numbered as <span class="math inline"><em>u</em></span> and sink
+numbered as <span class="math inline"><em>v</em></span> is
+<strong>breadth-first search (BFS)</strong>. The BFS algorithm maintains
+a mapping of visited vertices to their distances with respect to <span
+class="math inline"><em>u</em></span>, and each run of the algorithm
+goes through all the vertices newly visited in the previous run, and for
+each vertex, visits any of its unvisited neighbors. The algorithm
+terminates once either <span class="math inline"><em>v</em></span> is
+visited or the set of newly visited vertices in a single run is
+empty.</p>
+<p>We will use this algorithm to verify the accuracy of our machine
+learning approach. Given <span class="math inline"><em>V</em></span>
+vertices and <span class="math inline"><em>E</em></span> edges, the
+runtime of this algorithm is thus <span
+class="math inline"><em>O</em>(<em>V</em> + <em>E</em>)</span>; however,
+a machine learning approach may do better in time through parallelism,
+although at the expense of using much more memory.</p>
+<h3 id="potential-mathematical-approaches-to-shortest-paths">Potential
+Mathematical Approaches to Shortest Paths</h3>
+<p>Another way one can think of the shortest path of a graph is using a
+<em>matrix</em> to record which vertices are connected. Given vertices
+numbered <span class="math inline">1</span> to <span
+class="math inline"><em>V</em></span>, we denote the <strong>adjacency
+matrix</strong> <span class="math inline"><strong>M</strong></span> of
+dimensions <span class="math inline"><em>V</em> × <em>V</em></span> as
+the matrix with element <span
+class="math inline"><strong>M</strong><sub><em>i</em>, <em>j</em></sub> = 1</span>
+if vertices <span class="math inline"><em>i</em></span> and <span
+class="math inline"><em>j</em></span> are connected by an edge and <span
+class="math inline"><strong>M</strong><sub><em>i</em>, <em>j</em></sub> = 0</span>
+if they are not. Now, we note that (1) For all <span
+class="math inline"><em>k</em></span>, <span
+class="math inline">(<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup> = 0</span>
+if and only if there exists no path from the vertex numbered <span
+class="math inline"><em>i</em></span> to the vertex numbered <span
+class="math inline"><em>j</em></span> that is distance <span
+class="math inline"><em>k</em></span> or less due to Markov matrix
+processes. As a result, if the distance between vertices numbered <span
+class="math inline"><em>i</em></span> and <span
+class="math inline"><em>j</em></span> is <span
+class="math inline"><em>d</em></span>, then <span
+class="math inline">min((<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup>, 1) = 1</span>
+if <span class="math inline"><em>k</em> ≥ <em>d</em></span> and <span
+class="math inline">min((<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup>, 1) = 0</span>
+if <span class="math inline"><em>k</em> &lt; <em>d</em></span>.</p>
+<p>With this information, because the distance between any two vertices
+is at most <span class="math inline"><em>V</em> − 1</span> in a graph
+with <span class="math inline"><em>V</em></span> vertices, we note that
+the <em>distance</em> matrix turns out to be simply <span
+class="math display"><strong>D</strong> = <strong>1</strong><sub><em>V</em> × <em>V</em></sub> ⋅ <em>V</em> − <em>Σ</em><sub><em>i</em> = 0</sub><sup><em>V</em> − 1</sup>min((<strong>M</strong> + <em>I</em>)<sub><em>i</em>, <em>j</em></sub><sup><em>k</em></sup>, 1).</span>
+The runtime to compute this is <span
+class="math inline"><em>O</em>(<em>V</em>)</span>, although it will take
+more space to compute all powers of <span
+class="math inline"><strong>M</strong></span>.</p>
+<h2 id="our-machine-learning-approach">Our Machine Learning
+Approach</h2>
+<h3 id="data">Data</h3>
<p>We will represent an <span class="math inline"><em>n</em></span>
vertex, <span class="math inline"><em>m</em></span> edge unweighted,
undirected graph as sequence of the endpoints of the <span
@@ -321,7 +399,7 @@ class="math inline"><em>n</em> ∈ [16, 32)</span> instead.</li>
generalize to different distributions of graphs, such as denser graphs
or graphs with different properties. Time permitting, we will also
investigate this.</p>
-<h2 id="architecture">Architecture</h2>
+<h3 id="architecture">Architecture</h3>
<p>We plan to use a standard transformer architecture. We will ensure
that the number of layers in our transformer is at least the diameter of
the graph. By doing this, we ensure that there is an extremely simple
@@ -334,7 +412,8 @@ vertices besides <span class="math inline">2</span> – it seems like the
model should be computing these other distances as intermediate values
in its computation to find the distance to vertex <span
class="math inline">2</span>.</p>
-<h2 id="positional-encodings">Positional Encodings</h2>
+<h3 id="embeddings">Embeddings</h3>
+<p>TODO: fix this</p>
<p>In order to facilitate performing this task with limited
computational resources, we plan to use custom-made positional encodings
that tell the model extra information about the structure of the
@@ -350,6 +429,23 @@ pair is nearly orthogonal with high probability. We will concatenate
these with the token encodings rather than adding them. This should let
the model easily have large attention scores between vertices
corresponding to a single edge.</p>
+<h3 id="explicit-transformer-formula-for-shortest-paths">Explicit
+transformer formula for shortest paths</h3>
+<h2 id="results">Results</h2>
+<h3 id="initial-results">Initial Results</h3>
+<p>We used a model dimension of 64 64, four layers, and two heads per
+layer. We used MSE loss, the Adam optimizer, a learning rate of 8e-4,
+and a batch size of 131,072 for 8000 unique randomly generated batches.
+Our final MSE loss was 0.35546875.</p>
+<p><img src="training-loss.png" /></p>
+<p><img src="training-2d-histogram.png" /></p>
+<h3 id="fine-tuning">Fine Tuning</h3>
+<p>After receiving our initial results, we fine-tuned with a learning
+rate of 1e-5, also with MSE and the same batch size. Our final results
+are shown below.</p>
+<p><img src="fine-tuning-loss.png" /></p>
+<p><img src="fine-tuning-2d-histogram.png" /></p>
+<p><img src="test-2d-histogram.png" /></p>
<h2 class="unnumbered" id="references">References</h2>
<div id="refs" class="references csl-bib-body hanging-indent"
data-entry-spacing="0" role="list">
diff --git a/insane-shortest-paths.ipynb b/insane-shortest-paths.ipynb
new file mode 100644
index 0000000..72846c2
--- /dev/null
+++ b/insane-shortest-paths.ipynb
@@ -0,0 +1,263 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "execution_state": "idle",
+ "id": "86ce5f44-94f6-43b0-a0d1-091b8134ffb6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[set(), set(), {5, 6}, {4}, {3}, {2, 6}, {2, 5}]\n",
+ "[set(), {6}, set(), {4, 5, 6}, {3}, {3, 6}, {1, 3, 5}]\n",
+ "[set(), {4}, set(), {4, 5}, {1, 3}, {3, 6}, {5}]\n",
+ "[set(), {2, 6}, {1, 6}, {6}, set(), set(), {1, 2, 3}]\n",
+ "[set(), {3}, {3}, {1, 2, 5, 6}, {5}, {3, 4}, {3}]\n",
+ "[set(), {3, 6}, {4}, {1}, {2}, {6}, {1, 5}]\n",
+ "[set(), {2, 3}, {1, 3, 6}, {1, 2, 4}, {3}, set(), {2}]\n",
+ "[set(), {4}, set(), {4}, {1, 3, 5, 6}, {4}, {4}]\n",
+ "[set(), {3, 4, 5}, {6}, {1}, {1, 6}, {1}, {2, 4}]\n",
+ "[set(), {5, 6}, {6}, {6}, {5, 6}, {1, 4}, {1, 2, 3, 4}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# -*- coding: utf-8 -*-\n",
+ "\"\"\"how-tsp-should-be.ipynb\n",
+ "\n",
+ "Automatically generated by Colab.\n",
+ "\n",
+ "Original file is located at\n",
+ " https://colab.research.google.com/drive/1InE1iW8ARzndPpvqH_9y22s81sOiHxPs\n",
+ "\"\"\"\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import matplotlib as mpl\n",
+ "import matplotlib.pyplot as plt\n",
+ "from torch.utils.data import DataLoader, TensorDataset\n",
+ "\n",
+ "from math import sqrt\n",
+ "from collections import deque\n",
+ "import os\n",
+ "import random\n",
+ "import pickle\n",
+ "import ipdb\n",
+ "\n",
+ "# torch.manual_seed(30)\n",
+ "# random.seed(30)\n",
+ "torch.manual_seed(33)\n",
+ "random.seed(33)\n",
+ "\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "# assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\"\n",
+ "\n",
+ "NVTXS = 6\n",
+ "MAXDIST = NVTXS+1\n",
+ "AVGDEG = 2\n",
+ "SEQLEN = NVTXS + 1\n",
+ "HIDDENDIM = 4*NVTXS+2\n",
+ "\n",
+ "# 0: ANSFLAG\n",
+ "# 1:NVTXS+1 NBRS\n",
+ "# NVTXS+1: 2*NVTXS+1 REACH\n",
+ "# 2*NVTXS+1: 3*NVTXS+1 SELF\n",
+ "# -1 NOTANSFLAG\n",
+ "\n",
+ "START_REACH = NVTXS+1\n",
+ "START_OUT = 2*NVTXS+1\n",
+ "START_SELF = 3*NVTXS+1\n",
+ "SRC_FLAG_IDX = START_SELF\n",
+ "SOURCE = 1\n",
+ "TARGET = 2\n",
+ "ANS_FLAG_IDX = 0\n",
+ "NOTANS_FLAG_IDX = -1\n",
+ "\n",
+ "def print_everything(data):\n",
+ " print(\"NBRS\")\n",
+ " print(data[0, 1:, 1:1+NVTXS])\n",
+ " print(\"REACH\")\n",
+ " print(data[0, 1:, START_REACH:START_REACH+NVTXS])\n",
+ " print(\"ANSFLAG\")\n",
+ " print(data[0, :, 0])\n",
+ " print(\"MORE FLAGS\")\n",
+ " print(data[0, :, -1])\n",
+ " print(\"SELF\")\n",
+ " print(data[0, 1:, START_SELF:START_SELF+NVTXS])\n",
+ " print(\"OUT\")\n",
+ " print(data[0, 0, START_OUT:START_OUT+NVTXS])\n",
+ "\n",
+ "\n",
+ "def random_graph():\n",
+ " data = torch.zeros((SEQLEN, HIDDENDIM))\n",
+ "\n",
+ " for i in range(1,NVTXS+1):\n",
+ " data[i, START_SELF-1+i] = 1\n",
+ "\n",
+ " adj_list = [set() for _ in range(SEQLEN)]\n",
+ " indices = [random.randint(1, NVTXS) for _ in range(AVGDEG * NVTXS)]\n",
+ " for i in range(0, len(indices), 2):\n",
+ " u = indices[i]\n",
+ " v = indices[i + 1]\n",
+ " if u != v:\n",
+ " data[v,u] = 1\n",
+ " data[u,v] = 1\n",
+ " data[v,NVTXS+u] = 1\n",
+ " data[u,NVTXS+v] = 1\n",
+ " adj_list[u].add(v)\n",
+ " adj_list[v].add(u)\n",
+ "\n",
+ " data[0, ANS_FLAG_IDX] = 1\n",
+ " data[1:, NOTANS_FLAG_IDX] = 1\n",
+ "\n",
+ " # TODO: this is kind of a hack\n",
+ " data[0, START_REACH:START_REACH+NVTXS] = 1\n",
+ " return data, adj_list\n",
+ "\n",
+ "\"\"\"\n",
+ "input: G, represented as an adjacency list\n",
+ "output: distance from SOURCE to TARGET\n",
+ "\"\"\"\n",
+ "def SSSP(G):\n",
+ " dist = [MAXDIST for _ in G]\n",
+ " dist[SOURCE] = 0\n",
+ " frontier = deque()\n",
+ " frontier.append(SOURCE)\n",
+ " while len(frontier) > 0:\n",
+ " vtx = frontier.popleft()\n",
+ " for x in G[vtx]:\n",
+ " if dist[x] == MAXDIST:\n",
+ " dist[x] = 1 + dist[vtx]\n",
+ " frontier.append(x)\n",
+ " if x == TARGET:\n",
+ " return dist[TARGET]\n",
+ " return MAXDIST\n",
+ "\n",
+ "def mkbatch(size):\n",
+ " graphs1 = []\n",
+ " distance1 = []\n",
+ "\n",
+ " for i in range(size):\n",
+ " data, adj_list = random_graph()\n",
+ " dist = SSSP(adj_list)\n",
+ " graphs1.append(data)\n",
+ " distance1.append(dist)\n",
+ "\n",
+ " print(adj_list)\n",
+ "\n",
+ " data = torch.stack(graphs1)\n",
+ " labels = torch.tensor(distance1, dtype=torch.float16)\n",
+ " return data, labels\n",
+ "\n",
+ "\"\"\"\n",
+ "TODO: WRAP EVERYTHING in nn.Parameter(torch.zeros((1, HIDDENDIM)))\n",
+ "and then do my perturbing parameters experiment\n",
+ "\n",
+ "TODO:\n",
+ " USE activation magic to bring everything back to the 0/1 realm instead of possibly being 0/2 valued\n",
+ "\"\"\"\n",
+ "\n",
+ "class SillyTransformer(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.most_KQVs = []\n",
+ " for head in range(1,NVTXS+1):\n",
+ " Q = torch.zeros((2, HIDDENDIM))\n",
+ " Q[0, START_REACH-1+head] = 1000\n",
+ " Q[1, NOTANS_FLAG_IDX] = 1\n",
+ "\n",
+ " K = torch.zeros((2, HIDDENDIM))\n",
+ " K[0, head] = 1\n",
+ " K[1, ANS_FLAG_IDX] = 200\n",
+ "\n",
+ " V = torch.zeros((NVTXS,HIDDENDIM))\n",
+ " for i in range(NVTXS):\n",
+ " V[i, START_SELF+i] = 1\n",
+ "\n",
+ " self.most_KQVs.append((K, Q, V))\n",
+ "\n",
+ " self.weird_KQVs = []\n",
+ " for layer in range(NVTXS):\n",
+ " K = torch.zeros((3, HIDDENDIM))\n",
+ " K[0, NOTANS_FLAG_IDX] = -1000\n",
+ " K[0, SRC_FLAG_IDX] = +1100\n",
+ " K[1, NOTANS_FLAG_IDX] = -1000\n",
+ " K[1, NVTXS+TARGET] = +1100\n",
+ " K[1, ANS_FLAG_IDX] = -1100\n",
+ " K[2, ANS_FLAG_IDX] = 10\n",
+ "\n",
+ " Q = torch.zeros((3, HIDDENDIM))\n",
+ " Q[:, ANS_FLAG_IDX] = 1\n",
+ "\n",
+ " V = torch.zeros((NVTXS, HIDDENDIM))\n",
+ " V[layer, SRC_FLAG_IDX] = 1\n",
+ "\n",
+ " self.weird_KQVs.append((K, Q, V))\n",
+ "\n",
+ " def forward(self, src):\n",
+ " for layer in range(NVTXS):\n",
+ " allKQVs = [self.weird_KQVs[layer]] + self.most_KQVs\n",
+ " head_outputs = []\n",
+ " for (K, Q, V) in allKQVs:\n",
+ " ksrc = torch.matmul(src, K.unsqueeze(0).transpose(-2, -1))\n",
+ " qsrc = torch.matmul(src, Q.unsqueeze(0).transpose(-2, -1))\n",
+ " vsrc = torch.matmul(src, V.unsqueeze(0).transpose(-2, -1))\n",
+ "\n",
+ " scores = torch.matmul(qsrc, ksrc.transpose(-2, -1))\n",
+ " attention_weights = torch.softmax(scores, dim=-1)\n",
+ " head_output = torch.matmul(attention_weights, vsrc)\n",
+ " head_outputs.append(head_output)\n",
+ "\n",
+ " new_reaches = sum(head_outputs[1:])\n",
+ " BSZ = new_reaches.shape[0]\n",
+ "\n",
+ " nodelta_nbrs = torch.zeros((BSZ, SEQLEN, NVTXS+1))\n",
+ " morepadlol = torch.zeros((BSZ, SEQLEN, 1+NVTXS))\n",
+ "\n",
+ " DIFF = torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)\n",
+ " src += torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)\n",
+ " src[:, :, START_REACH:START_REACH+NVTXS] = 2*torch.sigmoid(src[:,:, START_REACH:START_REACH+NVTXS]*1000)-1\n",
+ "\n",
+ " # print(\"SRC\")\n",
+ " # print_everything(src)\n",
+ "\n",
+ " canreach = src[:,0,START_OUT:START_OUT+NVTXS]\n",
+ " # __import__('ipdb').set_trace()\n",
+ " final_output = 1+torch.sum(1-canreach,dim=1)\n",
+ " return final_output\n",
+ "\n",
+ "model = SillyTransformer()\n",
+ "model.to(device)\n",
+ "\n",
+ "data, labels = mkbatch(10)\n",
+ "assert torch.all(model(data) == labels)\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/test-2d-histogram.png b/test-2d-histogram.png
new file mode 100644
index 0000000..75f2f6e
--- /dev/null
+++ b/test-2d-histogram.png
Binary files differ
diff --git a/training-2d-histogram.png b/training-2d-histogram.png
new file mode 100644
index 0000000..45d53a2
--- /dev/null
+++ b/training-2d-histogram.png
Binary files differ
diff --git a/training-loss.png b/training-loss.png
new file mode 100644
index 0000000..98a2a20
--- /dev/null
+++ b/training-loss.png
Binary files differ