aboutsummaryrefslogtreecommitdiff
path: root/index.html
blob: 4d7b64d039d5cb9aed81e5d13bb2ae7b402bd98e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml" lang="" xml:lang="">
<head>
  <meta charset="utf-8" />
  <meta name="generator" content="pandoc" />
  <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes" />
  <title>Off-Distribution Generalization of Transformers for Shortest Paths</title>
  <style>
    html {
      color: #1a1a1a;
      background-color: #fdfdfd;
    }
    body {
      margin: 0 auto;
      max-width: 36em;
      padding-left: 50px;
      padding-right: 50px;
      padding-top: 50px;
      padding-bottom: 50px;
      hyphens: auto;
      overflow-wrap: break-word;
      text-rendering: optimizeLegibility;
      font-kerning: normal;
    }
    @media (max-width: 600px) {
      body {
        font-size: 0.9em;
        padding: 12px;
      }
      h1 {
        font-size: 1.8em;
      }
    }
    @media print {
      html {
        background-color: white;
      }
      body {
        background-color: transparent;
        color: black;
        font-size: 12pt;
      }
      p, h2, h3 {
        orphans: 3;
        widows: 3;
      }
      h2, h3, h4 {
        page-break-after: avoid;
      }
    }
    p {
      margin: 1em 0;
    }
    a {
      color: #1a1a1a;
    }
    a:visited {
      color: #1a1a1a;
    }
    img {
      max-width: 100%;
    }
    svg {
      height: auto;
      max-width: 100%;
    }
    h1, h2, h3, h4, h5, h6 {
      margin-top: 1.4em;
    }
    h5, h6 {
      font-size: 1em;
      font-style: italic;
    }
    h6 {
      font-weight: normal;
    }
    ol, ul {
      padding-left: 1.7em;
      margin-top: 1em;
    }
    li > ol, li > ul {
      margin-top: 0;
    }
    blockquote {
      margin: 1em 0 1em 1.7em;
      padding-left: 1em;
      border-left: 2px solid #e6e6e6;
      color: #606060;
    }
    code {
      font-family: Menlo, Monaco, Consolas, 'Lucida Console', monospace;
      font-size: 85%;
      margin: 0;
      hyphens: manual;
    }
    pre {
      margin: 1em 0;
      overflow: auto;
    }
    pre code {
      padding: 0;
      overflow: visible;
      overflow-wrap: normal;
    }
    .sourceCode {
     background-color: transparent;
     overflow: visible;
    }
    hr {
      background-color: #1a1a1a;
      border: none;
      height: 1px;
      margin: 1em 0;
    }
    table {
      margin: 1em 0;
      border-collapse: collapse;
      width: 100%;
      overflow-x: auto;
      display: block;
      font-variant-numeric: lining-nums tabular-nums;
    }
    table caption {
      margin-bottom: 0.75em;
    }
    tbody {
      margin-top: 0.5em;
      border-top: 1px solid #1a1a1a;
      border-bottom: 1px solid #1a1a1a;
    }
    th {
      border-top: 1px solid #1a1a1a;
      padding: 0.25em 0.5em 0.25em 0.5em;
    }
    td {
      padding: 0.125em 0.5em 0.25em 0.5em;
    }
    header {
      margin-bottom: 4em;
      text-align: center;
    }
    #TOC li {
      list-style: none;
    }
    #TOC ul {
      padding-left: 1.3em;
    }
    #TOC > ul {
      padding-left: 0;
    }
    #TOC a:not(:hover) {
      text-decoration: none;
    }
    code{white-space: pre-wrap;}
    span.smallcaps{font-variant: small-caps;}
    div.columns{display: flex; gap: min(4vw, 1.5em);}
    div.column{flex: auto; overflow-x: auto;}
    div.hanging-indent{margin-left: 1.5em; text-indent: -1.5em;}
    /* The extra [class] is a hack that increases specificity enough to
       override a similar rule in reveal.js */
    ul.task-list[class]{list-style: none;}
    ul.task-list li input[type="checkbox"] {
      font-size: inherit;
      width: 0.8em;
      margin: 0 0.8em 0.2em -1.6em;
      vertical-align: middle;
    }
    /* CSS for syntax highlighting */
    pre > code.sourceCode { white-space: pre; position: relative; }
    pre > code.sourceCode > span { line-height: 1.25; }
    pre > code.sourceCode > span:empty { height: 1.2em; }
    .sourceCode { overflow: visible; }
    code.sourceCode > span { color: inherit; text-decoration: inherit; }
    div.sourceCode { margin: 1em 0; }
    pre.sourceCode { margin: 0; }
    @media screen {
    div.sourceCode { overflow: auto; }
    }
    @media print {
    pre > code.sourceCode { white-space: pre-wrap; }
    pre > code.sourceCode > span { display: inline-block; text-indent: -5em; padding-left: 5em; }
    }
    pre.numberSource code
      { counter-reset: source-line 0; }
    pre.numberSource code > span
      { position: relative; left: -4em; counter-increment: source-line; }
    pre.numberSource code > span > a:first-child::before
      { content: counter(source-line);
        position: relative; left: -1em; text-align: right; vertical-align: baseline;
        border: none; display: inline-block;
        -webkit-touch-callout: none; -webkit-user-select: none;
        -khtml-user-select: none; -moz-user-select: none;
        -ms-user-select: none; user-select: none;
        padding: 0 4px; width: 4em;
        color: #aaaaaa;
      }
    pre.numberSource { margin-left: 3em; border-left: 1px solid #aaaaaa;  padding-left: 4px; }
    div.sourceCode
      {   }
    @media screen {
    pre > code.sourceCode > span > a:first-child::before { text-decoration: underline; }
    }
    code span.al { color: #ff0000; font-weight: bold; } /* Alert */
    code span.an { color: #60a0b0; font-weight: bold; font-style: italic; } /* Annotation */
    code span.at { color: #7d9029; } /* Attribute */
    code span.bn { color: #40a070; } /* BaseN */
    code span.bu { color: #008000; } /* BuiltIn */
    code span.cf { color: #007020; font-weight: bold; } /* ControlFlow */
    code span.ch { color: #4070a0; } /* Char */
    code span.cn { color: #880000; } /* Constant */
    code span.co { color: #60a0b0; font-style: italic; } /* Comment */
    code span.cv { color: #60a0b0; font-weight: bold; font-style: italic; } /* CommentVar */
    code span.do { color: #ba2121; font-style: italic; } /* Documentation */
    code span.dt { color: #902000; } /* DataType */
    code span.dv { color: #40a070; } /* DecVal */
    code span.er { color: #ff0000; font-weight: bold; } /* Error */
    code span.ex { } /* Extension */
    code span.fl { color: #40a070; } /* Float */
    code span.fu { color: #06287e; } /* Function */
    code span.im { color: #008000; font-weight: bold; } /* Import */
    code span.in { color: #60a0b0; font-weight: bold; font-style: italic; } /* Information */
    code span.kw { color: #007020; font-weight: bold; } /* Keyword */
    code span.op { color: #666666; } /* Operator */
    code span.ot { color: #007020; } /* Other */
    code span.pp { color: #bc7a00; } /* Preprocessor */
    code span.sc { color: #4070a0; } /* SpecialChar */
    code span.ss { color: #bb6688; } /* SpecialString */
    code span.st { color: #4070a0; } /* String */
    code span.va { color: #19177c; } /* Variable */
    code span.vs { color: #4070a0; } /* VerbatimString */
    code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warning */
  </style>
  <script defer=""
  src="https://cdn.jsdelivr.net/npm/katex@0.15.1/dist/katex.min.js"></script>
  <script>document.addEventListener("DOMContentLoaded", function () {
 var mathElements = document.getElementsByClassName("math");
 var macros = [];
 for (var i = 0; i < mathElements.length; i++) {
  var texText = mathElements[i].firstChild;
  if (mathElements[i].tagName == "SPAN") {
   katex.render(texText.data, mathElements[i], {
    displayMode: mathElements[i].classList.contains('display'),
    throwOnError: false,
    macros: macros,
    fleqn: false
   });
}}});
  </script>
  <link rel="stylesheet"
  href="https://cdn.jsdelivr.net/npm/katex@0.15.1/dist/katex.min.css" />
</head>
<body>
<header id="title-block-header">
<h1 class="title">Off-Distribution Generalization of Transformers for
Shortest Paths</h1>
</header>
<!-- Guidelines: https://www.dropbox.com/scl/fi/bet8enscln8ue36kd8t17/final_project_guidelines.pdf?rlkey=knd19cnumk51ho1y9crno56ib&e=2&dl=0 -->
<style>
body {
    /* Shrink margins */
    max-width: 45em;
    /* font-size: 110%; */
}

/* Make sure toolbar plots aren't cut off */
.mpld3-figure {
    height: 480px;
    display: block;
    margin: auto;
}
</style>
<div style="text-align:center">
<p>Anthony Wang, Alek Westover, Kevin Zhao</p>
<p>{xy,alekw,kevinmz}@mit.edu</p>
</div>
<h2 id="introduction">Introduction</h2>
<h3 id="motivation">Motivation</h3>
<p>Neural networks are capable of impressive feats of off-distribution
generalization. For instance, a program trained to convert sketches of
cats into realistic looking pictures of cats can draw a cat with three
eyes, if the sketch has three eyes, even if such an image never occurred
in its training data. However, neural networks also often learn
non-robust features that cause it to perform very poorly
off-distribution (e.g., adversarial examples for an image classifier).
In this project, we will investigate the question of when transformers
generalize off-distribution through via a case study on a simple
synthetic task. More specifically, the goal of our project is to make
progress towards answering the following question:</p>
<blockquote>
<p>Suppose you train a model to be truthful on task A, and to produce
coherent answers on a related task B.</p>
<p>Will the model learn to be truthful on task B?</p>
</blockquote>
<p>Here’s a more specific version of this question, inspired by a <a
href="https://www.alignmentforum.org/posts/BxersHYN2qcFoonwg/experimentally-evaluating-whether-honesty-generalizes?commentId=dsDA2BWpHPdgLvaXX">blog
post written by Paul Chrisiano</a> (who was influential in developing
RLHF).</p>
<blockquote>
<p>Suppose we trained a language model on next token prediction in both
English and French, and then fine-tuned the model to be truthful in
English. Would the model then learn to answer questions truthfully in
French? It has been shown <a href="#fn1" class="footnote-ref"
id="fnref1" role="doc-noteref"><sup>1</sup></a> that this doesn’t happen
by default, but you could hope that the truthfulness training in English
would generalize to French.</p>
</blockquote>
<p>In the future, humans are likely to trust AIs with performing
important and complex tasks. For some of these tasks it will be
infeasible for humans to audit the answers – either because human labor
is too expensive, or because the AIs are too competent, so humans aren’t
skilled enough to perform an audit (for example, an amateur chess player
might have a hard time telling apart a good chess move and a bad chess
move). However, there will be some simpler tasks where we can audit the
AI’s and reward truthul answers. So, it’s plausible that models trained
to be truthful in some simple domains will learn truthfulness as a
general lesson.</p>
<h3 id="task">Task</h3>
<p>In order to investigate generalization of this type in transformers,
we decided to focus on a simple synthetic task – shortest paths in a
graph. More specifically, we ran the following experiment (see <a
href="#data">Data</a> for more details):</p>
<ol type="1">
<li><strong>Pre-train</strong> a transformer to predict the distance
between vertices <span class="math inline">1</span> and <span
class="math inline">2</span> in graphs with <span class="math inline">n
\in [3,15]</span> vertices.</li>
</ol>
<div style="text-align:center">
<p><img src="img/train.svg" /></p>
</div>
<ol start="2" type="1">
<li><strong>Fine-tune</strong> a transformer to predict the distances
between vertex <span class="math inline">1</span> to <span
class="math inline">t</span> for any <span class="math inline">t</span>
on the shortest path from <span class="math inline">1</span> to <span
class="math inline">2</span>, but only do fine-tuning on small graphs
with <span class="math inline">n \in [3,7]</span> vertices. Our
intuition was that, in training a model to find distances between
vertices <span class="math inline">1</span> and <span
class="math inline">2</span>, the model would also implicitly need to
learn about distances from vertex <span class="math inline">1</span> to
other vertices on this shortest path.</li>
</ol>
<div style="text-align:center">
<p><img src="img/finetune.svg" /></p>
</div>
<ol start="3" type="1">
<li><strong>Test</strong> whether the transformer can accurately
generalize to predicting the distance between <span
class="math inline">1</span> to <span class="math inline">t</span> for
any <span class="math inline">t \in [3,15]</span> on the shortest path
from <span class="math inline">1</span> to <span
class="math inline">2</span> for graphs with <span class="math inline">n
\in [3,15]</span> vertices.</li>
</ol>
<div style="text-align:center">
<p><img src="img/test.svg" /></p>
</div>
<p>Obviously, if you actually want to compute shortest paths in a graph,
you should use the standard graph algorithm BFS (breadth-first search).
Our motivation for using this problem to study generalization was
two-fold:</p>
<ol type="1">
<li><p>It’s very easy to generate data and ground truth labels for this
task.</p></li>
<li><p>There is a simple circuit which computes shortest paths
(exactly), by doing something akin to BFS. In fact, we implemented this
simple circuit by hand in the section <a
href="#handcrafted-bfs-transformer">Handcrafted BFS Transformer</a>. Our
hope was that our model can also find a simple circuit that generalized
well.</p></li>
</ol>
<h3 id="findings">Findings</h3>
<p>Our main findings were as follows:</p>
<ol type="1">
<li><p>There exists a simple set of weights that you can find by hand to
achieve arbitrarily small MSE loss on shortest paths. Moreover, this set
of weights is “robust” in the sense that, if you perturb some of these
weights and run gradient descent, the model will recover low
loss.</p></li>
<li><p>Transformers can learn on their own to perform shortest paths.
The training process consists of a sequence of <strong>phase
changes</strong> where the loss would barely change for a large number
of epochs, and then over a very small number of epochs, the loss would
dramatically go down. (TODO: Is this just because we’re bad at DL? now
it shouldn’t have as many phase changes. Or maybe just in the
beginning?) In particular, these phase changes corresponded to times
when the model decided to, e.g., go from dramatically over-classifying
paths as being length 1 (resulting in great performance on tasks of
length 1 but poor performance on other tasks) to spreading these length
1 guesses out to other distances.</p></li>
<li><p>A model trained to compute the distance between two specific
vertices can be fine-tuned efficiently to compute the shortest distance
to other vertices.</p></li>
<li><p>A model trained to compute the distance between vertices <span
class="math inline">1</span> and <span class="math inline">2</span> on
large graphs, and fine tuned to compute the distance between vertices 1
and any vertex (on a shortest <span class="math inline">1</span>-<span
class="math inline">2</span> path) in small graphs, has the following
generalization properties:</p>
<ul>
<li>In larger graphs, the model can still compute the shortest distance
between <span class="math inline">1</span> and any <span
class="math inline">t</span> whose label existed in the smaller graphs,
namely <span class="math inline">t \in [3,7]</span>.</li>
<li>However, the model fails when the target vertex doesn’t exist in
smaller graphs.</li>
</ul></li>
</ol>
<p>The main conceptual implication of this work is that, while
off-distribution generalization sometimes does occur, sometimes it also
doesn’t, and it seems tricky to predict when off-distribution
generalization will happen.</p>
<h3 id="post-outline">Post Outline</h3>
<p>The remainder of our post is structured as follows:</p>
<ol type="1">
<li><p><a href="#handcrafted-bfs-transformer">Handcrafted BFS
Transformer</a>: To start, we explicitly construct by hand weights for a
transformer to solve the shortest paths problem. We also investigate the
“robustness” of our construction by showing that, starting from our
constructed solution it’s possible to recover low loss after perturbing
the weights via gradient descent. This gives a proof-of-concept that a
transformer could encode a simple generalizable solution to the problem,
and gives some reason to believe that a transformer could learn such a
solution.</p></li>
<li><p><a href="#training">Training</a>: Next, we discuss the training.
First we describe the exact data generation process and model
architecture. Then we discuss hyperparameter settings and optimization
choices that were important to get the model to train efficiently.
Finally, we investigate some interesting phenomena that occurred during
training, such as phase transitions in loss for different path
lengths.</p></li>
<li><p><a href="#generalization">Generalization</a>: Finally we discuss
the phenomena that we wanted to study – generalization! We show that
some generalization does occur, but other generalization doesn’t. We
explain why we think this happens.</p></li>
</ol>
<h2 id="handcrafted-bfs-transformer">Handcrafted BFS Transformer</h2>
<p>In this section we describe an explicit set of weights, made (with
love) by hand, that result in a transformer that achieves MSE less than
<span class="math inline">10^{-6}</span> for the shortest paths problem.
We’ll start by describing the (somewhat involved) constructions, and
then investigate the geometry of the loss landscape near our explicit
solution. The main reason for doing this is to demonstrate the existence
of a simple transformer computing shortest paths, to give
proof-of-concept that the transformer could learn a generalizable
solution.</p>
<p>The construction is somewhat tricky to write down, so we’ll make
several (non-essential) simplifying assumptions:</p>
<ol type="1">
<li>We only care about finding the distance between vertex 1 and vertex
2.</li>
<li>We are given a nice embedding of the graph.</li>
</ol>
<p>It will be obvious from the construction how to get rid of
assumptions (1) and (2).</p>
<p>The embedding of the graph that we will work with here will be
different from the embeddings that we use in the rest of the post. In
particular, in this section only the tokens will correspond to vertices.
In the remainder of the post tokens will correspond to edges.</p>
<p>More specifically, we will create the following embeddings for our
input sequence:</p>
<table>
<colgroup>
<col style="width: 25%" />
<col style="width: 20%" />
<col style="width: 22%" />
<col style="width: 13%" />
<col style="width: 18%" />
</colgroup>
<thead>
<tr class="header">
<th></th>
<th>Answer token</th>
<th><span class="math inline">v_{1}</span> token</th>
<th><span class="math inline">\ldots</span></th>
<th><span class="math inline">v_n</span> token</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Answer flag</td>
<td>1</td>
<td>0</td>
<td></td>
<td>0</td>
</tr>
<tr class="even">
<td>Neighbors</td>
<td></td>
<td>01101</td>
<td></td>
<td>10010</td>
</tr>
<tr class="odd">
<td>Reach</td>
<td></td>
<td>01101</td>
<td></td>
<td>10010</td>
</tr>
<tr class="even">
<td>Out</td>
<td>00000</td>
<td></td>
<td></td>
<td></td>
</tr>
<tr class="odd">
<td>Self</td>
<td></td>
<td>10000</td>
<td></td>
<td>00001</td>
</tr>
<tr class="even">
<td>Not-answer flag</td>
<td>0</td>
<td>1</td>
<td></td>
<td>1</td>
</tr>
</tbody>
</table>
<p>As you can see, the hidden dimension of the model will be <span
class="math inline">4n+2</span>. The hidden dimension is broken up into
5 sections, some of which are only used in the answer token, and some of
which are only used in the vertex tokens. Here are what the different
parts of the hidden dimension represent:</p>
<ol type="1">
<li>Flags: booleans indicating whether a vertex is the answer token or
not. If desired, you can think of this as part of a “positional
embedding”. The flags will not be updated as we move through the
transformer.</li>
<li>Neighbors: a boolean vector (of dimension <span
class="math inline">n</span>) describing the neighbors of each vertex.
(We just put zeros here for the answer token). This also won’t update
through the course of running the transformer.</li>
<li>Reach: a boolean vector representing the vertices that each vertex
can reach. This will update over the course of a forward pass of the
transformer. Specifically, at layer <span class="math inline">i</span>
of the transformer we will extend the reach of each vertex from vertices
which are distance <span class="math inline">i+1</span> away, to
vertices at distance <span class="math inline">i+2</span> away.</li>
<li>Out: this starts as zero for all tokens. Once vertex <span
class="math inline">1</span> can reach vertex <span
class="math inline">2</span> (i.e., we are in a layer deep enough that
the reach vector for <span class="math inline">v_{1}</span>’s token
indicates that it can reach vertex <span class="math inline">2</span>)
we’ll start writing <span class="math inline">1</span> on the answer
token’s out part of its embedding.</li>
<li>Self: a 1-hot encoding specifying which vertex each token
corresponds to. You should think of this as a positional embedding
(along with the flags).</li>
</ol>
<p>The transformer will consist of <span class="math inline">n</span>
layers, each of which has <span class="math inline">n+1</span> heads.
We’ve included the code for the handcoded transformer in the <a
href="#appendix">Appendix</a> – possibly it’s helpful to look at the
code in conjunction with the written explanation below.</p>
<p>In layer <span class="math inline">i</span>, our goal is to extend
the reach of each vertex <span class="math inline">v</span> from the set
of vertices at distance <span class="math inline">i+1</span> from <span
class="math inline">v</span> to the set of vertices at distance <span
class="math inline">i+2</span> from <span class="math inline">v</span>.
The <span class="math inline">j</span>-th head will perform the
following operation:</p>
<pre><code>In head j at layer i:
Vertex v&#39;s query vector will have high dot-product with vertex w&#39;s key vector
iff vertex v can reach vertex j, and vertex w is neighbors with vertex j 

The value that vertex w would give to vertex v is simply a 1-hot vector, with a 1 in the w-th position</code></pre>
<p>After running the heads, each vertex’s embedding is updated to
integrate the new vertices that it can reach into its embedding. Note
that we can think of embeddings as being updated because transformers
have a residual stream of embeddings that we make modifications to.</p>
<p>Finally, the last head will be in charge of noticing whether vertex 1
has reached vertex 2 yet, and if so recording this fact. For the full
details, see <a href="#appendix">appendix</a>.</p>
<h3 id="robustness">Robustness</h3>
<p>The field of Singular Learning Theory (SLT; see Liam Carroll’s
Master’s thesis “Phase Transitions in Neural Networks” for an
introduction) aims to understand model training and loss-landscape
geometry. In efforts to better understand the loss landscape of the
shortest paths loss function according to the tokens used in our hand
coded implementation of the shortest paths transformers, we decided to
start at a good setting of the parameters, and then perturb the weights,
and see if the model can subsequently achieve low loss. The intuition
for why this is a good approach at measuring “how attractive of a loss
basin” we have is that this experiment is similar to the Local Learning
Coefficient from SLT. (see Lau, Edmund, Zach Furman, George Wang, Daniel
Murfet, and Susan Wei. “The Local Learning Coefficient: A
Singularity-Aware Complexity Measure”). We found that, perturbing the
weights led to high loss, but gradient descent was able to recover low
loss, indicating that the solution is somewhat “findable” by gradient
descent.</p>
<p>TODO: Switch to fancy mpld3 plots <img src="perturb.png"
alt="perturb.png" /> <img src="perturb-loss.png"
alt="perturb-loss.png" /></p>
<h2 id="training">Training</h2>
<h3 id="data">Data</h3>
<p>We’ll represent an <span class="math inline">n</span> vertex, <span
class="math inline">m</span> edge unweighted, undirected graph as
sequence of the endpoints of the <span class="math inline">m</span>
edges, so <span
class="math inline">[a_1,b_1,a_2,b_2,\ldots,a_m,b_m]</span> represents a
graph with the edges <span class="math inline">\{(a_i,b_i)\}</span>
where <span class="math inline">a_i &lt; b_i</span> for <span
class="math inline">1 \leq i \leq m</span>. All sequences are padded to
the same length using the padding token <span
class="math inline">0</span>.</p>
<p>The full input to our model additionally includes the target vertex
<span class="math inline">t</span> after the padding tokens. The label
to an input is the length of the shortest path from <span
class="math inline">1</span> to <span class="math inline">t</span>. If
no such path exists, we define the length to be <span
class="math inline">n</span> which represents infinity. For example, the
input <span class="math inline">[1, 3, 3, 4, 2, 4, 2, 3, 0, 0, 0, 0,
3]</span> has the label <span class="math inline">2</span>.</p>
<div style="text-align:center">
<p><img src="img/finetune.svg" /></p>
</div>
<p>We have a dataset for each step.</p>
<ol type="1">
<li><strong>Pre-train data</strong>: For each <span
class="math inline">n \in [3,15]</span>, we generated an equal number of
graphs on <span class="math inline">n</span> vertices, with <span
class="math inline">t = 2</span>. Each graph was created by choosing
<span class="math inline">n</span> random edges. To balance the dataset,
we removed some graphs with shorter paths to obtain the distribution
shown below.</li>
</ol>


<style>

</style>

<div id="fig_el3734471397923808578566893788369"></div>
<script>
function mpld3_load_lib(url, callback){
  var s = document.createElement('script');
  s.src = url;
  s.async = true;
  s.onreadystatechange = s.onload = callback;
  s.onerror = function(){console.warn("failed to load library " + url);};
  document.getElementsByTagName("head")[0].appendChild(s);
}

if(typeof(mpld3) !== "undefined" && mpld3._mpld3IsLoaded){
   // already loaded: just create the figure
   !function(mpld3){
       
       mpld3.draw_figure("fig_el3734471397923808578566893788369", {"width": 640.0, "height": 480.0, "axes": [{"bbox": [0.125, 0.10999999999999999, 0.775, 0.77], "xlim": [-1.19, 16.19], "ylim": [0.0, 7642.95], "xdomain": [-1.19, 16.19], "ydomain": [0.0, 7642.95], "xscale": "linear", "yscale": "linear", "axes": [{"position": "bottom", "nticks": 11, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}, {"position": "left", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}], "axesbg": "#FFFFFF", "axesbgalpha": null, "zoomable": true, "id": "el373447139797053788464", "lines": [], "paths": [{"data": "data01", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381535360", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data02", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792380856416", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data03", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797054022832", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data04", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381530944", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data05", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381539632", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data06", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797053780624", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data07", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532192", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data08", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381536896", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data09", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381539872", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data10", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532480", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data11", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797053797824", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 2, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381533200", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 3, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532240", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 4, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797054021056", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 5, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532576", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381527536", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}], "markers": [], "texts": [{"text": "Label", "position": [0.5, -0.07921476671476674], "coordinates": "axes", "h_anchor": "middle", "v_baseline": "hanging", "rotation": -0.0, "fontsize": 10.0, "color": "#000000", "alpha": 1, "zorder": 3, "id": "el373447139792380852176"}, {"text": "Count", "position": [-0.10161850358422941, 0.5], "coordinates": "axes", "h_anchor": "middle", "v_baseline": "auto", "rotation": -90.0, "fontsize": 10.0, "color": "#000000", "alpha": 1, "zorder": 3, "id": "el373447139792380853616"}], "collections": [], "images": [], "sharex": [], "sharey": []}], "data": {"data01": [[-0.4, 0.0, 10.6, 11.6, 12.6, 13.6], [0.4, 0.0, 11.4, 12.4, 13.4, 14.4], [0.4, 0.0, 11.4, 12.4, 13.4, 14.4], [-0.4, 0.0, 10.6, 11.6, 12.6, 13.6]], "data02": [[0.6, 0.0], [1.4, 0.0], [1.4, 7279.0], [0.6, 7279.0]], "data03": [[1.6, 0.0], [2.4000000000000004, 0.0], [2.4000000000000004, 6614.0], [1.6, 6614.0]], "data04": [[2.6, 0.0], [3.4000000000000004, 0.0], [3.4000000000000004, 4434.0], [2.6, 4434.0]], "data05": [[3.6, 0.0], [4.4, 0.0], [4.4, 2907.0], [3.6, 2907.0]], "data06": [[4.6, 0.0], [5.3999999999999995, 0.0], [5.3999999999999995, 2324.0], [4.6, 2324.0]], "data07": [[5.6, 0.0], [6.3999999999999995, 0.0], [6.3999999999999995, 1578.0], [5.6, 1578.0]], "data08": [[6.6, 0.0], [7.3999999999999995, 0.0], [7.3999999999999995, 469.0], [6.6, 469.0]], "data09": [[7.6, 0.0], [8.4, 0.0], [8.4, 133.0], [7.6, 133.0]], "data10": [[8.6, 0.0], [9.4, 0.0], [9.4, 27.0], [8.6, 27.0]], "data11": [[9.6, 0.0], [10.4, 0.0], [10.4, 4.0], [9.6, 4.0]], "data12": [[14.6, 0.0], [15.4, 0.0], [15.4, 6999.0], [14.6, 6999.0]]}, "id": "el373447139792380857856", "plugins": [{"type": "reset"}, {"type": "zoom", "button": true, "enabled": false}, {"type": "boxzoom", "button": true, "enabled": false}]});
   }(mpld3);
}else if(typeof define === "function" && define.amd){
   // require.js is available: use it to load d3/mpld3
   require.config({paths: {d3: "https://d3js.org/d3.v5"}});
   require(["d3"], function(d3){
      window.d3 = d3;
      mpld3_load_lib("https://mpld3.github.io/js/mpld3.v0.5.10.js", function(){
         
         mpld3.draw_figure("fig_el3734471397923808578566893788369", {"width": 640.0, "height": 480.0, "axes": [{"bbox": [0.125, 0.10999999999999999, 0.775, 0.77], "xlim": [-1.19, 16.19], "ylim": [0.0, 7642.95], "xdomain": [-1.19, 16.19], "ydomain": [0.0, 7642.95], "xscale": "linear", "yscale": "linear", "axes": [{"position": "bottom", "nticks": 11, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}, {"position": "left", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}], "axesbg": "#FFFFFF", "axesbgalpha": null, "zoomable": true, "id": "el373447139797053788464", "lines": [], "paths": [{"data": "data01", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381535360", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data02", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792380856416", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data03", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797054022832", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data04", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381530944", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data05", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381539632", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data06", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797053780624", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data07", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532192", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data08", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381536896", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data09", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381539872", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data10", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532480", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data11", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797053797824", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 2, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381533200", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 3, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532240", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 4, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797054021056", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 5, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532576", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381527536", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}], "markers": [], "texts": [{"text": "Label", "position": [0.5, -0.07921476671476674], "coordinates": "axes", "h_anchor": "middle", "v_baseline": "hanging", "rotation": -0.0, "fontsize": 10.0, "color": "#000000", "alpha": 1, "zorder": 3, "id": "el373447139792380852176"}, {"text": "Count", "position": [-0.10161850358422941, 0.5], "coordinates": "axes", "h_anchor": "middle", "v_baseline": "auto", "rotation": -90.0, "fontsize": 10.0, "color": "#000000", "alpha": 1, "zorder": 3, "id": "el373447139792380853616"}], "collections": [], "images": [], "sharex": [], "sharey": []}], "data": {"data01": [[-0.4, 0.0, 10.6, 11.6, 12.6, 13.6], [0.4, 0.0, 11.4, 12.4, 13.4, 14.4], [0.4, 0.0, 11.4, 12.4, 13.4, 14.4], [-0.4, 0.0, 10.6, 11.6, 12.6, 13.6]], "data02": [[0.6, 0.0], [1.4, 0.0], [1.4, 7279.0], [0.6, 7279.0]], "data03": [[1.6, 0.0], [2.4000000000000004, 0.0], [2.4000000000000004, 6614.0], [1.6, 6614.0]], "data04": [[2.6, 0.0], [3.4000000000000004, 0.0], [3.4000000000000004, 4434.0], [2.6, 4434.0]], "data05": [[3.6, 0.0], [4.4, 0.0], [4.4, 2907.0], [3.6, 2907.0]], "data06": [[4.6, 0.0], [5.3999999999999995, 0.0], [5.3999999999999995, 2324.0], [4.6, 2324.0]], "data07": [[5.6, 0.0], [6.3999999999999995, 0.0], [6.3999999999999995, 1578.0], [5.6, 1578.0]], "data08": [[6.6, 0.0], [7.3999999999999995, 0.0], [7.3999999999999995, 469.0], [6.6, 469.0]], "data09": [[7.6, 0.0], [8.4, 0.0], [8.4, 133.0], [7.6, 133.0]], "data10": [[8.6, 0.0], [9.4, 0.0], [9.4, 27.0], [8.6, 27.0]], "data11": [[9.6, 0.0], [10.4, 0.0], [10.4, 4.0], [9.6, 4.0]], "data12": [[14.6, 0.0], [15.4, 0.0], [15.4, 6999.0], [14.6, 6999.0]]}, "id": "el373447139792380857856", "plugins": [{"type": "reset"}, {"type": "zoom", "button": true, "enabled": false}, {"type": "boxzoom", "button": true, "enabled": false}]});
      });
    });
}else{
    // require.js not available: dynamically load d3 & mpld3
    mpld3_load_lib("https://d3js.org/d3.v5.js", function(){
         mpld3_load_lib("https://mpld3.github.io/js/mpld3.v0.5.10.js", function(){
                 
                 mpld3.draw_figure("fig_el3734471397923808578566893788369", {"width": 640.0, "height": 480.0, "axes": [{"bbox": [0.125, 0.10999999999999999, 0.775, 0.77], "xlim": [-1.19, 16.19], "ylim": [0.0, 7642.95], "xdomain": [-1.19, 16.19], "ydomain": [0.0, 7642.95], "xscale": "linear", "yscale": "linear", "axes": [{"position": "bottom", "nticks": 11, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}, {"position": "left", "nticks": 9, "tickvalues": null, "tickformat_formatter": "", "tickformat": null, "scale": "linear", "fontsize": 10.0, "grid": {"gridOn": false}, "visible": true}], "axesbg": "#FFFFFF", "axesbgalpha": null, "zoomable": true, "id": "el373447139797053788464", "lines": [], "paths": [{"data": "data01", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381535360", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data02", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792380856416", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data03", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797054022832", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data04", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381530944", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data05", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381539632", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data06", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797053780624", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data07", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532192", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data08", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381536896", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data09", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381539872", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data10", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532480", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data11", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797053797824", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 2, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381533200", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 3, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532240", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 4, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139797054021056", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data01", "xindex": 5, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381532576", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}, {"data": "data12", "xindex": 0, "yindex": 1, "coordinates": "data", "pathcodes": ["M", "L", "L", "L", "Z"], "id": "el373447139792381527536", "dasharray": "none", "alpha": 1, "facecolor": "#1F77B4", "edgecolor": "none", "edgewidth": 1.0, "zorder": 1}], "markers": [], "texts": [{"text": "Label", "position": [0.5, -0.07921476671476674], "coordinates": "axes", "h_anchor": "middle", "v_baseline": "hanging", "rotation": -0.0, "fontsize": 10.0, "color": "#000000", "alpha": 1, "zorder": 3, "id": "el373447139792380852176"}, {"text": "Count", "position": [-0.10161850358422941, 0.5], "coordinates": "axes", "h_anchor": "middle", "v_baseline": "auto", "rotation": -90.0, "fontsize": 10.0, "color": "#000000", "alpha": 1, "zorder": 3, "id": "el373447139792380853616"}], "collections": [], "images": [], "sharex": [], "sharey": []}], "data": {"data01": [[-0.4, 0.0, 10.6, 11.6, 12.6, 13.6], [0.4, 0.0, 11.4, 12.4, 13.4, 14.4], [0.4, 0.0, 11.4, 12.4, 13.4, 14.4], [-0.4, 0.0, 10.6, 11.6, 12.6, 13.6]], "data02": [[0.6, 0.0], [1.4, 0.0], [1.4, 7279.0], [0.6, 7279.0]], "data03": [[1.6, 0.0], [2.4000000000000004, 0.0], [2.4000000000000004, 6614.0], [1.6, 6614.0]], "data04": [[2.6, 0.0], [3.4000000000000004, 0.0], [3.4000000000000004, 4434.0], [2.6, 4434.0]], "data05": [[3.6, 0.0], [4.4, 0.0], [4.4, 2907.0], [3.6, 2907.0]], "data06": [[4.6, 0.0], [5.3999999999999995, 0.0], [5.3999999999999995, 2324.0], [4.6, 2324.0]], "data07": [[5.6, 0.0], [6.3999999999999995, 0.0], [6.3999999999999995, 1578.0], [5.6, 1578.0]], "data08": [[6.6, 0.0], [7.3999999999999995, 0.0], [7.3999999999999995, 469.0], [6.6, 469.0]], "data09": [[7.6, 0.0], [8.4, 0.0], [8.4, 133.0], [7.6, 133.0]], "data10": [[8.6, 0.0], [9.4, 0.0], [9.4, 27.0], [8.6, 27.0]], "data11": [[9.6, 0.0], [10.4, 0.0], [10.4, 4.0], [9.6, 4.0]], "data12": [[14.6, 0.0], [15.4, 0.0], [15.4, 6999.0], [14.6, 6999.0]]}, "id": "el373447139792380857856", "plugins": [{"type": "reset"}, {"type": "zoom", "button": true, "enabled": false}, {"type": "boxzoom", "button": true, "enabled": false}]});
            })
         });
}
</script>
<ol start="2" type="1">
<li><p><strong>Fine-tune data</strong>: For each <span
class="math inline">n \in [3,7]</span>, we generated an equal number of
graphs on <span class="math inline">n</span> vertices each with a random
<span class="math inline">t</span> on the shortest path from <span
class="math inline">1</span> to <span class="math inline">2</span>.
Again, each graph was created by choosing <span
class="math inline">n</span> random edges.</p></li>
<li><p><strong>Generalization test data</strong>: The same as the
fine-tune data, except we sample <span class="math inline">n \in
[3,15]</span> and <span class="math inline">t \leq 7</span>.</p></li>
</ol>
<p>We wrote some Python code to generate the data during the training
loop, but Python is excruciatingly slow and data generation wasted a lot
of training time. Our solution was to pre-generate the data before
training using a multithreaded version of our Python code.</p>
<h3 id="architecture">Architecture</h3>
<p>We used a standard transformer architecture. To ensure that it can in
theory learn BFS, we ensured that the number of layers in our
transformer exceeds the diameter of the input graphs.</p>
<p>Since the order of the edges in the input doesn’t matter, we did not
use positional encodings. Each edge <span
class="math inline">(a,b)</span> is embedded to dimension <span
class="math inline">d</span> where the first <span
class="math inline">\frac{d}{2}</span> elements are the learned
embedding of <span class="math inline">a</span> and the last <span
class="math inline">\frac{d}{2}</span> elements are the learned
embedding of <span class="math inline">b</span>. For the target vertex
<span class="math inline">t</span>, we pair it with the special token
<span class="math inline">TARGET</span> and embed <span
class="math inline">(t,TARGET)</span> in the same way.</p>
<!-- https://cocreate.csail.mit.edu/r/sxArTEXiAgJshznmm -->
<p><img src="img/embeddings.svg" /></p>
<h3 id="specifications">Specifications</h3>
<p>For our training run, we used the following specifications:</p>
<table>
<thead>
<tr class="header">
<th>Hyperparameter</th>
<th>Value</th>
</tr>
</thead>
<tbody>
<tr class="odd">
<td>Model dimension</td>
<td><span class="math inline">64</span></td>
</tr>
<tr class="even">
<td>Layers</td>
<td><span class="math inline">11</span></td>
</tr>
<tr class="odd">
<td>Heads</td>
<td><span class="math inline">2</span></td>
</tr>
<tr class="even">
<td>Epochs</td>
<td><span class="math inline">200</span></td>
</tr>
<tr class="odd">
<td>Batch size</td>
<td><span class="math inline">2^{15}</span></td>
</tr>
<tr class="even">
<td>Dropout</td>
<td><span class="math inline">0</span></td>
</tr>
<tr class="odd">
<td>Total training examples</td>
<td><span class="math inline">2\cdot10^8</span></td>
</tr>
<tr class="even">
<td>Parameters</td>
<td><span class="math inline">550433</span></td>
</tr>
<tr class="odd">
<td>Learning rate</td>
<td><span class="math inline">5\cdot10^{-4}</span></td>
</tr>
<tr class="even">
<td>Loss</td>
<td>MSE</td>
</tr>
<tr class="odd">
<td>Optimizer</td>
<td>Adam</td>
</tr>
</tbody>
</table>
<p>The number of bits required to store the model parameters in float32
is around <span class="math inline">1.76\cdot10^6</span>. The number of
possible graphs on 15 vertices generated using our procedure is
approximately <span class="math display">\frac{\binom{15}{2}^{15}}{15!}
\approx 1.59\cdot10^{18}.</span> This is because there are <span
class="math inline">\binom{15}{2}</span> choices for each of the 15
edges and we don’t care about the order of the edges. This is only an
approximation because some edges might be duplicated. Each graph has an
answer between 1 and 15 which requires around 4 bits, so memorizing all
the answers requires requires <span
class="math inline">4\cdot1.59\cdot10^{18} = 6.36\cdot10^{18}</span>
bits, which is <span class="math inline">3.61\cdot10^{12}</span> times
larger than our model size. This implies that in order to get really low
loss, our model needs to do something other than brute memorization.</p>
<p>A single training run takes roughly two hours to run on a Radeon 7900
XTX graphics card.</p>
<p>One pitfall we encountered during training is that we initially used
bfloat16 to save VRAM, but our huge batch size caused loss-of-precision
problems and made training very difficult. It took us two weeks to debug
this until we found that switching to float32 improved training
significantly.</p>
<h3 id="training-results">Training Results</h3>
<p>Here is our training loss curve:</p>
<p>$include<code>raw="html"</code> plots/train-loss.html</p>
<p>Here is a “confusion matrix”-like plot, showing for each true label,
the probability distribution over classifications that the model
gives:</p>
<p>$include<code>raw="html"</code> plots/train-hist.html</p>
<p>One pattern we consistently noticed during training is that the model
often gets stuck and plateaus for many epochs before rapidly decreasing.
For instance, this happened between epochs 100 and 300 in the graph
below:</p>
<p>$include<code>raw="html"</code> plots/train-lens.html</p>
<p>Such <strong>phase change</strong> phenomena are ubiquitous in deep
learning and for simple algorithmic tasks we can sometimes interpret
them <a href="#fn2" class="footnote-ref" id="fnref2"
role="doc-noteref"><sup>2</sup></a>.</p>
<p><strong>TODO ANT: training curves for 1, 2, 3 length paths. and add
some text about why this is an interesting finding.</strong></p>
<h2 id="generalization">Generalization</h2>
<h3 id="fine-tuning-results">Fine Tuning Results</h3>
<p>After receiving our initial results, we fine-tuned with a learning
rate of 1e-5, also with MSE and the same batch size. Our final results
are shown in the images below.</p>
<p>$include<code>raw="html"</code> plots/tune-loss.html</p>
<p>$include<code>raw="html"</code> plots/tune-hist.html</p>
<p>$include<code>raw="html"</code>
plots/test-onpath-smalltarget.html</p>
<p>$include<code>raw="html"</code>
plots/test-onpath-largetarget.html</p>
<p>$include<code>raw="html"</code> plots/test-small-any.html</p>
<p>$include<code>raw="html"</code>
plots/test-large-any-smalltarget.html</p>
<p>$include<code>raw="html"</code>
plots/test-large-any-largetarget.html</p>
<!--
ANT:  
RE t<7
The model couuld've decided to generalize for t > 7 too, but it didn't  
so id say 
we tested generalization for all t in 3, 15
it succeeded for t in 3,7
failed for t in 8 to 15
reason is because it could do "if statements" instead of realizing that it should just be looking at the dor product
mmm
i feel like the fact that it works for 3,7 and not 8,15 is kind of compelling
anyways this  is just a comment explaining how im going to / think you should pitch the generalization section
(it couldve generalized all the way, but didnt, so what we learened is generalization sometimes happens but sometimes doesnt

yeah that's the intuition
can't really do interpretability to figure out if that's actually what's going on
-->
<h2 id="conclusion">Conclusion</h2>
<p>In this post we’ve investigated solving the shortest paths problem
with a transformer. We’ve shown that this problem has a simple solution
that can be implemented by a transformer. We’ve shown that a transformer
can learn to solve this problem. We’ve shown that when the transformer
learns to solve this problem it also implicitly builds good internal
representations of the input data that allow for efficient fine tuning
to find shortest paths between other vertices. Finally, we’ve seen that
some off-distribution generalization does occur, and some does not! The
main conceptual take-away from our work is that it’s hard to predict
when models will and won’t generalize.</p>
<h2 id="appendix">Appendix</h2>
<div class="sourceCode" id="cb2"><pre
class="sourceCode python"><code class="sourceCode python"><span id="cb2-1"><a href="#cb2-1" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch</span>
<span id="cb2-2"><a href="#cb2-2" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> torch.nn <span class="im">as</span> nn</span>
<span id="cb2-3"><a href="#cb2-3" aria-hidden="true" tabindex="-1"></a><span class="im">import</span> random</span>
<span id="cb2-4"><a href="#cb2-4" aria-hidden="true" tabindex="-1"></a><span class="im">from</span> collections <span class="im">import</span> deque</span>
<span id="cb2-5"><a href="#cb2-5" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-6"><a href="#cb2-6" aria-hidden="true" tabindex="-1"></a><span class="co"># Configuration</span></span>
<span id="cb2-7"><a href="#cb2-7" aria-hidden="true" tabindex="-1"></a>NVTXS <span class="op">=</span> <span class="dv">8</span></span>
<span id="cb2-8"><a href="#cb2-8" aria-hidden="true" tabindex="-1"></a>MAXDIST <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb2-9"><a href="#cb2-9" aria-hidden="true" tabindex="-1"></a>AVGDEG <span class="op">=</span> <span class="dv">2</span></span>
<span id="cb2-10"><a href="#cb2-10" aria-hidden="true" tabindex="-1"></a>SEQLEN <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb2-11"><a href="#cb2-11" aria-hidden="true" tabindex="-1"></a>HIDDENDIM <span class="op">=</span> <span class="dv">4</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">2</span></span>
<span id="cb2-12"><a href="#cb2-12" aria-hidden="true" tabindex="-1"></a>START_REACH <span class="op">=</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb2-13"><a href="#cb2-13" aria-hidden="true" tabindex="-1"></a>START_OUT <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb2-14"><a href="#cb2-14" aria-hidden="true" tabindex="-1"></a>START_SELF <span class="op">=</span> <span class="dv">3</span> <span class="op">*</span> NVTXS <span class="op">+</span> <span class="dv">1</span></span>
<span id="cb2-15"><a href="#cb2-15" aria-hidden="true" tabindex="-1"></a>SRC_FLAG_IDX <span class="op">=</span> START_SELF</span>
<span id="cb2-16"><a href="#cb2-16" aria-hidden="true" tabindex="-1"></a>ANS_FLAG_IDX <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb2-17"><a href="#cb2-17" aria-hidden="true" tabindex="-1"></a>NOTANS_FLAG_IDX <span class="op">=</span> <span class="op">-</span><span class="dv">1</span></span>
<span id="cb2-18"><a href="#cb2-18" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-19"><a href="#cb2-19" aria-hidden="true" tabindex="-1"></a><span class="co"># Determine device</span></span>
<span id="cb2-20"><a href="#cb2-20" aria-hidden="true" tabindex="-1"></a>device <span class="op">=</span> torch.device(<span class="st">&quot;cuda&quot;</span> <span class="cf">if</span> torch.cuda.is_available() <span class="cf">else</span> <span class="st">&quot;cpu&quot;</span>)</span>
<span id="cb2-21"><a href="#cb2-21" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-22"><a href="#cb2-22" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> random_graph(device):</span>
<span id="cb2-23"><a href="#cb2-23" aria-hidden="true" tabindex="-1"></a>    <span class="co">&quot;&quot;&quot;Generate a random graph tensor.&quot;&quot;&quot;</span></span>
<span id="cb2-24"><a href="#cb2-24" aria-hidden="true" tabindex="-1"></a>    data <span class="op">=</span> torch.zeros((SEQLEN, HIDDENDIM), device<span class="op">=</span>device)</span>
<span id="cb2-25"><a href="#cb2-25" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-26"><a href="#cb2-26" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Mark self vertices</span></span>
<span id="cb2-27"><a href="#cb2-27" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">1</span>):</span>
<span id="cb2-28"><a href="#cb2-28" aria-hidden="true" tabindex="-1"></a>        data[i, START_SELF <span class="op">-</span> <span class="dv">1</span> <span class="op">+</span> i] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-29"><a href="#cb2-29" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-30"><a href="#cb2-30" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Create adjacency list</span></span>
<span id="cb2-31"><a href="#cb2-31" aria-hidden="true" tabindex="-1"></a>    adj_list <span class="op">=</span> [<span class="bu">set</span>() <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(SEQLEN)]</span>
<span id="cb2-32"><a href="#cb2-32" aria-hidden="true" tabindex="-1"></a>    indices <span class="op">=</span> [random.randint(<span class="dv">1</span>, NVTXS) <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(AVGDEG <span class="op">*</span> NVTXS)]</span>
<span id="cb2-33"><a href="#cb2-33" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-34"><a href="#cb2-34" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">0</span>, <span class="bu">len</span>(indices), <span class="dv">2</span>):</span>
<span id="cb2-35"><a href="#cb2-35" aria-hidden="true" tabindex="-1"></a>        u <span class="op">=</span> indices[i]</span>
<span id="cb2-36"><a href="#cb2-36" aria-hidden="true" tabindex="-1"></a>        v <span class="op">=</span> indices[i <span class="op">+</span> <span class="dv">1</span>]</span>
<span id="cb2-37"><a href="#cb2-37" aria-hidden="true" tabindex="-1"></a>        <span class="cf">if</span> u <span class="op">!=</span> v:</span>
<span id="cb2-38"><a href="#cb2-38" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Bidirectional connections</span></span>
<span id="cb2-39"><a href="#cb2-39" aria-hidden="true" tabindex="-1"></a>            data[v, u] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-40"><a href="#cb2-40" aria-hidden="true" tabindex="-1"></a>            data[u, v] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-41"><a href="#cb2-41" aria-hidden="true" tabindex="-1"></a>            data[v, NVTXS <span class="op">+</span> u] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-42"><a href="#cb2-42" aria-hidden="true" tabindex="-1"></a>            data[u, NVTXS <span class="op">+</span> v] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-43"><a href="#cb2-43" aria-hidden="true" tabindex="-1"></a>            adj_list[u].add(v)</span>
<span id="cb2-44"><a href="#cb2-44" aria-hidden="true" tabindex="-1"></a>            adj_list[v].add(u)</span>
<span id="cb2-45"><a href="#cb2-45" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-46"><a href="#cb2-46" aria-hidden="true" tabindex="-1"></a>    <span class="co"># Set flags</span></span>
<span id="cb2-47"><a href="#cb2-47" aria-hidden="true" tabindex="-1"></a>    data[<span class="dv">0</span>, ANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-48"><a href="#cb2-48" aria-hidden="true" tabindex="-1"></a>    data[<span class="dv">1</span>:, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-49"><a href="#cb2-49" aria-hidden="true" tabindex="-1"></a>    data[<span class="dv">0</span>, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-50"><a href="#cb2-50" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data, adj_list</span>
<span id="cb2-51"><a href="#cb2-51" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-52"><a href="#cb2-52" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> SSSP(G):</span>
<span id="cb2-53"><a href="#cb2-53" aria-hidden="true" tabindex="-1"></a>    <span class="co">&quot;&quot;&quot;Single Source Shortest Path algorithm.&quot;&quot;&quot;</span></span>
<span id="cb2-54"><a href="#cb2-54" aria-hidden="true" tabindex="-1"></a>    dist <span class="op">=</span> [MAXDIST <span class="cf">for</span> _ <span class="kw">in</span> G]</span>
<span id="cb2-55"><a href="#cb2-55" aria-hidden="true" tabindex="-1"></a>    dist[<span class="dv">1</span>] <span class="op">=</span> <span class="dv">0</span></span>
<span id="cb2-56"><a href="#cb2-56" aria-hidden="true" tabindex="-1"></a>    frontier <span class="op">=</span> deque([<span class="dv">1</span>])</span>
<span id="cb2-57"><a href="#cb2-57" aria-hidden="true" tabindex="-1"></a>    <span class="cf">while</span> frontier:</span>
<span id="cb2-58"><a href="#cb2-58" aria-hidden="true" tabindex="-1"></a>        vtx <span class="op">=</span> frontier.popleft()</span>
<span id="cb2-59"><a href="#cb2-59" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> x <span class="kw">in</span> G[vtx]:</span>
<span id="cb2-60"><a href="#cb2-60" aria-hidden="true" tabindex="-1"></a>            <span class="cf">if</span> dist[x] <span class="op">==</span> MAXDIST:</span>
<span id="cb2-61"><a href="#cb2-61" aria-hidden="true" tabindex="-1"></a>                dist[x] <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> dist[vtx]</span>
<span id="cb2-62"><a href="#cb2-62" aria-hidden="true" tabindex="-1"></a>                frontier.append(x)</span>
<span id="cb2-63"><a href="#cb2-63" aria-hidden="true" tabindex="-1"></a>                <span class="cf">if</span> x <span class="op">==</span> <span class="dv">2</span>:</span>
<span id="cb2-64"><a href="#cb2-64" aria-hidden="true" tabindex="-1"></a>                    <span class="cf">return</span> dist[<span class="dv">2</span>]</span>
<span id="cb2-65"><a href="#cb2-65" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> MAXDIST</span>
<span id="cb2-66"><a href="#cb2-66" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-67"><a href="#cb2-67" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> mkbatch(size):</span>
<span id="cb2-68"><a href="#cb2-68" aria-hidden="true" tabindex="-1"></a>    <span class="co">&quot;&quot;&quot;Create a batch of graph data.&quot;&quot;&quot;</span></span>
<span id="cb2-69"><a href="#cb2-69" aria-hidden="true" tabindex="-1"></a>    graphs <span class="op">=</span> []</span>
<span id="cb2-70"><a href="#cb2-70" aria-hidden="true" tabindex="-1"></a>    distances <span class="op">=</span> []</span>
<span id="cb2-71"><a href="#cb2-71" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-72"><a href="#cb2-72" aria-hidden="true" tabindex="-1"></a>    <span class="cf">for</span> _ <span class="kw">in</span> <span class="bu">range</span>(size):</span>
<span id="cb2-73"><a href="#cb2-73" aria-hidden="true" tabindex="-1"></a>        data, adj_list <span class="op">=</span> random_graph(device)</span>
<span id="cb2-74"><a href="#cb2-74" aria-hidden="true" tabindex="-1"></a>        dist <span class="op">=</span> SSSP(adj_list)</span>
<span id="cb2-75"><a href="#cb2-75" aria-hidden="true" tabindex="-1"></a>        graphs.append(data)</span>
<span id="cb2-76"><a href="#cb2-76" aria-hidden="true" tabindex="-1"></a>        distances.append(dist)</span>
<span id="cb2-77"><a href="#cb2-77" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-78"><a href="#cb2-78" aria-hidden="true" tabindex="-1"></a>    data <span class="op">=</span> torch.stack(graphs)</span>
<span id="cb2-79"><a href="#cb2-79" aria-hidden="true" tabindex="-1"></a>    labels <span class="op">=</span> torch.tensor(distances, dtype<span class="op">=</span>torch.float32, device<span class="op">=</span>device)</span>
<span id="cb2-80"><a href="#cb2-80" aria-hidden="true" tabindex="-1"></a>    <span class="cf">return</span> data, labels</span>
<span id="cb2-81"><a href="#cb2-81" aria-hidden="true" tabindex="-1"></a>    </span>
<span id="cb2-82"><a href="#cb2-82" aria-hidden="true" tabindex="-1"></a>BIG,SUPABIG,MED,CURSE <span class="op">=</span> <span class="dv">12</span>,<span class="dv">30</span>,<span class="dv">7</span>,<span class="dv">5</span></span>
<span id="cb2-83"><a href="#cb2-83" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-84"><a href="#cb2-84" aria-hidden="true" tabindex="-1"></a><span class="kw">class</span> SillyTransformer(nn.Module):</span>
<span id="cb2-85"><a href="#cb2-85" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> <span class="fu">__init__</span>(<span class="va">self</span>, device):</span>
<span id="cb2-86"><a href="#cb2-86" aria-hidden="true" tabindex="-1"></a>        <span class="bu">super</span>().<span class="fu">__init__</span>()</span>
<span id="cb2-87"><a href="#cb2-87" aria-hidden="true" tabindex="-1"></a>        <span class="va">self</span>.device <span class="op">=</span> device</span>
<span id="cb2-88"><a href="#cb2-88" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-89"><a href="#cb2-89" aria-hidden="true" tabindex="-1"></a>        <span class="cf">with</span> torch.no_grad():</span>
<span id="cb2-90"><a href="#cb2-90" aria-hidden="true" tabindex="-1"></a>            <span class="co"># Initialize weight parameters with specific configurations</span></span>
<span id="cb2-91"><a href="#cb2-91" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mostKs <span class="op">=</span> nn.ParameterList()</span>
<span id="cb2-92"><a href="#cb2-92" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mostQs <span class="op">=</span> nn.ParameterList()</span>
<span id="cb2-93"><a href="#cb2-93" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.mostVs <span class="op">=</span> nn.ParameterList()</span>
<span id="cb2-94"><a href="#cb2-94" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> head <span class="kw">in</span> <span class="bu">range</span>(<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">1</span>):</span>
<span id="cb2-95"><a href="#cb2-95" aria-hidden="true" tabindex="-1"></a>                Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
<span id="cb2-96"><a href="#cb2-96" aria-hidden="true" tabindex="-1"></a>                Q[<span class="dv">0</span>, START_REACH <span class="op">-</span> <span class="dv">1</span> <span class="op">+</span> head] <span class="op">=</span> SUPABIG</span>
<span id="cb2-97"><a href="#cb2-97" aria-hidden="true" tabindex="-1"></a>                Q[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-98"><a href="#cb2-98" aria-hidden="true" tabindex="-1"></a>                K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">2</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
<span id="cb2-99"><a href="#cb2-99" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">0</span>, head] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-100"><a href="#cb2-100" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> BIG</span>
<span id="cb2-101"><a href="#cb2-101" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-102"><a href="#cb2-102" aria-hidden="true" tabindex="-1"></a>                V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span>
<span id="cb2-103"><a href="#cb2-103" aria-hidden="true" tabindex="-1"></a>                <span class="cf">for</span> i <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
<span id="cb2-104"><a href="#cb2-104" aria-hidden="true" tabindex="-1"></a>                    V[i, START_SELF <span class="op">+</span> i] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-105"><a href="#cb2-105" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-106"><a href="#cb2-106" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.mostKs.append(K)</span>
<span id="cb2-107"><a href="#cb2-107" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.mostQs.append(Q)</span>
<span id="cb2-108"><a href="#cb2-108" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.mostVs.append(V)</span>
<span id="cb2-109"><a href="#cb2-109" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.weirdKs <span class="op">=</span> nn.ParameterList()</span>
<span id="cb2-110"><a href="#cb2-110" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.weirdQs <span class="op">=</span> nn.ParameterList()</span>
<span id="cb2-111"><a href="#cb2-111" aria-hidden="true" tabindex="-1"></a>            <span class="va">self</span>.weirdVs <span class="op">=</span> nn.ParameterList()</span>
<span id="cb2-112"><a href="#cb2-112" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
<span id="cb2-113"><a href="#cb2-113" aria-hidden="true" tabindex="-1"></a>                K <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
<span id="cb2-114"><a href="#cb2-114" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">0</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG</span>
<span id="cb2-115"><a href="#cb2-115" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">0</span>, SRC_FLAG_IDX] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span>
<span id="cb2-116"><a href="#cb2-116" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">1</span>, NOTANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>SUPABIG</span>
<span id="cb2-117"><a href="#cb2-117" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">1</span>, NVTXS <span class="op">+</span> <span class="dv">2</span>] <span class="op">=</span> BIG<span class="op">+</span>SUPABIG</span>
<span id="cb2-118"><a href="#cb2-118" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">1</span>, ANS_FLAG_IDX] <span class="op">=</span> <span class="op">-</span>BIG<span class="op">-</span>SUPABIG</span>
<span id="cb2-119"><a href="#cb2-119" aria-hidden="true" tabindex="-1"></a>                K[<span class="dv">2</span>, ANS_FLAG_IDX] <span class="op">=</span> MED</span>
<span id="cb2-120"><a href="#cb2-120" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-121"><a href="#cb2-121" aria-hidden="true" tabindex="-1"></a>                Q <span class="op">=</span> nn.Parameter(torch.zeros((<span class="dv">3</span>, HIDDENDIM), device<span class="op">=</span>device))</span>
<span id="cb2-122"><a href="#cb2-122" aria-hidden="true" tabindex="-1"></a>                Q[:, ANS_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-123"><a href="#cb2-123" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-124"><a href="#cb2-124" aria-hidden="true" tabindex="-1"></a>                V <span class="op">=</span> nn.Parameter(torch.zeros((NVTXS, HIDDENDIM), device<span class="op">=</span>device))</span>
<span id="cb2-125"><a href="#cb2-125" aria-hidden="true" tabindex="-1"></a>                V[layer, SRC_FLAG_IDX] <span class="op">=</span> <span class="dv">1</span></span>
<span id="cb2-126"><a href="#cb2-126" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-127"><a href="#cb2-127" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.weirdKs.append(K)</span>
<span id="cb2-128"><a href="#cb2-128" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.weirdQs.append(Q)</span>
<span id="cb2-129"><a href="#cb2-129" aria-hidden="true" tabindex="-1"></a>                <span class="va">self</span>.weirdVs.append(V)</span>
<span id="cb2-130"><a href="#cb2-130" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-131"><a href="#cb2-131" aria-hidden="true" tabindex="-1"></a>    <span class="kw">def</span> forward(<span class="va">self</span>, src):</span>
<span id="cb2-132"><a href="#cb2-132" aria-hidden="true" tabindex="-1"></a>        <span class="cf">for</span> layer <span class="kw">in</span> <span class="bu">range</span>(NVTXS):</span>
<span id="cb2-133"><a href="#cb2-133" aria-hidden="true" tabindex="-1"></a>            allKs <span class="op">=</span> [<span class="va">self</span>.weirdKs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostKs]</span>
<span id="cb2-134"><a href="#cb2-134" aria-hidden="true" tabindex="-1"></a>            allQs <span class="op">=</span> [<span class="va">self</span>.weirdQs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostQs]</span>
<span id="cb2-135"><a href="#cb2-135" aria-hidden="true" tabindex="-1"></a>            allVs <span class="op">=</span> [<span class="va">self</span>.weirdVs[layer]] <span class="op">+</span> [x <span class="cf">for</span> x <span class="kw">in</span> <span class="va">self</span>.mostVs]</span>
<span id="cb2-136"><a href="#cb2-136" aria-hidden="true" tabindex="-1"></a>            head_outputs <span class="op">=</span> []</span>
<span id="cb2-137"><a href="#cb2-137" aria-hidden="true" tabindex="-1"></a>            </span>
<span id="cb2-138"><a href="#cb2-138" aria-hidden="true" tabindex="-1"></a>            <span class="cf">for</span> (K, Q, V) <span class="kw">in</span> <span class="bu">zip</span>(allKs, allQs, allVs):</span>
<span id="cb2-139"><a href="#cb2-139" aria-hidden="true" tabindex="-1"></a>                ksrc <span class="op">=</span> torch.matmul(src, K.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
<span id="cb2-140"><a href="#cb2-140" aria-hidden="true" tabindex="-1"></a>                qsrc <span class="op">=</span> torch.matmul(src, Q.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
<span id="cb2-141"><a href="#cb2-141" aria-hidden="true" tabindex="-1"></a>                vsrc <span class="op">=</span> torch.matmul(src, V.unsqueeze(<span class="dv">0</span>).transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
<span id="cb2-142"><a href="#cb2-142" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-143"><a href="#cb2-143" aria-hidden="true" tabindex="-1"></a>                scores <span class="op">=</span> torch.matmul(qsrc, ksrc.transpose(<span class="op">-</span><span class="dv">2</span>, <span class="op">-</span><span class="dv">1</span>))</span>
<span id="cb2-144"><a href="#cb2-144" aria-hidden="true" tabindex="-1"></a>                attention_weights <span class="op">=</span> torch.softmax(scores, dim<span class="op">=-</span><span class="dv">1</span>)</span>
<span id="cb2-145"><a href="#cb2-145" aria-hidden="true" tabindex="-1"></a>                head_output <span class="op">=</span> torch.matmul(attention_weights, vsrc)</span>
<span id="cb2-146"><a href="#cb2-146" aria-hidden="true" tabindex="-1"></a>                head_outputs.append(head_output)</span>
<span id="cb2-147"><a href="#cb2-147" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-148"><a href="#cb2-148" aria-hidden="true" tabindex="-1"></a>            new_reaches <span class="op">=</span> <span class="bu">sum</span>(head_outputs[<span class="dv">1</span>:])</span>
<span id="cb2-149"><a href="#cb2-149" aria-hidden="true" tabindex="-1"></a>            BSZ <span class="op">=</span> new_reaches.shape[<span class="dv">0</span>]</span>
<span id="cb2-150"><a href="#cb2-150" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-151"><a href="#cb2-151" aria-hidden="true" tabindex="-1"></a>            nodelta_nbrs <span class="op">=</span> torch.zeros((BSZ, SEQLEN, NVTXS <span class="op">+</span> <span class="dv">1</span>), device<span class="op">=</span><span class="va">self</span>.device)</span>
<span id="cb2-152"><a href="#cb2-152" aria-hidden="true" tabindex="-1"></a>            morepadlol <span class="op">=</span> torch.zeros((BSZ, SEQLEN, <span class="dv">1</span> <span class="op">+</span> NVTXS), device<span class="op">=</span><span class="va">self</span>.device)</span>
<span id="cb2-153"><a href="#cb2-153" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-154"><a href="#cb2-154" aria-hidden="true" tabindex="-1"></a>            src <span class="op">=</span> src <span class="op">+</span> torch.cat((nodelta_nbrs, new_reaches, head_outputs[<span class="dv">0</span>], morepadlol), dim<span class="op">=</span><span class="dv">2</span>)</span>
<span id="cb2-155"><a href="#cb2-155" aria-hidden="true" tabindex="-1"></a>            src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">=</span> <span class="dv">2</span> <span class="op">*</span> torch.sigmoid(src[:, :, START_REACH:START_REACH <span class="op">+</span> NVTXS] <span class="op">*</span> CURSE) <span class="op">-</span> <span class="dv">1</span></span>
<span id="cb2-156"><a href="#cb2-156" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-157"><a href="#cb2-157" aria-hidden="true" tabindex="-1"></a>        canreach <span class="op">=</span> src[:, <span class="dv">0</span>, START_OUT:START_OUT <span class="op">+</span> NVTXS]</span>
<span id="cb2-158"><a href="#cb2-158" aria-hidden="true" tabindex="-1"></a>        final_output <span class="op">=</span> <span class="dv">1</span> <span class="op">+</span> torch.<span class="bu">sum</span>(<span class="dv">1</span> <span class="op">-</span> canreach, dim<span class="op">=</span><span class="dv">1</span>)</span>
<span id="cb2-159"><a href="#cb2-159" aria-hidden="true" tabindex="-1"></a>        <span class="cf">return</span> final_output</span>
<span id="cb2-160"><a href="#cb2-160" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-161"><a href="#cb2-161" aria-hidden="true" tabindex="-1"></a>model <span class="op">=</span> SillyTransformer(device).to(device)</span>
<span id="cb2-162"><a href="#cb2-162" aria-hidden="true" tabindex="-1"></a>params <span class="op">=</span> <span class="bu">sum</span>(p.numel() <span class="cf">for</span> p <span class="kw">in</span> model.parameters())</span>
<span id="cb2-163"><a href="#cb2-163" aria-hidden="true" tabindex="-1"></a><span class="bu">print</span>(<span class="ss">f&quot;Total number of parameters: </span><span class="sc">{</span>params<span class="sc">}</span><span class="ss">&quot;</span>)</span>
<span id="cb2-164"><a href="#cb2-164" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-165"><a href="#cb2-165" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> destroy_rand_weights(model):</span>
<span id="cb2-166"><a href="#cb2-166" aria-hidden="true" tabindex="-1"></a>    weight_lists <span class="op">=</span> [model.mostKs, model.mostQs, model.mostVs, </span>
<span id="cb2-167"><a href="#cb2-167" aria-hidden="true" tabindex="-1"></a>                    model.weirdKs, model.weirdQs, model.weirdVs]</span>
<span id="cb2-168"><a href="#cb2-168" aria-hidden="true" tabindex="-1"></a>    random_list <span class="op">=</span> random.choice(weight_lists)</span>
<span id="cb2-169"><a href="#cb2-169" aria-hidden="true" tabindex="-1"></a>    random_matrix <span class="op">=</span> random.choice(random_list)</span>
<span id="cb2-170"><a href="#cb2-170" aria-hidden="true" tabindex="-1"></a>    random_matrix.data <span class="op">=</span> torch.randn_like(random_matrix)</span>
<span id="cb2-171"><a href="#cb2-171" aria-hidden="true" tabindex="-1"></a></span>
<span id="cb2-172"><a href="#cb2-172" aria-hidden="true" tabindex="-1"></a>optimizer <span class="op">=</span> torch.optim.Adam(model.parameters(), lr<span class="op">=</span><span class="fl">1e-6</span>)</span>
<span id="cb2-173"><a href="#cb2-173" aria-hidden="true" tabindex="-1"></a>loss_fn <span class="op">=</span> nn.MSELoss()</span></code></pre></div>
<h2 id="references">References</h2>
<section id="footnotes" class="footnotes footnotes-end-of-document"
role="doc-endnotes">
<hr />
<ol>
<li id="fn1"><p>Lin, Hilton, and Evans. “TruthfulQA: Measuring How
Models Mimic Human Falsehoods”<a href="#fnref1" class="footnote-back"
role="doc-backlink">↩︎</a></p></li>
<li id="fn2"><p>Nanda, Neel, Lawrence Chan, Tom Lieberum, Jess Smith,
and Jacob Steinhardt. “Progress measures for grokking via mechanistic
interpretability.” arXiv preprint arXiv:2301.05217 (2023).
https://arxiv.org/abs/2301.05217.<a href="#fnref2" class="footnote-back"
role="doc-backlink">↩︎</a></p></li>
</ol>
</section>
</body>
</html>