aboutsummaryrefslogtreecommitdiff
path: root/transformer_shortest_paths.ipynb
blob: a158fcb7befeba8337cdb7ccffa773c251c2315e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LPphBnKR-aWF"
   },
   "source": [
    "# Step 0: Imports"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Question: \n",
    "\n",
    "- Do the attention heads learn to attend to the same positional encodings\n",
    "- do interp -- what is it doing? can we figure out?\n",
    "- update: I think we should do interp once it's bigger. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ge5QvElvhCOw",
    "outputId": "c7cdaefa-d6dc-44ad-c258-e4fb2aca97a5"
   },
   "outputs": [],
   "source": [
    "# using tqdm.auto glitches out collaborative editing\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from math import sqrt\n",
    "from collections import deque\n",
    "import os\n",
    "import random\n",
    "\n",
    "torch.manual_seed(42)\n",
    "random.seed(42)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "assert device.type == 'cuda', \"CUDA is not available. Please check your GPU setup.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "lylOX2POPwFL"
   },
   "outputs": [],
   "source": [
    "# VTXS numbers here are inclusive\n",
    "MIN_VTXS = 3\n",
    "MAX_VTXS = 31\n",
    "MAX_TUNE_VTXS = 15 # 15\n",
    "AVG_DEG = 2\n",
    "SEQ_LEN = MAX_VTXS + 1 # means 32 edges, final token is the target vertex\n",
    "PAD_TOKEN = 0\n",
    "# vertices are labelled 1,2,...,63\n",
    "# we also have a padding token which is 0."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gKt-yIpDebF1"
   },
   "source": [
    "# Step 1: Generate synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1IbzGIWseK3E",
    "outputId": "a3cbc233-358c-4e17-ea6e-f4e9349d886b"
   },
   "outputs": [],
   "source": [
    "# original task data\n",
    "# the data will be edge lists\n",
    "# like this: [1 3 1 5 2 4 0 0 0 0 2]\n",
    "# this represents edges (1,3), (1,5) (2,4)\n",
    "# (the zeros are just padding tokens)\n",
    "# the final 2 means which vertex we're going to \n",
    "\n",
    "# the label is the shortest distance from vtx 1 to vtx 2\n",
    "# or \"number of vertices\" if no path exists\n",
    "\n",
    "def random_graph(n):\n",
    "    edge_list = []\n",
    "    adjacencies = [set() for _ in range(n+1)]\n",
    "    indices = [random.randint(1, n) for _ in range(AVG_DEG * n)]\n",
    "    for i in range(0, len(indices), 2):\n",
    "        u = indices[i]\n",
    "        v = indices[i + 1]\n",
    "        if u != v:\n",
    "            edge_list += [min(u,v),max(u,v)]\n",
    "            adjacencies[u].add(v)\n",
    "            adjacencies[v].add(u)\n",
    "    edge_list += [PAD_TOKEN]*(2*SEQ_LEN-1-len(edge_list))\n",
    "    return edge_list, adjacencies\n",
    "\n",
    "\"\"\"\n",
    "input: G, represented as an adjacency list\n",
    "output: [number of vertices]+[d(1,i) for i in range(n)] if target=None\n",
    "if target is set to some value, then we instead just output that specific distance\n",
    "\"\"\"\n",
    "def SSSP(n, G, target=2):\n",
    "    dist = [n for _ in G]\n",
    "    dist[1] = 0\n",
    "    frontier = deque()\n",
    "    frontier.append(1)\n",
    "    while len(frontier) > 0:\n",
    "        vtx = frontier.popleft()\n",
    "        for x in G[vtx]:\n",
    "            if dist[x] == n:\n",
    "                dist[x] = 1 + dist[vtx]\n",
    "                frontier.append(x)\n",
    "                if x == target:\n",
    "                    return dist[target]\n",
    "    if target is not None:\n",
    "        return dist[target]\n",
    "    else:\n",
    "        return dist\n",
    "\n",
    "def mkbatch(size):\n",
    "    graphs1 = []\n",
    "    distance1 = []\n",
    "    \n",
    "    for i in range(size):\n",
    "        n = random.randint(MIN_VTXS, MAX_VTXS)\n",
    "        edge_list, adj_list = random_graph(n)\n",
    "        dist = SSSP(n, adj_list)\n",
    "        edge_list[-1] = 2 # target token\n",
    "        graphs1.append(edge_list)\n",
    "        distance1.append(dist)\n",
    "    \n",
    "    data = torch.tensor(graphs1, device=device)\n",
    "    labels = torch.tensor(distance1, dtype=torch.bfloat16, device=device)\n",
    "    padding = data == PAD_TOKEN\n",
    "    return data, labels, padding\n",
    "\n",
    "def vertices_on_shortest_12_path(n, G, target=2):\n",
    "    dist = [n for _ in G]\n",
    "    parent = [-1 for _ in G]\n",
    "    dist[1] = 0\n",
    "    frontier = deque()\n",
    "    frontier.append(1)\n",
    "    while len(frontier) > 0:\n",
    "        vtx = frontier.popleft()\n",
    "        for x in G[vtx]:\n",
    "            if dist[x] == n:\n",
    "                parent[x] = vtx\n",
    "                dist[x] = 1 + dist[vtx]\n",
    "                frontier.append(x)\n",
    "                if x == target:\n",
    "                    path = [x]\n",
    "                    while parent[x] != -1:\n",
    "                        x = parent[x]\n",
    "                        path.append(x)\n",
    "                    return list(reversed(path))\n",
    "    return []\n",
    "\n",
    "def mktunebatch(size):\n",
    "    graphs = []\n",
    "    distance = []\n",
    "    \n",
    "    for i in range(size):\n",
    "        n = random.randint(MIN_VTXS, MAX_TUNE_VTXS)\n",
    "        while True:\n",
    "            edge_list, adj_list = random_graph(n)\n",
    "            path = vertices_on_shortest_12_path(n, adj_list)\n",
    "            if len(path) > 1:\n",
    "                target_vtx_idx = random.randrange(1, len(path))\n",
    "                target_vtx = path[target_vtx_idx]\n",
    "                edge_list[-1] = target_vtx\n",
    "                graphs.append(edge_list)\n",
    "                distance.append(target_vtx_idx)\n",
    "                break\n",
    "    \n",
    "    data = torch.tensor(graphs, device=device)\n",
    "    labels = torch.tensor(distance, dtype=torch.bfloat16, device=device)\n",
    "    padding = data == PAD_TOKEN\n",
    "    return data, labels, padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 1,  4,  8,  9,  5,  8,  4, 22,  3, 18, 14, 19,  1,  2,  3,  7,  8, 17,\n",
       "           1, 20,  7, 18, 21, 23, 18, 23,  8, 14, 15, 19,  1,  9,  6, 23, 11, 14,\n",
       "           5,  9,  7, 11,  3,  4,  4, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "           0,  0,  0,  0,  0,  0,  0,  0,  2],\n",
       "         [ 2,  9, 15, 18,  4, 13,  3, 18, 10, 21, 12, 20,  7, 19,  2,  3,  8, 22,\n",
       "           3, 10,  4,  8,  9, 13, 15, 21,  6, 12,  7, 22,  9, 22,  3, 21, 20, 21,\n",
       "           6, 18,  6,  8, 13, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "           0,  0,  0,  0,  0,  0,  0,  0,  2],\n",
       "         [ 9, 11,  4, 11,  1,  6,  1,  4,  6,  7,  2,  5,  4, 10,  4,  6,  8, 11,\n",
       "           7, 11,  3,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "           0,  0,  0,  0,  0,  0,  0,  0,  2],\n",
       "         [ 3,  4,  5, 10,  7, 10,  6,  7,  3,  4,  8,  9,  1,  2,  2,  3,  3, 11,\n",
       "           7, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "           0,  0,  0,  0,  0,  0,  0,  0,  2],\n",
       "         [ 3, 13, 13, 20, 15, 17,  9, 18,  1, 22,  4, 22,  9, 18, 11, 21,  4, 10,\n",
       "           6, 14,  1, 15,  9, 17,  6, 17,  4, 21, 10, 21, 17, 20,  5,  7,  6, 12,\n",
       "          17, 18,  1, 20, 11, 16,  1,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "           0,  0,  0,  0,  0,  0,  0,  0,  2]], device='cuda:0'),\n",
       " tensor([ 1., 22., 11.,  1., 22.], device='cuda:0', dtype=torch.bfloat16),\n",
       " tensor([[False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True, False],\n",
       "         [False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True, False],\n",
       "         [False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True, False],\n",
       "         [False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True, False],\n",
       "         [False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False, False, False, False, False, False, False,\n",
       "          False, False, False, False,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "           True,  True, False]], device='cuda:0'))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mkbatch(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([320.,   0., 310.,   0., 264.,   0., 178.,   0., 119.,   0.,  90.,\n",
       "          0.,  69.,   0.,  42.,   0.,   0.,  39.,   0.,  30.,   0.,  30.,\n",
       "          0.,  31.,   0.,  26.,   0.,  28.,   0.,  27.,   0.,   0.,  20.,\n",
       "          0.,  30.,   0.,  29.,   0.,  26.,   0.,  20.,   0.,  30.,   0.,\n",
       "         34.,   0.,  23.,   0.,   0.,  35.,   0.,  29.,   0.,  17.,   0.,\n",
       "         26.,   0.,  33.,   0.,  35.,   0.,  32.,   0.,  26.]),\n",
       " array([ 1.     ,  1.46875,  1.9375 ,  2.40625,  2.875  ,  3.34375,\n",
       "         3.8125 ,  4.28125,  4.75   ,  5.21875,  5.6875 ,  6.15625,\n",
       "         6.625  ,  7.09375,  7.5625 ,  8.03125,  8.5    ,  8.96875,\n",
       "         9.4375 ,  9.90625, 10.375  , 10.84375, 11.3125 , 11.78125,\n",
       "        12.25   , 12.71875, 13.1875 , 13.65625, 14.125  , 14.59375,\n",
       "        15.0625 , 15.53125, 16.     , 16.46875, 16.9375 , 17.40625,\n",
       "        17.875  , 18.34375, 18.8125 , 19.28125, 19.75   , 20.21875,\n",
       "        20.6875 , 21.15625, 21.625  , 22.09375, 22.5625 , 23.03125,\n",
       "        23.5    , 23.96875, 24.4375 , 24.90625, 25.375  , 25.84375,\n",
       "        26.3125 , 26.78125, 27.25   , 27.71875, 28.1875 , 28.65625,\n",
       "        29.125  , 29.59375, 30.0625 , 30.53125, 31.     ]),\n",
       " <BarContainer object of 64 artists>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAi0ElEQVR4nO3df1BVdeL/8RegXH9eCBUurIioJZJiLSneqVxXWH7IurrSjJZb1Do6utCsUqY0pmU7i2tNP4d0drbVmpEsd1JHK0sxcVtRk3L8VYw6tNjKhVZHrmLgD873jx3v93MTlavofd/r8zFzZrj3vO+973M8E8/OPfcSYlmWJQAAAIOE+nsCAAAAP0WgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADBOJ39P4Hq0trbq+PHj6tmzp0JCQvw9HQAA0A6WZen06dOKi4tTaOjVz5EEZKAcP35c8fHx/p4GAAC4DseOHVPfvn2vOiYgA6Vnz56S/reBdrvdz7MBAADt4Xa7FR8f7/k9fjUBGSiX3tax2+0ECgAAAaY9l2dwkSwAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzTyd8TMFH/+R9dcd13S3Jv4UwAALg9cQYFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxuGr7jsQX5EPAEDH4AwKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzjU6AsW7ZMKSkpstvtstvtcjqd+uSTTzzrm5ubVVBQoF69eqlHjx7Ky8tTfX2913PU1tYqNzdX3bp1U3R0tObOnasLFy50zNYAAICg4FOg9O3bV0uWLFFVVZX27NmjsWPHasKECTp48KAkac6cOdqwYYPWrFmjiooKHT9+XJMmTfI8/uLFi8rNzdW5c+e0Y8cOvfPOO1q5cqUWLlzYsVsFAAACWohlWdaNPEFUVJReeuklPfTQQ+rTp4/Kysr00EMPSZK+/fZbDRkyRJWVlRo1apQ++eQT/frXv9bx48cVExMjSVq+fLnmzZunH374QeHh4e16TbfbrYiICDU2Nsput9/I9Nt0vV+4xhe1AQBwZb78/r7ua1AuXryo1atXq6mpSU6nU1VVVTp//rwyMjI8Y5KSktSvXz9VVlZKkiorKzVs2DBPnEhSVlaW3G635yxMW1paWuR2u70WAAAQvHwOlP3796tHjx6y2WyaOXOm1q5dq+TkZLlcLoWHhysyMtJrfExMjFwulyTJ5XJ5xcml9ZfWXUlJSYkiIiI8S3x8vK/TBgAAAcTnQBk8eLD27t2rXbt2adasWcrPz9ehQ4duxtw8iouL1djY6FmOHTt2U18PAAD4l89/LDA8PFyDBg2SJKWmpurLL7/U66+/rsmTJ+vcuXM6deqU11mU+vp6ORwOSZLD4dDu3bu9nu/Sp3wujWmLzWaTzWbzdaoAACBA3fD3oLS2tqqlpUWpqanq3LmzysvLPeuqq6tVW1srp9MpSXI6ndq/f78aGho8YzZv3iy73a7k5OQbnQoAAAgSPp1BKS4uVk5Ojvr166fTp0+rrKxM27Zt06effqqIiAhNmzZNRUVFioqKkt1u15NPPimn06lRo0ZJkjIzM5WcnKxHH31US5culcvl0oIFC1RQUMAZEgAA4OFToDQ0NOixxx5TXV2dIiIilJKSok8//VS/+tWvJEmvvvqqQkNDlZeXp5aWFmVlZemtt97yPD4sLEwbN27UrFmz5HQ61b17d+Xn52vx4sUdu1UAACCg3fD3oPgD34MCAEDguSXfgwIAAHCzECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA43Ty9wQg9Z//0RXXfbck9xbOBAAAM3AGBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMbxKVBKSko0YsQI9ezZU9HR0Zo4caKqq6u9xowZM0YhISFey8yZM73G1NbWKjc3V926dVN0dLTmzp2rCxcu3PjWAACAoODT3+KpqKhQQUGBRowYoQsXLujZZ59VZmamDh06pO7du3vGTZ8+XYsXL/bc7tatm+fnixcvKjc3Vw6HQzt27FBdXZ0ee+wxde7cWX/+8587YJMAAECg8ylQNm3a5HV75cqVio6OVlVVlUaPHu25v1u3bnI4HG0+x2effaZDhw5py5YtiomJ0T333KMXX3xR8+bN0/PPP6/w8PDr2AwAABBMbugalMbGRklSVFSU1/2rVq1S7969NXToUBUXF+vs2bOedZWVlRo2bJhiYmI892VlZcntduvgwYNtvk5LS4vcbrfXAgAAgpdPZ1D+r9bWVs2ePVv333+/hg4d6rn/kUceUUJCguLi4rRv3z7NmzdP1dXV+vDDDyVJLpfLK04keW67XK42X6ukpEQvvPDC9U4VAAAEmOsOlIKCAh04cEBffPGF1/0zZszw/Dxs2DDFxsYqPT1dR48e1cCBA6/rtYqLi1VUVOS57Xa7FR8ff30TBwAAxruut3gKCwu1ceNGff755+rbt+9Vx6alpUmSjhw5IklyOByqr6/3GnPp9pWuW7HZbLLb7V4LAAAIXj4FimVZKiws1Nq1a7V161YlJiZe8zF79+6VJMXGxkqSnE6n9u/fr4aGBs+YzZs3y263Kzk52ZfpAACAIOXTWzwFBQUqKyvT+vXr1bNnT881IxEREeratauOHj2qsrIyjRs3Tr169dK+ffs0Z84cjR49WikpKZKkzMxMJScn69FHH9XSpUvlcrm0YMECFRQUyGazdfwWAgCAgOPTGZRly5apsbFRY8aMUWxsrGd5//33JUnh4eHasmWLMjMzlZSUpKeeekp5eXnasGGD5znCwsK0ceNGhYWFyel06ne/+50ee+wxr+9NAQAAtzefzqBYlnXV9fHx8aqoqLjm8yQkJOjjjz/25aUBAMBthL/FAwAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADCOT4FSUlKiESNGqGfPnoqOjtbEiRNVXV3tNaa5uVkFBQXq1auXevTooby8PNXX13uNqa2tVW5urrp166bo6GjNnTtXFy5cuPGtAQAAQcGnQKmoqFBBQYF27typzZs36/z588rMzFRTU5NnzJw5c7RhwwatWbNGFRUVOn78uCZNmuRZf/HiReXm5urcuXPasWOH3nnnHa1cuVILFy7suK0CAAABrZMvgzdt2uR1e+XKlYqOjlZVVZVGjx6txsZGvf322yorK9PYsWMlSStWrNCQIUO0c+dOjRo1Sp999pkOHTqkLVu2KCYmRvfcc49efPFFzZs3T88//7zCw8M7busAAEBAuqFrUBobGyVJUVFRkqSqqiqdP39eGRkZnjFJSUnq16+fKisrJUmVlZUaNmyYYmJiPGOysrLkdrt18ODBNl+npaVFbrfbawEAAMHrugOltbVVs2fP1v3336+hQ4dKklwul8LDwxUZGek1NiYmRi6XyzPm/8bJpfWX1rWlpKREERERniU+Pv56pw0AAALAdQdKQUGBDhw4oNWrV3fkfNpUXFysxsZGz3Ls2LGb/poAAMB/fLoG5ZLCwkJt3LhR27dvV9++fT33OxwOnTt3TqdOnfI6i1JfXy+Hw+EZs3v3bq/nu/Qpn0tjfspms8lms13PVAEAQADy6QyKZVkqLCzU2rVrtXXrViUmJnqtT01NVefOnVVeXu65r7q6WrW1tXI6nZIkp9Op/fv3q6GhwTNm8+bNstvtSk5OvpFtAQAAQcKnMygFBQUqKyvT+vXr1bNnT881IxEREeratasiIiI0bdo0FRUVKSoqSna7XU8++aScTqdGjRolScrMzFRycrIeffRRLV26VC6XSwsWLFBBQQFnSQAAgCQfA2XZsmWSpDFjxnjdv2LFCj3++OOSpFdffVWhoaHKy8tTS0uLsrKy9NZbb3nGhoWFaePGjZo1a5acTqe6d++u/Px8LV68+Ma2BAAABA2fAsWyrGuO6dKli0pLS1VaWnrFMQkJCfr44499eWkAAHAb4W/xAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjdPL3BHD9+s//6IrrvluSewtnAgBAx+IMCgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjONzoGzfvl3jx49XXFycQkJCtG7dOq/1jz/+uEJCQryW7OxsrzEnT57U1KlTZbfbFRkZqWnTpunMmTM3tCEAACB4+BwoTU1NGj58uEpLS684Jjs7W3V1dZ7lvffe81o/depUHTx4UJs3b9bGjRu1fft2zZgxw/fZAwCAoNTJ1wfk5OQoJyfnqmNsNpscDkeb67755htt2rRJX375pe677z5J0ptvvqlx48bp5ZdfVlxcnK9TAgAAQeamXIOybds2RUdHa/DgwZo1a5ZOnDjhWVdZWanIyEhPnEhSRkaGQkNDtWvXrjafr6WlRW6322sBAADBq8MDJTs7W++++67Ky8v1l7/8RRUVFcrJydHFixclSS6XS9HR0V6P6dSpk6KiouRyudp8zpKSEkVERHiW+Pj4jp42AAAwiM9v8VzLlClTPD8PGzZMKSkpGjhwoLZt26b09PTres7i4mIVFRV5brvdbiIFAIAgdtM/ZjxgwAD17t1bR44ckSQ5HA41NDR4jblw4YJOnjx5xetWbDab7Ha71wIAAILXTQ+U77//XidOnFBsbKwkyel06tSpU6qqqvKM2bp1q1pbW5WWlnazpwMAAAKAz2/xnDlzxnM2RJJqamq0d+9eRUVFKSoqSi+88ILy8vLkcDh09OhRPfPMMxo0aJCysrIkSUOGDFF2dramT5+u5cuX6/z58yosLNSUKVP4BA8AAJB0HWdQ9uzZo3vvvVf33nuvJKmoqEj33nuvFi5cqLCwMO3bt0+/+c1vdNddd2natGlKTU3VP//5T9lsNs9zrFq1SklJSUpPT9e4ceP0wAMP6K9//WvHbRUAAAhoPp9BGTNmjCzLuuL6Tz/99JrPERUVpbKyMl9fGgAA3Cb4WzwAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzTyd8TwK3Xf/5HV1z33ZLcWzgTAADaxhkUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxfA6U7du3a/z48YqLi1NISIjWrVvntd6yLC1cuFCxsbHq2rWrMjIydPjwYa8xJ0+e1NSpU2W32xUZGalp06bpzJkzN7QhAAAgePgcKE1NTRo+fLhKS0vbXL906VK98cYbWr58uXbt2qXu3bsrKytLzc3NnjFTp07VwYMHtXnzZm3cuFHbt2/XjBkzrn8rAABAUOnk6wNycnKUk5PT5jrLsvTaa69pwYIFmjBhgiTp3XffVUxMjNatW6cpU6bom2++0aZNm/Tll1/qvvvukyS9+eabGjdunF5++WXFxcXdwOYAAIBg0KHXoNTU1MjlcikjI8NzX0REhNLS0lRZWSlJqqysVGRkpCdOJCkjI0OhoaHatWtXm8/b0tIit9vttQAAgODVoYHicrkkSTExMV73x8TEeNa5XC5FR0d7re/UqZOioqI8Y36qpKREERERniU+Pr4jpw0AAAwTEJ/iKS4uVmNjo2c5duyYv6cEAABuog4NFIfDIUmqr6/3ur++vt6zzuFwqKGhwWv9hQsXdPLkSc+Yn7LZbLLb7V4LAAAIXh0aKImJiXI4HCovL/fc53a7tWvXLjmdTkmS0+nUqVOnVFVV5RmzdetWtba2Ki0trSOnAwAAApTPn+I5c+aMjhw54rldU1OjvXv3KioqSv369dPs2bP1pz/9SXfeeacSExP13HPPKS4uThMnTpQkDRkyRNnZ2Zo+fbqWL1+u8+fPq7CwUFOmTOETPAAAQNJ1BMqePXv0y1/+0nO7qKhIkpSfn6+VK1fqmWeeUVNTk2bMmKFTp07pgQce0KZNm9SlSxfPY1atWqXCwkKlp6crNDRUeXl5euONNzpgcwAAQDDwOVDGjBkjy7KuuD4kJESLFy/W4sWLrzgmKipKZWVlvr40AAC4TQTEp3gAAMDthUABAADGIVAAAIBxCBQAAGAcny+Sxe2r//yPrrjuuyW5t3AmAIBgxxkUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGKeTvyeA4Nd//kdXXPfdktxbOBMAQKDgDAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgdHijPP/+8QkJCvJakpCTP+ubmZhUUFKhXr17q0aOH8vLyVF9f39HTAAAAAeymnEG5++67VVdX51m++OILz7o5c+Zow4YNWrNmjSoqKnT8+HFNmjTpZkwDAAAEqE435Uk7dZLD4bjs/sbGRr399tsqKyvT2LFjJUkrVqzQkCFDtHPnTo0aNepmTAcAAASYm3IG5fDhw4qLi9OAAQM0depU1dbWSpKqqqp0/vx5ZWRkeMYmJSWpX79+qqysvOLztbS0yO12ey0AACB4dXigpKWlaeXKldq0aZOWLVummpoaPfjggzp9+rRcLpfCw8MVGRnp9ZiYmBi5XK4rPmdJSYkiIiI8S3x8fEdPGwAAGKTD3+LJycnx/JySkqK0tDQlJCTogw8+UNeuXa/rOYuLi1VUVOS57Xa7iRQAAILYTf+YcWRkpO666y4dOXJEDodD586d06lTp7zG1NfXt3nNyiU2m012u91rAQAAweumB8qZM2d09OhRxcbGKjU1VZ07d1Z5eblnfXV1tWpra+V0Om/2VAAAQIDo8Ld4nn76aY0fP14JCQk6fvy4Fi1apLCwMD388MOKiIjQtGnTVFRUpKioKNntdj355JNyOp18ggcAAHh0eKB8//33evjhh3XixAn16dNHDzzwgHbu3Kk+ffpIkl599VWFhoYqLy9PLS0tysrK0ltvvdXR0wAAAAGswwNl9erVV13fpUsXlZaWqrS0tKNfGgAABAn+Fg8AADAOgQIAAIxzU77qHugI/ed/dMV13y3JvYUzAQDcapxBAQAAxiFQAACAcXiLB7cV3jYCgMDAGRQAAGAcAgUAABiHQAEAAMYhUAAAgHG4SBYAcEtxsXrb2C/eCBQAuA3wyw+BhkABAKADEYMdg0ABACCABWsQESiAoYL1PzoA0B4ECtBO1xsMhMatczvs69thGwGJQAEAfukHCP6dbi8ECgBJ/McfgFn4ojYAAGAczqAAuCGceQFwMxAoQJAhGG4d9nVwux3+fU3eRt7iAQAAxiFQAACAcXiLBwBwRSa/BYDgxhkUAABgHAIFAAAYh7d4ABgn2N9WuNr2ScGxjcCN4gwKAAAwDoECAACMQ6AAAADjcA0KACCoBfs1TcGKMygAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4fg2U0tJS9e/fX126dFFaWpp2797tz+kAAABD+C1Q3n//fRUVFWnRokX66quvNHz4cGVlZamhocFfUwIAAIbwW6C88sormj59up544gklJydr+fLl6tatm/7+97/7a0oAAMAQnfzxoufOnVNVVZWKi4s994WGhiojI0OVlZWXjW9paVFLS4vndmNjoyTJ7XbflPm1tpy94rqrvSaPM+NxV3Mjzxko2x8oj7uaQNmGm/E4f7wm/7635+Nu9LHX49JzWpZ17cGWH/znP/+xJFk7duzwun/u3LnWyJEjLxu/aNEiSxILCwsLCwtLECzHjh27Ziv45QyKr4qLi1VUVOS53draqpMnT6pXr14KCQm5bLzb7VZ8fLyOHTsmu91+K6cacNhX7ce+aj/2Vfuxr9qPfeUbE/eXZVk6ffq04uLirjnWL4HSu3dvhYWFqb6+3uv++vp6ORyOy8bbbDbZbDav+yIjI6/5Ona73Zh/FNOxr9qPfdV+7Kv2Y1+1H/vKN6btr4iIiHaN88tFsuHh4UpNTVV5ebnnvtbWVpWXl8vpdPpjSgAAwCB+e4unqKhI+fn5uu+++zRy5Ei99tprampq0hNPPOGvKQEAAEP4LVAmT56sH374QQsXLpTL5dI999yjTZs2KSYm5oaf22azadGiRZe9LYTLsa/aj33Vfuyr9mNftR/7yjeBvr9CLKs9n/UBAAC4dfhbPAAAwDgECgAAMA6BAgAAjEOgAAAA4wRdoJSWlqp///7q0qWL0tLStHv3bn9PyTjPP/+8QkJCvJakpCR/T8sY27dv1/jx4xUXF6eQkBCtW7fOa71lWVq4cKFiY2PVtWtXZWRk6PDhw/6ZrJ9da189/vjjlx1r2dnZ/pmsH5WUlGjEiBHq2bOnoqOjNXHiRFVXV3uNaW5uVkFBgXr16qUePXooLy/vsi+zvF20Z3+NGTPmsmNr5syZfpqx/yxbtkwpKSmeL2NzOp365JNPPOsD+bgKqkB5//33VVRUpEWLFumrr77S8OHDlZWVpYaGBn9PzTh333236urqPMsXX3zh7ykZo6mpScOHD1dpaWmb65cuXao33nhDy5cv165du9S9e3dlZWWpubn5Fs/U/661ryQpOzvb61h77733buEMzVBRUaGCggLt3LlTmzdv1vnz55WZmammpibPmDlz5mjDhg1as2aNKioqdPz4cU2aNMmPs/af9uwvSZo+fbrXsbV06VI/zdh/+vbtqyVLlqiqqkp79uzR2LFjNWHCBB08eFBSgB9XHfLX/wwxcuRIq6CgwHP74sWLVlxcnFVSUuLHWZln0aJF1vDhw/09jYAgyVq7dq3ndmtrq+VwOKyXXnrJc9+pU6csm81mvffee36YoTl+uq8sy7Ly8/OtCRMm+GU+JmtoaLAkWRUVFZZl/e8Y6ty5s7VmzRrPmG+++caSZFVWVvprmsb46f6yLMv6xS9+Yf3xj3/036QMdscdd1h/+9vfAv64CpozKOfOnVNVVZUyMjI894WGhiojI0OVlZV+nJmZDh8+rLi4OA0YMEBTp05VbW2tv6cUEGpqauRyubyOs4iICKWlpXGcXcG2bdsUHR2twYMHa9asWTpx4oS/p+R3jY2NkqSoqChJUlVVlc6fP+91XCUlJalfv34cV7p8f12yatUq9e7dW0OHDlVxcbHOnj3rj+kZ4+LFi1q9erWamprkdDoD/rgKiL9m3B7//e9/dfHixcu+iTYmJkbffvutn2ZlprS0NK1cuVKDBw9WXV2dXnjhBT344IM6cOCAevbs6e/pGc3lcklSm8fZpXX4/7KzszVp0iQlJibq6NGjevbZZ5WTk6PKykqFhYX5e3p+0draqtmzZ+v+++/X0KFDJf3vuAoPD7/sj6ByXLW9vyTpkUceUUJCguLi4rRv3z7NmzdP1dXV+vDDD/04W//Yv3+/nE6nmpub1aNHD61du1bJycnau3dvQB9XQRMoaL+cnBzPzykpKUpLS1NCQoI++OADTZs2zY8zQ7CZMmWK5+dhw4YpJSVFAwcO1LZt25Senu7HmflPQUGBDhw4wHVf7XSl/TVjxgzPz8OGDVNsbKzS09N19OhRDRw48FZP068GDx6svXv3qrGxUf/4xz+Un5+viooKf0/rhgXNWzy9e/dWWFjYZVcn19fXy+Fw+GlWgSEyMlJ33XWXjhw54u+pGO/SscRxdn0GDBig3r1737bHWmFhoTZu3KjPP/9cffv29dzvcDh07tw5nTp1ymv87X5cXWl/tSUtLU2SbstjKzw8XIMGDVJqaqpKSko0fPhwvf766wF/XAVNoISHhys1NVXl5eWe+1pbW1VeXi6n0+nHmZnvzJkzOnr0qGJjY/09FeMlJibK4XB4HWdut1u7du3iOGuH77//XidOnLjtjjXLslRYWKi1a9dq69atSkxM9Fqfmpqqzp07ex1X1dXVqq2tvS2Pq2vtr7bs3btXkm67Y6stra2tamlpCfzjyt9X6Xak1atXWzabzVq5cqV16NAha8aMGVZkZKTlcrn8PTWjPPXUU9a2bdusmpoa61//+peVkZFh9e7d22poaPD31Ixw+vRp6+uvv7a+/vprS5L1yiuvWF9//bX173//27Isy1qyZIkVGRlprV+/3tq3b581YcIEKzEx0frxxx/9PPNb72r76vTp09bTTz9tVVZWWjU1NdaWLVusn//859add95pNTc3+3vqt9SsWbOsiIgIa9u2bVZdXZ1nOXv2rGfMzJkzrX79+llbt2619uzZYzmdTsvpdPpx1v5zrf115MgRa/HixdaePXusmpoaa/369daAAQOs0aNH+3nmt978+fOtiooKq6amxtq3b581f/58KyQkxPrss88sywrs4yqoAsWyLOvNN9+0+vXrZ4WHh1sjR460du7c6e8pGWfy5MlWbGysFR4ebv3sZz+zJk+ebB05csTf0zLG559/bkm6bMnPz7cs638fNX7uueesmJgYy2azWenp6VZ1dbV/J+0nV9tXZ8+etTIzM60+ffpYnTt3thISEqzp06fflv/D0NY+kmStWLHCM+bHH3+0/vCHP1h33HGH1a1bN+u3v/2tVVdX579J+9G19ldtba01evRoKyoqyrLZbNagQYOsuXPnWo2Njf6duB/8/ve/txISEqzw8HCrT58+Vnp6uidOLCuwj6sQy7KsW3e+BgAA4NqC5hoUAAAQPAgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxvl/F3uXZdNo9IkAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(mkbatch(2048)[1].to(torch.float32).cpu(), bins=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([1289.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,  477.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,  192.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,   55.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,   28.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    5.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           2.]),\n",
       " array([1.     , 1.09375, 1.1875 , 1.28125, 1.375  , 1.46875, 1.5625 ,\n",
       "        1.65625, 1.75   , 1.84375, 1.9375 , 2.03125, 2.125  , 2.21875,\n",
       "        2.3125 , 2.40625, 2.5    , 2.59375, 2.6875 , 2.78125, 2.875  ,\n",
       "        2.96875, 3.0625 , 3.15625, 3.25   , 3.34375, 3.4375 , 3.53125,\n",
       "        3.625  , 3.71875, 3.8125 , 3.90625, 4.     , 4.09375, 4.1875 ,\n",
       "        4.28125, 4.375  , 4.46875, 4.5625 , 4.65625, 4.75   , 4.84375,\n",
       "        4.9375 , 5.03125, 5.125  , 5.21875, 5.3125 , 5.40625, 5.5    ,\n",
       "        5.59375, 5.6875 , 5.78125, 5.875  , 5.96875, 6.0625 , 6.15625,\n",
       "        6.25   , 6.34375, 6.4375 , 6.53125, 6.625  , 6.71875, 6.8125 ,\n",
       "        6.90625, 7.     ]),\n",
       " <BarContainer object of 64 artists>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjSklEQVR4nO3dfVSUdf7/8dcgcpMyg1gMzobItm1K3mRSNNm9HEnJXU90Y1FRsbqnBcvsRvmW2o2F2XZHS5ieVj0nPd3sWa0sKfKOLRERl03JyDZTzAY6h5gJOiLK/P7oeP2a1EJ3aOZjz8c51znNdX1mrvdcx3N4NjMMNr/f7xcAAIBBIkI9AAAAwPEiYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYJzLUA/SUrq4u7du3T3FxcbLZbKEeBwAAdIPf79e3334rl8uliIhjv85y0gbMvn37lJycHOoxAADACWhsbNTpp59+zOMnbcDExcVJ+v4C2O32EE8DAAC6w+fzKTk52fo5fiwnbcAcftvIbrcTMAAAGObnPv7Bh3gBAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGCcyFAPYKJBM98+5rEv5mX/gpMAAPDrxCswAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxz3AFTWVmpCRMmyOVyyWazaeXKldaxzs5OzZgxQ8OGDVOfPn3kcrl0yy23aN++fQGP0dLSotzcXNntdsXHxys/P19tbW0Baz766CNdfPHFiomJUXJysubPn39izxAAAJx0jjtg2tvbNWLECJWWlh5x7LvvvtPWrVs1a9Ysbd26Vf/85z/V0NCgP/zhDwHrcnNzVV9fr4qKCq1atUqVlZWaMmWKddzn82ns2LFKSUlRbW2tnnzyST300ENauHDhCTxFAABwsrH5/X7/Cd/ZZtOKFSs0ceLEY66pqanR+eefr927d2vgwIHasWOH0tLSVFNTo/T0dElSeXm5xo8fr71798rlcqmsrEwPPPCAPB6PoqKiJEkzZ87UypUr9cknn3RrNp/PJ4fDIa/XK7vdfqJP8agGzXz7mMe+mJcd1HMBAPBr0t2f3z3+GRiv1yubzab4+HhJUlVVleLj4614kaTMzExFRESourraWnPJJZdY8SJJWVlZamho0DfffNPTIwMAgDAX2ZMPvn//fs2YMUM33HCDVVEej0eJiYmBQ0RGKiEhQR6Px1qTmpoasMbpdFrH+vXrd8S5Ojo61NHRYd32+XxBfS4AACB89NgrMJ2dnbruuuvk9/tVVlbWU6exFBcXy+FwWFtycnKPnxMAAIRGjwTM4XjZvXu3KioqAt7DSkpKUnNzc8D6gwcPqqWlRUlJSdaapqamgDWHbx9e82NFRUXyer3W1tjYGMynBAAAwkjQA+ZwvOzcuVPvv/+++vfvH3Dc7XartbVVtbW11r61a9eqq6tLGRkZ1prKykp1dnZaayoqKnTWWWcd9e0jSYqOjpbdbg/YAADAyem4A6atrU11dXWqq6uTJO3atUt1dXXas2ePOjs7dc0112jLli1atmyZDh06JI/HI4/HowMHDkiShgwZoiuvvFKTJ0/W5s2b9eGHH6qwsFCTJk2Sy+WSJN14442KiopSfn6+6uvr9eqrr+q5557T9OnTg/fMAQCAsY7716jXr1+vyy+//Ij9eXl5euihh4748O1h69at02WXXSbp+y+yKyws1FtvvaWIiAjl5OSopKREffv2tdZ/9NFHKigoUE1NjU499VRNnTpVM2bM6Pac/Bo1AADm6e7P7//pe2DCGQEDAIB5wuZ7YAAAAIKNgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgnOMOmMrKSk2YMEEul0s2m00rV64MOO73+zV79mwNGDBAsbGxyszM1M6dOwPWtLS0KDc3V3a7XfHx8crPz1dbW1vAmo8++kgXX3yxYmJilJycrPnz5x//swMAACel4w6Y9vZ2jRgxQqWlpUc9Pn/+fJWUlGjBggWqrq5Wnz59lJWVpf3791trcnNzVV9fr4qKCq1atUqVlZWaMmWKddzn82ns2LFKSUlRbW2tnnzyST300ENauHDhCTxFAABwsrH5/X7/Cd/ZZtOKFSs0ceJESd+/+uJyuXTPPffo3nvvlSR5vV45nU4tWbJEkyZN0o4dO5SWlqaamhqlp6dLksrLyzV+/Hjt3btXLpdLZWVleuCBB+TxeBQVFSVJmjlzplauXKlPPvmkW7P5fD45HA55vV7Z7fYTfYpHNWjm28c89sW87KCeCwCAX5Pu/vwO6mdgdu3aJY/Ho8zMTGufw+FQRkaGqqqqJElVVVWKj4+34kWSMjMzFRERoerqamvNJZdcYsWLJGVlZamhoUHffPPNUc/d0dEhn88XsAEAgJNTUAPG4/FIkpxOZ8B+p9NpHfN4PEpMTAw4HhkZqYSEhIA1R3uMH57jx4qLi+VwOKwtOTn5f39CAAAgLJ00v4VUVFQkr9drbY2NjaEeCQAA9JCgBkxSUpIkqampKWB/U1OTdSwpKUnNzc0Bxw8ePKiWlpaANUd7jB+e48eio6Nlt9sDNgAAcHIKasCkpqYqKSlJa9assfb5fD5VV1fL7XZLktxut1pbW1VbW2utWbt2rbq6upSRkWGtqaysVGdnp7WmoqJCZ511lvr16xfMkQEAgIGOO2Da2tpUV1enuro6Sd9/cLeurk579uyRzWbTtGnTNHfuXL355pvatm2bbrnlFrlcLus3lYYMGaIrr7xSkydP1ubNm/Xhhx+qsLBQkyZNksvlkiTdeOONioqKUn5+vurr6/Xqq6/queee0/Tp04P2xAEAgLkij/cOW7Zs0eWXX27dPhwVeXl5WrJkie6//361t7drypQpam1t1UUXXaTy8nLFxMRY91m2bJkKCws1ZswYRUREKCcnRyUlJdZxh8Oh9957TwUFBRo1apROPfVUzZ49O+C7YgAAwK/X//Q9MOGM74EBAMA8IfkeGAAAgF8CAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOEEPmEOHDmnWrFlKTU1VbGyszjjjDD366KPy+/3WGr/fr9mzZ2vAgAGKjY1VZmamdu7cGfA4LS0tys3Nld1uV3x8vPLz89XW1hbscQEAgIGCHjBPPPGEysrK9Le//U07duzQE088ofnz5+v555+31syfP18lJSVasGCBqqur1adPH2VlZWn//v3WmtzcXNXX16uiokKrVq1SZWWlpkyZEuxxAQCAgWz+H740EgRXXXWVnE6nXnrpJWtfTk6OYmNj9fLLL8vv98vlcumee+7RvffeK0nyer1yOp1asmSJJk2apB07digtLU01NTVKT0+XJJWXl2v8+PHau3evXC7Xz87h8/nkcDjk9Xplt9uD+RQ1aObbxzz2xbzsoJ4LAIBfk+7+/A76KzAXXnih1qxZo08//VSS9J///EcffPCBxo0bJ0natWuXPB6PMjMzrfs4HA5lZGSoqqpKklRVVaX4+HgrXiQpMzNTERERqq6uPup5Ozo65PP5AjYAAHByigz2A86cOVM+n0+DBw9Wr169dOjQIT322GPKzc2VJHk8HkmS0+kMuJ/T6bSOeTweJSYmBg4aGamEhARrzY8VFxfr4YcfDvbTAQAAYSjor8C89tprWrZsmZYvX66tW7dq6dKl+utf/6qlS5cG+1QBioqK5PV6ra2xsbFHzwcAAEIn6K/A3HfffZo5c6YmTZokSRo2bJh2796t4uJi5eXlKSkpSZLU1NSkAQMGWPdramrSOeecI0lKSkpSc3NzwOMePHhQLS0t1v1/LDo6WtHR0cF+OgAAIAwF/RWY7777ThERgQ/bq1cvdXV1SZJSU1OVlJSkNWvWWMd9Pp+qq6vldrslSW63W62traqtrbXWrF27Vl1dXcrIyAj2yAAAwDBBfwVmwoQJeuyxxzRw4ECdffbZ+ve//62nn35at99+uyTJZrNp2rRpmjt3rs4880ylpqZq1qxZcrlcmjhxoiRpyJAhuvLKKzV58mQtWLBAnZ2dKiws1KRJk7r1G0gAAODkFvSAef755zVr1iz95S9/UXNzs1wul/785z9r9uzZ1pr7779f7e3tmjJlilpbW3XRRRepvLxcMTEx1pply5apsLBQY8aMUUREhHJyclRSUhLscQEAgIGC/j0w4YLvgQEAwDwh+x4YAACAnkbAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4PRIwX375pW666Sb1799fsbGxGjZsmLZs2WId9/v9mj17tgYMGKDY2FhlZmZq586dAY/R0tKi3Nxc2e12xcfHKz8/X21tbT0xLgAAMEzQA+abb77R6NGj1bt3b61evVoff/yxnnrqKfXr189aM3/+fJWUlGjBggWqrq5Wnz59lJWVpf3791trcnNzVV9fr4qKCq1atUqVlZWaMmVKsMcFAAAGsvn9fn8wH3DmzJn68MMP9a9//euox/1+v1wul+655x7de++9kiSv1yun06klS5Zo0qRJ2rFjh9LS0lRTU6P09HRJUnl5ucaPH6+9e/fK5XL97Bw+n08Oh0Ner1d2uz14T1DSoJlvH/PYF/Oyg3ouAAB+Tbr78zvor8C8+eabSk9P17XXXqvExESNHDlSixYtso7v2rVLHo9HmZmZ1j6Hw6GMjAxVVVVJkqqqqhQfH2/FiyRlZmYqIiJC1dXVRz1vR0eHfD5fwAYAAE5OQQ+Yzz//XGVlZTrzzDP17rvv6o477tCdd96ppUuXSpI8Ho8kyel0BtzP6XRaxzwejxITEwOOR0ZGKiEhwVrzY8XFxXI4HNaWnJwc7KcGAADCRNADpqurS+eee64ef/xxjRw5UlOmTNHkyZO1YMGCYJ8qQFFRkbxer7U1Njb26PkAAEDoBD1gBgwYoLS0tIB9Q4YM0Z49eyRJSUlJkqSmpqaANU1NTdaxpKQkNTc3Bxw/ePCgWlparDU/Fh0dLbvdHrABAICTU9ADZvTo0WpoaAjY9+mnnyolJUWSlJqaqqSkJK1Zs8Y67vP5VF1dLbfbLUlyu91qbW1VbW2ttWbt2rXq6upSRkZGsEcGAACGiQz2A95999268MIL9fjjj+u6667T5s2btXDhQi1cuFCSZLPZNG3aNM2dO1dnnnmmUlNTNWvWLLlcLk2cOFHS96/YXHnlldZbT52dnSosLNSkSZO69RtIAADg5Bb0gDnvvPO0YsUKFRUV6ZFHHlFqaqqeffZZ5ebmWmvuv/9+tbe3a8qUKWptbdVFF12k8vJyxcTEWGuWLVumwsJCjRkzRhEREcrJyVFJSUmwxwUAAAYK+vfAhAu+BwYAAPOE7HtgAAAAehoBAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA40SGegCcXAbNfPuYx76Yl/0LTgIAOJnxCgwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDg9HjDz5s2TzWbTtGnTrH379+9XQUGB+vfvr759+yonJ0dNTU0B99uzZ4+ys7N1yimnKDExUffdd58OHjzY0+MCAAAD9GjA1NTU6MUXX9Tw4cMD9t99991666239Prrr2vDhg3at2+frr76auv4oUOHlJ2drQMHDmjjxo1aunSplixZotmzZ/fkuAAAwBA9FjBtbW3Kzc3VokWL1K9fP2u/1+vVSy+9pKefflpXXHGFRo0apcWLF2vjxo3atGmTJOm9997Txx9/rJdfflnnnHOOxo0bp0cffVSlpaU6cOBAT40MAAAM0WMBU1BQoOzsbGVmZgbsr62tVWdnZ8D+wYMHa+DAgaqqqpIkVVVVadiwYXI6ndaarKws+Xw+1dfXH/V8HR0d8vl8ARsAADg5RfbEg77yyivaunWrampqjjjm8XgUFRWl+Pj4gP1Op1Mej8da88N4OXz88LGjKS4u1sMPPxyE6QEAQLgL+iswjY2Nuuuuu7Rs2TLFxMQE++GPqaioSF6v19oaGxt/sXMDAIBfVtADpra2Vs3NzTr33HMVGRmpyMhIbdiwQSUlJYqMjJTT6dSBAwfU2toacL+mpiYlJSVJkpKSko74raTDtw+v+bHo6GjZ7faADQAAnJyCHjBjxozRtm3bVFdXZ23p6enKzc21/rt3795as2aNdZ+Ghgbt2bNHbrdbkuR2u7Vt2zY1NzdbayoqKmS325WWlhbskQEAgGGC/hmYuLg4DR06NGBfnz591L9/f2t/fn6+pk+froSEBNntdk2dOlVut1sXXHCBJGns2LFKS0vTzTffrPnz58vj8ejBBx9UQUGBoqOjgz0yAAAwTI98iPfnPPPMM4qIiFBOTo46OjqUlZWlF154wTreq1cvrVq1SnfccYfcbrf69OmjvLw8PfLII6EYFwAAhJlfJGDWr18fcDsmJkalpaUqLS095n1SUlL0zjvv9PBkAADARPwtJAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGCXrAFBcX67zzzlNcXJwSExM1ceJENTQ0BKzZv3+/CgoK1L9/f/Xt21c5OTlqamoKWLNnzx5lZ2frlFNOUWJiou677z4dPHgw2OMCAAADBT1gNmzYoIKCAm3atEkVFRXq7OzU2LFj1d7ebq25++679dZbb+n111/Xhg0btG/fPl199dXW8UOHDik7O1sHDhzQxo0btXTpUi1ZskSzZ88O9rgAAMBANr/f7+/JE3z99ddKTEzUhg0bdMkll8jr9eq0007T8uXLdc0110iSPvnkEw0ZMkRVVVW64IILtHr1al111VXat2+fnE6nJGnBggWaMWOGvv76a0VFRf3seX0+nxwOh7xer+x2e1Cf06CZbx/z2BfzsoN6LtNwbQAA/4vu/vzu8c/AeL1eSVJCQoIkqba2Vp2dncrMzLTWDB48WAMHDlRVVZUkqaqqSsOGDbPiRZKysrLk8/lUX1/f0yMDAIAwF9mTD97V1aVp06Zp9OjRGjp0qCTJ4/EoKipK8fHxAWudTqc8Ho+15ofxcvj44WNH09HRoY6ODuu2z+cL1tMAAABhpkdfgSkoKND27dv1yiuv9ORpJH3/4WGHw2FtycnJPX5OAAAQGj0WMIWFhVq1apXWrVun008/3dqflJSkAwcOqLW1NWB9U1OTkpKSrDU//q2kw7cPr/mxoqIieb1ea2tsbAziswEAAOEk6G8h+f1+TZ06VStWrND69euVmpoacHzUqFHq3bu31qxZo5ycHElSQ0OD9uzZI7fbLUlyu9167LHH1NzcrMTERElSRUWF7Ha70tLSjnre6OhoRUdHB/vpAD2ODz4DwPELesAUFBRo+fLleuONNxQXF2d9ZsXhcCg2NlYOh0P5+fmaPn26EhISZLfbNXXqVLndbl1wwQWSpLFjxyotLU0333yz5s+fL4/HowcffFAFBQVECgAACH7AlJWVSZIuu+yygP2LFy/WrbfeKkl65plnFBERoZycHHV0dCgrK0svvPCCtbZXr15atWqV7rjjDrndbvXp00d5eXl65JFHgj0uAAAwUI+8hfRzYmJiVFpaqtLS0mOuSUlJ0TvvvBPM0QAAwEmCv4UEAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjBMZ6gEA4EQMmvn2MY99MS/7F5wEQCjwCgwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4/A9MABwEuH7cfBrwSswAADAOAQMAAAwTlgHTGlpqQYNGqSYmBhlZGRo8+bNoR4JAACEgbANmFdffVXTp0/XnDlztHXrVo0YMUJZWVlqbm4O9WgAACDEwjZgnn76aU2ePFm33Xab0tLStGDBAp1yyin6+9//HurRAABAiIXlbyEdOHBAtbW1KioqsvZFREQoMzNTVVVVR71PR0eHOjo6rNter1eS5PP5gj5fV8d3xzzWE+czCdfm+HHNTgzX7ei4Lidm6Jx3j3ls+8NZv+AkOPzv1O/3//RCfxj68ssv/ZL8GzduDNh/3333+c8///yj3mfOnDl+SWxsbGxsbGwnwdbY2PiTrRCWr8CciKKiIk2fPt263dXVpZaWFvXv3182my1o5/H5fEpOTlZjY6PsdnvQHvdkxfXqPq5V93Gtuo9r1X1cq+7ryWvl9/v17bffyuVy/eS6sAyYU089Vb169VJTU1PA/qamJiUlJR31PtHR0YqOjg7YFx8f31Mjym638w/8OHC9uo9r1X1cq+7jWnUf16r7eupaORyOn10Tlh/ijYqK0qhRo7RmzRprX1dXl9asWSO32x3CyQAAQDgIy1dgJGn69OnKy8tTenq6zj//fD377LNqb2/XbbfdFurRAABAiIVtwFx//fX6+uuvNXv2bHk8Hp1zzjkqLy+X0+kM6VzR0dGaM2fOEW9X4ei4Xt3Hteo+rlX3ca26j2vVfeFwrWx+/8/9nhIAAEB4CcvPwAAAAPwUAgYAABiHgAEAAMYhYAAAgHEImONQWVmpCRMmyOVyyWazaeXKlaEeKSwVFxfrvPPOU1xcnBITEzVx4kQ1NDSEeqywVFZWpuHDh1tfBuV2u7V69epQj2WEefPmyWazadq0aaEeJSw99NBDstlsAdvgwYNDPVbY+vLLL3XTTTepf//+io2N1bBhw7Rly5ZQjxV2Bg0adMS/K5vNpoKCgl98FgLmOLS3t2vEiBEqLS0N9ShhbcOGDSooKNCmTZtUUVGhzs5OjR07Vu3t7aEeLeycfvrpmjdvnmpra7VlyxZdccUV+uMf/6j6+vpQjxbWampq9OKLL2r48OGhHiWsnX322frqq6+s7YMPPgj1SGHpm2++0ejRo9W7d2+tXr1aH3/8sZ566in169cv1KOFnZqamoB/UxUVFZKka6+99hefJWy/ByYcjRs3TuPGjQv1GGGvvLw84PaSJUuUmJio2tpaXXLJJSGaKjxNmDAh4PZjjz2msrIybdq0SWeffXaIpgpvbW1tys3N1aJFizR37txQjxPWIiMjj/nnV/D/PfHEE0pOTtbixYutfampqSGcKHyddtppAbfnzZunM844Q5deeukvPguvwKDHeb1eSVJCQkKIJwlvhw4d0iuvvKL29nb+ZMZPKCgoUHZ2tjIzM0M9StjbuXOnXC6Xfvvb3yo3N1d79uwJ9Uhh6c0331R6erquvfZaJSYmauTIkVq0aFGoxwp7Bw4c0Msvv6zbb789qH80ubt4BQY9qqurS9OmTdPo0aM1dOjQUI8TlrZt2ya32639+/erb9++WrFihdLS0kI9Vlh65ZVXtHXrVtXU1IR6lLCXkZGhJUuW6KyzztJXX32lhx9+WBdffLG2b9+uuLi4UI8XVj7//HOVlZVp+vTp+r//+z/V1NTozjvvVFRUlPLy8kI9XthauXKlWltbdeutt4bk/AQMelRBQYG2b9/Oe+8/4ayzzlJdXZ28Xq/+8Y9/KC8vTxs2bCBifqSxsVF33XWXKioqFBMTE+pxwt4P3+4ePny4MjIylJKSotdee035+fkhnCz8dHV1KT09XY8//rgkaeTIkdq+fbsWLFhAwPyEl156SePGjZPL5QrJ+XkLCT2msLBQq1at0rp163T66aeHepywFRUVpd/97ncaNWqUiouLNWLECD333HOhHivs1NbWqrm5Weeee64iIyMVGRmpDRs2qKSkRJGRkTp06FCoRwxr8fHx+v3vf6/PPvss1KOEnQEDBhzxPwxDhgzhLbefsHv3br3//vv605/+FLIZeAUGQef3+zV16lStWLFC69ev58Nwx6mrq0sdHR2hHiPsjBkzRtu2bQvYd9ttt2nw4MGaMWOGevXqFaLJzNDW1qb//ve/uvnmm0M9StgZPXr0EV/18OmnnyolJSVEE4W/xYsXKzExUdnZ2SGbgYA5Dm1tbQH/97Jr1y7V1dUpISFBAwcODOFk4aWgoEDLly/XG2+8obi4OHk8HkmSw+FQbGxsiKcLL0VFRRo3bpwGDhyob7/9VsuXL9f69ev17rvvhnq0sBMXF3fE56j69Omj/v378/mqo7j33ns1YcIEpaSkaN++fZozZ4569eqlG264IdSjhZ27775bF154oR5//HFdd9112rx5sxYuXKiFCxeGerSw1NXVpcWLFysvL0+RkSHMCD+6bd26dX5JR2x5eXmhHi2sHO0aSfIvXrw41KOFndtvv92fkpLij4qK8p922mn+MWPG+N97771Qj2WMSy+91H/XXXeFeoywdP311/sHDBjgj4qK8v/mN7/xX3/99f7PPvss1GOFrbfeess/dOhQf3R0tH/w4MH+hQsXhnqksPXuu+/6JfkbGhpCOofN7/f7Q5NOAAAAJ4YP8QIAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIzz/wBRiRVxrEYfAQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(mktunebatch(2048)[1].to(torch.float32).cpu(), bins=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import pickle\n",
    "\n",
    "# graphs, labels, _ = mkbatch(3*10**5)\n",
    "\n",
    "# data = {\n",
    "#     \"data\": graphs,\n",
    "#     \"labels\": labels\n",
    "# }\n",
    "\n",
    "# with open('data.pkl', 'wb') as file:\n",
    "#     pickle.dump(data, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q3Cg_8UQep8g"
   },
   "source": [
    "# Step 2: Define Transformer Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "tLOWhg_CeWzH"
   },
   "outputs": [],
   "source": [
    "class TransformerModel(nn.Module):\n",
    "    def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, dropout):\n",
    "        super().__init__()\n",
    "        self.model_dim = model_dim\n",
    "        self.embedding = nn.Embedding(input_dim, model_dim // 2, dtype=torch.bfloat16)\n",
    "        # # seq_len is odd\n",
    "        # self.fancy_encoding = torch.repeat_interleave(torch.rand((1, seq_len // 2 + 1, model_dim // 2), device=device, dtype=torch.bfloat16), 2, dim=1)\n",
    "        # # cut off last element since the target vertex is not repeated\n",
    "        # self.fancy_encoding = self.fancy_encoding[:, :seq_len, :]\n",
    "        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads,\n",
    "                                                   dim_feedforward=model_dim*4,\n",
    "                                                   dropout=dropout, batch_first=True, dtype=torch.bfloat16)\n",
    "        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n",
    "        self.fc_out = nn.Linear(model_dim, output_dim, dtype=torch.bfloat16)\n",
    "\n",
    "    def forward(self, src, key_padding_mask):\n",
    "        batch_sz = src.size(0)\n",
    "        embed = torch.cat((self.embedding(src[:,:-1:2]), self.embedding(src[:,1::2])), dim=2)\n",
    "        last_dude = torch.cat((self.embedding(src[:,-1:]), torch.ones((batch_sz, 1, self.model_dim // 2), dtype=torch.bfloat16, device=device)), dim=2)\n",
    "        final_embed = torch.cat((embed, last_dude), dim=1)\n",
    "        output = self.transformer_encoder(final_embed, src_key_padding_mask=key_padding_mask[:, ::2])\n",
    "        return self.fc_out(output[:, -1, :])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bpIeg86S-hBb"
   },
   "source": [
    "# Step 3: Make Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "kWXvJRDYgFVP",
    "outputId": "c13adb9d-6565-43b5-8437-20cef3dc0d16"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training data: 524M\n",
      "Trainable parameters in the model: 800K\n"
     ]
    }
   ],
   "source": [
    "# PARAMS\n",
    "VOCAB_SIZE = 1 + MAX_VTXS # one more than the max number of vertices\n",
    "MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
    "NEPOCHS = 1000\n",
    "BSZ = 2**14 # Batch size\n",
    "BPE = 32 # Batches per epoch\n",
    "LR = 5e-5\n",
    "WD = 1e-5\n",
    "NHEADS = 4\n",
    "NLAYERS = 16\n",
    "DROPOUT = 0 # 0.2\n",
    "model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
    "                         output_dim=1, num_heads=NHEADS,\n",
    "                         num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
    "                         dropout=DROPOUT).to(device)\n",
    "# model = torch.compile(model)\n",
    "\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)\n",
    "\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(f\"Training data: {NEPOCHS*BSZ*BPE//10**6}M\")\n",
    "print(f\"Trainable parameters in the model: {trainable_params//1000}K\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# with open(\"data.pkl\", \"rb\") as f:\n",
    "#     pickled_stuff = pickle.load(f)\n",
    "\n",
    "# data = pickled_stuff[\"data\"].to(device)\n",
    "# label = pickled_stuff[\"labels\"].to(device)\n",
    "# padding_mask = (data == PAD_TOKEN).bool().to(device)\n",
    "# dataset = TensorDataset(data, label, padding_mask)\n",
    "# # train_dataset, test_dataset = torch.utils.data.random_split(dataset, [.9, .1])\n",
    "# train_loader = DataLoader(dataset, batch_size=BSZ, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# figure out if it's doing better on short paths\n",
    "def evaluate_short():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output[batch_labels == 1].squeeze(1), batch_labels[batch_labels==1])\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f8Zn33m7CxL5"
   },
   "source": [
    "# Step 4: Train the Model for the first task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 486
    },
    "id": "pvTfzGmCeXU4",
    "outputId": "0d3a20f3-23be-4c19-9eb6-46bfe11a48b1"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/1000 \t Train Err: 149.6562 \t Test Err: 137.0000, Test short loss: 0.3164\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/1000 \t Train Err: 129.4688 \t Test Err: 120.0000, Test short loss: 0.4512\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/1000 \t Train Err: 116.1719 \t Test Err: 110.5000, Test short loss: 2.2500\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4/1000 \t Train Err: 107.9375 \t Test Err: 104.0000, Test short loss: 4.0625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5/1000 \t Train Err: 103.0938 \t Test Err: 102.0000, Test short loss: 5.5625\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6/1000 \t Train Err: 99.9531 \t Test Err: 99.0000, Test short loss: 6.6250\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7/1000 \t Train Err: 99.0469 \t Test Err: 101.0000, Test short loss: 7.2188\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:13<00:00,  2.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8/1000 \t Train Err: 97.8594 \t Test Err: 97.5000, Test short loss: 7.6250\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|██████████████████████████████████████████████████████████████████                                                                                     | 14/32 [00:05<00:07,  2.41it/s]"
     ]
    }
   ],
   "source": [
    "train_err = []\n",
    "test_err = []\n",
    "\n",
    "# clear loss file\n",
    "open('loss', 'w').close()\n",
    "\n",
    "for epoch in range(NEPOCHS):\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    for i in tqdm(range(BPE)):\n",
    "        batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
    "    # for batch_src, batch_labels, batch_padding_mask in tqdm(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "        train_loss += loss.item() / BPE\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    test_loss = evaluate()\n",
    "    test_short_loss = evaluate_short()\n",
    "    \n",
    "    test_err.append(test_loss)\n",
    "    train_err.append(train_loss)\n",
    "    with open('loss', 'a') as f:\n",
    "        f.write(f\"{train_loss} {test_loss}\\n\")\n",
    "    print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}, Test short loss: {test_short_loss:.4f}\")\n",
    "    \n",
    "    if epoch % 100 == 99:\n",
    "        torch.save(model.state_dict(), f\"model_weights_{epoch}.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"\"\"\n",
    "# Now let's figure out what it's doing. \n",
    "\n",
    "# step 1: figure out what people are attending to \n",
    "# \"\"\"\n",
    "\n",
    "# example_graph, answer, padding = mkbatch(1)\n",
    "# sentance_embeddings = model.full_embedding(example_graph)[0,:,:][example_graph.flatten() != 0]\n",
    "# WQ,WK,WV = torch.split(model.transformer_encoder.layers[0].self_attn.in_proj_weight, (MODEL_DIM, MODEL_DIM, MODEL_DIM))\n",
    "\n",
    "# Q = sentance_embeddings@WQ\n",
    "# K = sentance_embeddings@WK\n",
    "\n",
    "# raw_scores = Q @ K.T / sqrt(MODEL_DIM)\n",
    "# soft = torch.softmax(raw_scores, dim=-1).detach().cpu().to(float).numpy()\n",
    "# plt.imshow(soft)\n",
    "# plt.show()\n",
    "\n",
    "# print(example_graph)\n",
    "\n",
    "# print(Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.suptitle('MSE vs Epochs')\n",
    "plt.plot(train_err, label='Train', color='blue')\n",
    "plt.plot(test_err, label='Test', color='red')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('MSE')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LoGEmM5lH7_A"
   },
   "outputs": [],
   "source": [
    "batch_src, batch_labels, batch_padding_mask = mkbatch(4096)\n",
    "output = model(batch_src, batch_padding_mask)\n",
    "batch_src[0], batch_labels[0], output[0]\n",
    "plt.scatter(batch_labels.detach().to(torch.float16).cpu().numpy().flatten(), output.detach().to(torch.float16).cpu().numpy().flatten())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LC6Xv3YfC0Rm"
   },
   "source": [
    "# Step 5: Fine Tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TUNE_EPOCHS = 100\n",
    "TUNE_LR = 1e-5\n",
    "TUNE_WD = 1e-5\n",
    "\n",
    "tune_criterion = nn.MSELoss()\n",
    "tune_optimizer = torch.optim.Adam(model.parameters(), lr=TUNE_LR, weight_decay=TUNE_WD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tuneevaluate():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        batch_src, batch_labels, batch_padding_mask = mktunebatch(BSZ)\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tune_train_err = []\n",
    "tune_test_err = []\n",
    "\n",
    "# clear loss file\n",
    "open('tune_loss', 'w').close()\n",
    "\n",
    "for epoch in range(N_TUNE_EPOCHS):\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    for i in tqdm(range(BPE)):\n",
    "        batch_src, batch_labels, batch_padding_mask = mktunebatch(BSZ)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "        train_loss += loss.item() / BPE\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    test_loss = tuneevaluate()\n",
    "    \n",
    "    tune_test_err.append(test_loss)\n",
    "    tune_train_err.append(train_loss)\n",
    "    with open('tune_loss', 'a') as f:\n",
    "        f.write(f\"{train_loss} {test_loss}\\n\")\n",
    "    print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
    "\n",
    "    if epoch % 10 == 9:\n",
    "        torch.save(model.state_dict(), f\"tune_model_weights_{epoch}.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.suptitle('MSE vs Epochs')\n",
    "plt.plot(tune_train_err, label='Train', color='blue')\n",
    "plt.plot(tune_test_err, label='Test', color='red')\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('MSE')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_src, batch_labels, batch_padding_mask = mktunebatch(2048)\n",
    "output = model(batch_src, batch_padding_mask)\n",
    "batch_src[0], batch_labels[0], output[0]\n",
    "plt.scatter(batch_labels.detach().to(torch.float16).cpu().numpy().flatten(), output.detach().to(torch.float16).cpu().numpy().flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JtTLXn4zC1z_"
   },
   "source": [
    "# Step 6: Test generalization"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}