aboutsummaryrefslogtreecommitdiff
path: root/transformer_shortest_paths.ipynb
blob: fcb24c7cef310270b076042f4dc84f866e028366 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LPphBnKR-aWF"
   },
   "source": [
    "# Step 0: Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "execution_state": "idle",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ge5QvElvhCOw",
    "outputId": "c7cdaefa-d6dc-44ad-c258-e4fb2aca97a5"
   },
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "import pickle\n",
    "# using tqdm.auto glitches out collaborative editing\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pickle\n",
    "from math import sqrt\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import matplotlib.pyplot as plt\n",
    "# %matplotlib widget\n",
    "torch.manual_seed(42)\n",
    "\n",
    "import os\n",
    "from IPython.display import clear_output\n",
    "import ipdb\n",
    "\n",
    "import random\n",
    "random.seed(42)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "assert device.type == 'cuda', \"CUDA is not available. Please check your GPU setup.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "execution_state": "idle",
   "metadata": {
    "id": "lylOX2POPwFL"
   },
   "outputs": [],
   "source": [
    "SEQ_LEN = 65 # means 32 edges, final token is the target vertex\n",
    "PAD_TOKEN = 0\n",
    "AVG_DEG = 2\n",
    "MAX_VTXS = SEQ_LEN//AVG_DEG + 1 # 32 (exclusive)\n",
    "MIN_VTXS = 8\n",
    "MAX_TUNE_VTXS = 16\n",
    "PAD_TOKEN = 0\n",
    "# vertices are labelled 1,2,...,63\n",
    "# we also have a padding token which is 0."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gKt-yIpDebF1"
   },
   "source": [
    "# Step 1: Generate synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "execution_state": "idle",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1IbzGIWseK3E",
    "outputId": "a3cbc233-358c-4e17-ea6e-f4e9349d886b"
   },
   "outputs": [],
   "source": [
    "# original task data\n",
    "# the data will be edge lists\n",
    "# like this: [1 3 1 5 2 4 0 0 0 0 2]\n",
    "# this represents edges (1,3), (1,5) (2,4)\n",
    "# (the zeros are just padding tokens)\n",
    "# the final 2 means which vertex we're going to \n",
    "\n",
    "# the label is the shortest distance from vtx 1 to vtx 2\n",
    "# or \"number of vertices\" if no path exists\n",
    "\n",
    "def random_graph(n):\n",
    "    edge_list = []\n",
    "    adjacencies = [set() for _ in range(n+1)]\n",
    "    indices = [random.randint(1, n-1) for _ in range(AVG_DEG * (n-1))]\n",
    "    for i in range(0, len(indices), 2):\n",
    "        u = indices[i]\n",
    "        v = indices[i + 1]\n",
    "        if u != v:\n",
    "            edge_list += [u,v]\n",
    "            adjacencies[u].add(v)\n",
    "            adjacencies[v].add(u)\n",
    "\n",
    "    edge_list += [PAD_TOKEN]*(SEQ_LEN-len(edge_list))\n",
    "    return edge_list, adjacencies\n",
    "\n",
    "\"\"\"\n",
    "input: G, represented as an adjacency list\n",
    "output: [number of vertices]+[d(1,i) for i in range(n)] if target=None\n",
    "if target is set to some value, then we instead just output that specific distance\n",
    "\"\"\"\n",
    "def SSSP(n, G, target=2):\n",
    "    dist = [n for _ in G]\n",
    "    dist[1] = 0\n",
    "    frontier = deque()\n",
    "    frontier.append(1)\n",
    "    while len(frontier) > 0:\n",
    "        vtx = frontier.popleft()\n",
    "        for x in G[vtx]:\n",
    "            if dist[x] == n:\n",
    "                dist[x] = 1 + dist[vtx]\n",
    "                frontier.append(x)\n",
    "                if x == target:\n",
    "                    return dist[target]\n",
    "    if target is not None:\n",
    "        return dist[target]\n",
    "    else:\n",
    "        return dist\n",
    "\n",
    "def mkbatch(size):\n",
    "    graphs1 = []\n",
    "    distance1 = []\n",
    "    \n",
    "    for i in range(size):\n",
    "        n = random.randrange(MIN_VTXS, MAX_VTXS)\n",
    "        edge_list, adj_list = random_graph(n)\n",
    "        dist = SSSP(n, adj_list)\n",
    "        edge_list[-1] = 2 # target token\n",
    "        graphs1.append(edge_list)\n",
    "        distance1.append(dist)\n",
    "    \n",
    "    data = torch.tensor(graphs1, device=device)\n",
    "    labels = torch.tensor(distance1, dtype=torch.float32, device=device)\n",
    "    padding = data == PAD_TOKEN\n",
    "    return data, labels, padding\n",
    "\n",
    "def vertices_on_shortest_12_path(n, G, target=2):\n",
    "    dist = [n for _ in G]\n",
    "    parent = [-1 for _ in G]\n",
    "    dist[1] = 0\n",
    "    frontier = deque()\n",
    "    frontier.append(1)\n",
    "    while len(frontier) > 0:\n",
    "        vtx = frontier.popleft()\n",
    "        for x in G[vtx]:\n",
    "            if dist[x] == n:\n",
    "                parent[x] = vtx\n",
    "                dist[x] = 1 + dist[vtx]\n",
    "                frontier.append(x)\n",
    "                if x == target:\n",
    "                    path = [x]\n",
    "                    while parent[x] != -1:\n",
    "                        x = parent[x]\n",
    "                        path.append(x)\n",
    "                    return list(reversed(path))\n",
    "    return []\n",
    "\n",
    "def mktunebatch(size):\n",
    "    graphs2 = []\n",
    "    distance2 = []\n",
    "    \n",
    "    for i in range(size):\n",
    "        n = random.randrange(MIN_VTXS, MAX_TUNE_VTXS)\n",
    "        while True:\n",
    "            edge_list, adj_list = random_graph(n)\n",
    "            path = vertices_on_shortest_12_path(n, adj_list)\n",
    "            if len(path) > 1:\n",
    "                target_vtx_idx = random.randrange(1, len(path))\n",
    "                target_vtx = path[target_vtx_idx]\n",
    "                edge_list[-1] = target_vtx\n",
    "                graphs2.append(edge_list)\n",
    "                distance2.append(target_vtx_idx)\n",
    "                break\n",
    "    \n",
    "    data = torch.tensor(graphs2, device=device)\n",
    "    labels = torch.tensor(distance2, dtype=torch.float32, device=device)\n",
    "    padding = data == PAD_TOKEN\n",
    "    return data, labels, padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "execution_state": "idle",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([255.,   0., 298.,   0., 231.,   0., 210.,   0., 123.,   0.,  63.,\n",
       "          0.,  31.,   0.,  46.,   0.,  39.,   0.,  35.,   0.,  40.,   0.,\n",
       "         35.,   0.,  44.,   0.,  24.,   0.,  37.,   0.,  26.,   0.,   0.,\n",
       "         39.,   0.,  31.,   0.,  31.,   0.,  38.,   0.,  34.,   0.,  36.,\n",
       "          0.,  33.,   0.,  33.,   0.,  22.,   0.,  38.,   0.,  28.,   0.,\n",
       "         34.,   0.,  26.,   0.,  30.,   0.,  26.,   0.,  32.]),\n",
       " array([ 1.      ,  1.484375,  1.96875 ,  2.453125,  2.9375  ,  3.421875,\n",
       "         3.90625 ,  4.390625,  4.875   ,  5.359375,  5.84375 ,  6.328125,\n",
       "         6.8125  ,  7.296875,  7.78125 ,  8.265625,  8.75    ,  9.234375,\n",
       "         9.71875 , 10.203125, 10.6875  , 11.171875, 11.65625 , 12.140625,\n",
       "        12.625   , 13.109375, 13.59375 , 14.078125, 14.5625  , 15.046875,\n",
       "        15.53125 , 16.015625, 16.5     , 16.984375, 17.46875 , 17.953125,\n",
       "        18.4375  , 18.921875, 19.40625 , 19.890625, 20.375   , 20.859375,\n",
       "        21.34375 , 21.828125, 22.3125  , 22.796875, 23.28125 , 23.765625,\n",
       "        24.25    , 24.734375, 25.21875 , 25.703125, 26.1875  , 26.671875,\n",
       "        27.15625 , 27.640625, 28.125   , 28.609375, 29.09375 , 29.578125,\n",
       "        30.0625  , 30.546875, 31.03125 , 31.515625, 32.      ]),\n",
       " <BarContainer object of 64 artists>)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAivElEQVR4nO3df1BVdeL/8RegXH/BJVS4sAL+KpEUa0nxTuXHFVZE1tWVZrTcstbR0YVmlTKlMTV3Z3Gt6eeQzs62UjOS5U7qSGUpJm4ralKOv4pRhxZbvNDqwFWMH8L5/vHJ+/3c/EEocN/A8zFzZrjnvO+973M85dNzf+BnWZYlAAAAg/j7egIAAAA/RqAAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAME4PX0/gVjQ3N6uiokJBQUHy8/Pz9XQAAMBPYFmWLl68qMjISPn73/waSacMlIqKCkVFRfl6GgAA4BacPXtWgwYNuumYThkoQUFB0g87GBwc7OvpAACAn8DtdisqKsrz9/jNdMpAufqyTnBwMIECAEAn81PensGbZAEAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGaVWgrF+/XvHx8Z4vSHM6nfroo4882+vq6pSRkaH+/furX79+Sk9PV2VlpddjlJeXKy0tTX369FFYWJiWLl2qK1eutN0eAQCATq9VgTJo0CCtXbtWJSUlOnz4sCZNmqTp06frxIkTkqQlS5Zox44d2rJli4qKilRRUaGZM2d67t/U1KS0tDQ1NDRo//79euutt5SXl6eVK1e2/Z4BAIBOy8+yLOt2HiA0NFQvvPCCHnroIQ0cOFD5+fl66KGHJElff/21Ro4cqeLiYo0fP14fffSRfvWrX6miokLh4eGSpA0bNmjZsmX67rvvFBgY+JOe0+12y263q6amhq+6BwCgk2jN39+3/B6UpqYmbd68WbW1tXI6nSopKVFjY6OSk5M9Y2JjYxUdHa3i4mJJUnFxsUaPHu2JE0lKSUmR2+32XIW5nvr6erndbq8FAAB0Xa0OlGPHjqlfv36y2WxauHChtm7dqri4OLlcLgUGBiokJMRrfHh4uFwulyTJ5XJ5xcnV7Ve33UhOTo7sdrtniYqKau20AQBAJ9LqQBkxYoSOHDmigwcPatGiRZo7d65OnjzZPrP7QXZ2tmpqajzL2bNn2/X5AACAb/Vo7R0CAwM1fPhwSVJCQoI+//xzvfrqq5o1a5YaGhpUXV3tdRWlsrJSDodDkuRwOHTo0CGvx7v6KZ+rY67HZrPJZrO1dqodbvDyD2647Zu1aR06FwAAOrPb/h6U5uZm1dfXKyEhQT179lRhYaFnW2lpqcrLy+V0OiVJTqdTx44dU1VVlWfMrl27FBwcrLi4uNudCgAA6CJadQUlOztbqampio6O1sWLF5Wfn6+9e/fq448/lt1u17x585SVlaXQ0FAFBwfrySeflNPp1Pjx4yVJkydPVlxcnB599FGtW7dOLpdLK1asUEZGRqe4QgIAADpGqwKlqqpKjz32mM6dOye73a74+Hh9/PHH+uUvfylJevnll+Xv76/09HTV19crJSVFb7zxhuf+AQEBKigo0KJFi+R0OtW3b1/NnTtXa9asafs9AwAAndZtfw+KL5j6PSi8BwUAgBvrkO9BAQAAaC8ECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADBOD19PwESDl39ww23frE3r0LkAANAdcQUFAAAYp1WBkpOTo7FjxyooKEhhYWGaMWOGSktLvcZMnDhRfn5+XsvChQu9xpSXlystLU19+vRRWFiYli5dqitXrrTNHgEAgE6vVS/xFBUVKSMjQ2PHjtWVK1f07LPPavLkyTp58qT69u3rGTd//nytWbPGc7tPnz6en5uampSWliaHw6H9+/fr3Llzeuyxx9SzZ0/9+c9/bqv9AgAAnVirAmXnzp1et/Py8hQWFqaSkhJNmDDBs75Pnz5yOBzXfYxPPvlEJ0+e1O7duxUeHq577rlHf/zjH7Vs2TKtXr1agYGBt7ovAACgi7it96DU1NRIkkJDQ73Wb9q0SQMGDNCoUaOUnZ2ty5cve7YVFxdr9OjRCg8P96xLSUmR2+3WiRMnrvs89fX1crvdXgsAAOi6bvlTPM3NzVq8eLHuv/9+jRo1yrP+kUceUUxMjCIjI3X06FEtW7ZMpaWlev/99yVJLpfLK04keW67XK7rPldOTo6ef/75W50qAADoZG45UDIyMnT8+HF99tlnXusXLFjg+Xn06NGKiIhQUlKSzpw5o2HDht3Sc2VnZysrK8tz2+12Kyoq6lanDgAADHdLL/FkZmaqoKBAn376qQYNGnTTsYmJiZKk06dPS5IcDocqKyu9xly9faP3rdhsNgUHB3stAACg62pVoFiWpczMTG3dulV79uzRkCFDWrzPkSNHJEkRERGSJKfTqWPHjqmqqsozZteuXQoODlZcXFzr9wAAAHQ5rXqJJyMjQ/n5+dq+fbuCgoI87xmx2+3q3bu3zpw5o/z8fE2dOlX9+/fX0aNHtWTJEk2YMEHx8fGSpMmTJysuLk6PPvqo1q1bJ5fLpRUrVigjI0M2m6199hIAAHQqrbqCsn79etXU1GjixImKiIjwLO+++64kKTAwULt379bkyZMVGxurp556Sunp6dqxY4fnMQICAlRQUKCAgAA5nU799re/1WOPPeb1vSkAAKB7a9UVFMuybro9KipKRUVFLT5OTEyMPvzww9Y8NQAA6Eb4XTwAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzTw9cTgDR4+Qc33PbN2rQOnQsAACbgCgoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4PXw9Ady6wcs/uOG2b9amdehcAABoS1xBAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGCcVgVKTk6Oxo4dq6CgIIWFhWnGjBkqLS31GlNXV6eMjAz1799f/fr1U3p6uiorK73GlJeXKy0tTX369FFYWJiWLl2qK1eutM0eAQCATq9VgVJUVKSMjAwdOHBAu3btUmNjoyZPnqza2lrPmCVLlmjHjh3asmWLioqKVFFRoZkzZ3q2NzU1KS0tTQ0NDdq/f7/eeust5eXlaeXKlW27ZwAAoNNq1TfJ7ty50+t2Xl6ewsLCVFJSogkTJqimpkZvvvmm8vPzNWnSJEnSxo0bNXLkSB04cEDjx4/XJ598opMnT2r37t0KDw/XPffcoz/+8Y9atmyZVq9ercDAwLbdQwAA0Onc1ntQampqJEmhoaGSpJKSEjU2Nio5OdkzJjY2VtHR0SouLpYkFRcXa/To0QoPD/eMSUlJkdvt1okTJ677PPX19XK73V4LAADoum45UJqbm7V48WLdf//9GjVqlCTJ5XIpMDBQISEhXmPDw8Plcrk8Y/5vnFzdfnXb9eTk5Mhut3uWqKioW502AADoBG45UDIyMnT8+HFt3ry5bWd0HdnZ2aqpqfEsZ8+ebffnBAAAvnNLv804MzNTBQUF2rdvnwYNGuRZ73A41NDQoOrqaq+rKJWVlXI4HJ4xhw4d8nq8q5/yuTrmx2w2m2w2261MFQAAdEKtuoJiWZYyMzO1detW7dmzR0OGDPHanpCQoJ49e6qwsNCzrrS0VOXl5XI6nZIkp9OpY8eOqaqqyjNm165dCg4OVlxc3O3vEQAA6PRadQUlIyND+fn52r59u4KCgjzvGbHb7erdu7fsdrvmzZunrKwshYaGKjg4WE8++aScTqfGjx8vSZo8ebLi4uL06KOPat26dXK5XFqxYoUyMjK4SgIAAKTWBsr69eslSRMnTvRav3HjRj3++OOSpJdffln+/v5KT09XfX29UlJS9MYbb3jGBgQEqKCgQIsWLZLT6VTfvn01d+5crVmzpm32CAAAdHqtChTLsloc06tXL+Xm5io3N/eGY2JiYvThhx+25qkBAEA3wu/iAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGKfVgbJv3z5NmzZNkZGR8vPz07Zt27y2P/744/Lz8/NapkyZ4jXmwoULmjNnjoKDgxUSEqJ58+bp0qVLt783AACgS2h1oNTW1mrMmDHKzc294ZgpU6bo3LlznuWdd97x2j5nzhydOHFCu3btUkFBgfbt26cFCxbc2h4AAIAup0dr75CamqrU1NSbjrHZbHI4HNfd9tVXX2nnzp36/PPPdd9990mSXn/9dU2dOlUvvviiIiMjWzslAADQxbTLe1D27t2rsLAwjRgxQosWLdL58+c924qLixUSEuKJE0lKTk6Wv7+/Dh48eN3Hq6+vl9vt9loAAEDX1eaBMmXKFL399tsqLCzUX/7yFxUVFSk1NVVNTU2SJJfLpbCwMK/79OjRQ6GhoXK5XNd9zJycHNntds8SFRXV1tMGAAAGafVLPC2ZPXu25+fRo0crPj5ew4YN0969e5WUlHRLj5mdna2srCzPbbfbTaQAANCFtfvHjIcOHaoBAwbo9OnTkiSHw6GqqiqvMVeuXNGFCxdu+L4Vm82m4OBgrwUAAHRd7R4o3377rc6fP6+IiAhJktPpVHV1tUpKSjxj9uzZo+bmZiUmJrb3dAAAQCfQ6pd4Ll265LkaIkllZWU6cuSIQkNDFRoaqueff17p6elyOBw6c+aMnnnmGQ0fPlwpKSmSpJEjR2rKlCmaP3++NmzYoMbGRmVmZmr27Nl8ggcAAEi3cgXl8OHDuvfee3XvvfdKkrKysnTvvfdq5cqVCggI0NGjR/XrX/9ad911l+bNm6eEhAT985//lM1m8zzGpk2bFBsbq6SkJE2dOlUPPPCA/vrXv7btngEAgE6r1VdQJk6cKMuybrj9448/bvExQkNDlZ+f39qnBgAA3QS/iwcAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABinh68ngI43ePkHN9z2zdq0Dp0LAADXwxUUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGKfVgbJv3z5NmzZNkZGR8vPz07Zt27y2W5allStXKiIiQr1791ZycrJOnTrlNebChQuaM2eOgoODFRISonnz5unSpUu3vzcAAKBLaHWg1NbWasyYMcrNzb3u9nXr1um1117Thg0bdPDgQfXt21cpKSmqq6vzjJkzZ45OnDihXbt2qaCgQPv27dOCBQtub08AAECX0aO1d0hNTVVqaup1t1mWpVdeeUUrVqzQ9OnTJUlvv/22wsPDtW3bNs2ePVtfffWVdu7cqc8//1z33XefJOn111/X1KlT9eKLLyoyMvJ29wkAAHRybfoelLKyMrlcLiUnJ3vW2e12JSYmqri4WJJUXFyskJAQT5xIUnJysvz9/XXw4MHrPm59fb3cbrfXAgAAuq42DRSXyyVJCg8P91ofHh7u2eZyuRQWFua1vUePHgoNDfWM+bGcnBzZ7XbPEhUV1ZbTBgAAhukUn+LJzs5WTU2NZzl79qyvpwQAANpRmwaKw+GQJFVWVnqtr6ys9GxzOByqqqry2n7lyhVduHDBM+bHbDabgoODvRYAANB1tWmgDBkyRA6HQ4WFhZ51brdbBw8elNPplCQ5nU5VV1erpKTEM2bPnj1qbm5WYmJiW04HAAB0Uq3+FM+lS5d0+vRpz+2ysjIdOXJEoaGhio6O1uLFi/WnP/1Jd955p4YMGaLnnntOkZGRmjFjhiRp5MiRmjJliubPn68NGzaosbFRmZmZmj17Np/gAQAA0q0EyuHDh/WLX/zCczsrK0uSNHfuXOXl5emZZ55RbW2tFixYoOrqaj3wwAPauXOnevXq5bnPpk2blJmZqaSkJPn7+ys9PV2vvfZaW+0TAADo5FodKBMnTpRlWTfc7ufnpzVr1mjNmjU3HBMaGqr8/PzWPjUAAOgmOsWneAAAQPdCoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzTw9cTQOcxePkHN9z2zdq0Dp0LAKBr4woKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjNPmgbJ69Wr5+fl5LbGxsZ7tdXV1ysjIUP/+/dWvXz+lp6ersrKyracBAAA6sXa5gnL33Xfr3LlznuWzzz7zbFuyZIl27NihLVu2qKioSBUVFZo5c2Z7TAMAAHRS7fLLAnv06CGHw3HN+pqaGr355pvKz8/XpEmTJEkbN27UyJEjdeDAAY0fP749pgMAADqZdrmCcurUKUVGRmro0KGaM2eOysvLJUklJSVqbGxUcnKyZ2xsbKyio6NVXFx8w8err6+X2+32WgAAQNfV5oGSmJiovLw87dy5U+vXr1dZWZkefPBBXbx4US6XS4GBgQoJCfG6T3h4uFwu1w0fMycnR3a73bNERUW19bQBAIBB2vwlntTUVM/P8fHxSkxMVExMjN577z317t37lh4zOztbWVlZnttut5tIAQCgC2v3jxmHhITorrvu0unTp+VwONTQ0KDq6mqvMZWVldd9z8pVNptNwcHBXgsAAOi62j1QLl26pDNnzigiIkIJCQnq2bOnCgsLPdtLS0tVXl4up9PZ3lMBAACdRJu/xPP0009r2rRpiomJUUVFhVatWqWAgAA9/PDDstvtmjdvnrKyshQaGqrg4GA9+eSTcjqdfIIHAAB4tHmgfPvtt3r44Yd1/vx5DRw4UA888IAOHDiggQMHSpJefvll+fv7Kz09XfX19UpJSdEbb7zR1tNAFzB4+Qc33PbN2rQOnQsAoGO1eaBs3rz5ptt79eql3Nxc5ebmtvVTAwCALqJdvqgNQPvi6hKAro5fFggAAIxDoAAAAOMQKAAAwDgECgAAMA5vkgXaAG9aBYC2RaCgyyEWujb+fIHugUABgHZASLU9jmn3wntQAACAcbiCAqBF/MsVQEcjUACgiyAk0ZUQKMAPbvY/d7XT/+D5CwUAro9AAQDDEK5m4M/BtwgUAOjmbvUvYv4CR3siUAAA6KZMjkw+ZgwAAIzDFRS0O5MLHd0HL2MAnQuBAgCAAYhhb7zEAwAAjMMVFABAl8aVic6JKygAAMA4BAoAADAOL/EAaDdcWkd31NHnfVf974wrKAAAwDhcQQG6ka76Ly0AXQ9XUAAAgHG4ggIA6FBcycNPwRUUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYByfBkpubq4GDx6sXr16KTExUYcOHfLldAAAgCF8FijvvvuusrKytGrVKn3xxRcaM2aMUlJSVFVV5aspAQAAQ/gsUF566SXNnz9fTzzxhOLi4rRhwwb16dNHf//73301JQAAYIgevnjShoYGlZSUKDs727PO399fycnJKi4uvmZ8fX296uvrPbdramokSW63u13m11x/+Ybbbvac3K/r3s8Xz9md73cznWUfbmffO8tcuV/nvt/t3vdWXH1My7JaHmz5wH/+8x9LkrV//36v9UuXLrXGjRt3zfhVq1ZZklhYWFhYWFi6wHL27NkWW8EnV1BaKzs7W1lZWZ7bzc3NunDhgvr37y8/P79rxrvdbkVFRens2bMKDg7u4Nl2DhyjlnGMbo7j0zKOUcs4RjfX1Y6PZVm6ePGiIiMjWxzrk0AZMGCAAgICVFlZ6bW+srJSDofjmvE2m002m81rXUhISIvPExwc3CX+QNsTx6hlHKOb4/i0jGPUMo7RzXWl42O323/SOJ+8STYwMFAJCQkqLCz0rGtublZhYaGcTqcvpgQAAAzis5d4srKyNHfuXN13330aN26cXnnlFdXW1uqJJ57w1ZQAAIAhfBYos2bN0nfffaeVK1fK5XLpnnvu0c6dOxUeHn7bj22z2bRq1aprXhbC/8cxahnH6OY4Pi3jGLWMY3Rz3fn4+Fk/6bM+AAAAHYffxQMAAIxDoAAAAOMQKAAAwDgECgAAME6XDJTc3FwNHjxYvXr1UmJiog4dOuTrKRlj9erV8vPz81piY2N9PS2f2bdvn6ZNm6bIyEj5+flp27ZtXtsty9LKlSsVERGh3r17Kzk5WadOnfLZfH2hpWP0+OOPX3NOTZkyxWfz7Wg5OTkaO3asgoKCFBYWphkzZqi0tNRrTF1dnTIyMtS/f3/169dP6enp13xRZVf2U47RxIkTrzmPFi5c6LM5d7T169crPj7e84VsTqdTH330kWd7dzyHulygvPvuu8rKytKqVav0xRdfaMyYMUpJSVFVVZWvp2aMu+++W+fOnfMsn332ma+n5DO1tbUaM2aMcnNzr7t93bp1eu2117RhwwYdPHhQffv2VUpKiurq6jp8rr7S0jGSpClTpnidU++8806HztGXioqKlJGRoQMHDmjXrl1qbGzU5MmTVVtb6xmzZMkS7dixQ1u2bFFRUZEqKio0c+ZMn867I/2UYyRJ8+fP9zqP1q1b57M5d7RBgwZp7dq1Kikp0eHDhzVp0iRNnz5dJ06ckLrrOdSWvwTQBOPGjbMyMjI8t5uamqzIyEgrJyfHp/MyxapVq6wxY8b4ehpGkmRt3brVc7u5udlyOBzWCy+84FlXXV1t2Ww265133vHRLH3rx8fIsixr7ty51vTp0302J9NUVVVZkqyioiLL+uGc6dmzp7VlyxbPmK+++sqSZBUXF/twpr7z42NkWZb1P//zP9Yf/vAHn87LNHfccYf1t7/9rdueQ13qCkpDQ4NKSkqUnJzsWefv76/k5GQVFxf7dG4mOXXqlCIjIzV06FDNmTNH5eXlvp6SkcrKyuRyubzOJ7vdrsTERM6nH9m7d6/CwsI0YsQILVq0SOfPn/f1lHympqZGkhQaGipJKikpUWNjo9d5FBsbq+jo6G57Hv34GF21adMmDRgwQKNGjVJ2drYuX77soxn6VlNTkzZv3qza2lo5nc5uew51it9m/FP997//VVNT0zXfRhseHq6vv/7aZ/MySWJiovLy8jRixAidO3dOzz//vB588EEdP35cQUFBvp6eUVwul/TD+fN/hYeHe7bhf1/emTlzpoYMGaIzZ87o2WefVWpqqoqLixUQEODr6XWo5uZmLV68WPfff79GjRol/XAeBQYGXvMLTrvreXS9YyRJjzzyiGJiYhQZGamjR49q2bJlKi0t1fvvv+/T+XakY8eOyel0qq6uTv369dPWrVsVFxenI0eOdMtzqEsFClqWmprq+Tk+Pl6JiYmKiYnRe++9p3nz5vl0buicZs+e7fl59OjRio+P17Bhw7R3714lJSX5dG4dLSMjQ8ePH+/W7+tqyY2O0YIFCzw/jx49WhEREUpKStKZM2c0bNgwH8y0440YMUJHjhxRTU2N/vGPf2ju3LkqKiry9bR8pku9xDNgwAAFBARc887myspKORwOn83LZCEhIbrrrrt0+vRpX0/FOFfPGc6n1hk6dKgGDBjQ7c6pzMxMFRQU6NNPP9WgQYM86x0OhxoaGlRdXe01vjueRzc6RteTmJgoSd3qPAoMDNTw4cOVkJCgnJwcjRkzRq+++mq3PYe6VKAEBgYqISFBhYWFnnXNzc0qLCyU0+n06dxMdenSJZ05c0YRERG+nopxhgwZIofD4XU+ud1uHTx4kPPpJr799ludP3++25xTlmUpMzNTW7du1Z49ezRkyBCv7QkJCerZs6fXeVRaWqry8vJucx61dIyu58iRI5LUbc6j62lublZ9fX33PYd8/S7dtrZ582bLZrNZeXl51smTJ60FCxZYISEhlsvl8vXUjPDUU09Ze/futcrKyqx//etfVnJysjVgwACrqqrK11PziYsXL1pffvml9eWXX1qSrJdeesn68ssvrX//+9+WZVnW2rVrrZCQEGv79u3W0aNHrenTp1tDhgyxvv/+e19PvcPc7BhdvHjRevrpp63i4mKrrKzM2r17t/Xzn//cuvPOO626ujpfT71DLFq0yLLb7dbevXutc+fOeZbLly97xixcuNCKjo629uzZYx0+fNhyOp2W0+n06bw7UkvH6PTp09aaNWusw4cPW2VlZdb27dutoUOHWhMmTPD11DvM8uXLraKiIqusrMw6evSotXz5csvPz8/65JNPLKubnkNdLlAsy7Jef/11Kzo62goMDLTGjRtnHThwwNdTMsasWbOsiIgIKzAw0PrZz35mzZo1yzp9+rSvp+Uzn376qSXpmmXu3LmW9cNHjZ977jkrPDzcstlsVlJSklVaWurraXeomx2jy5cvW5MnT7YGDhxo9ezZ04qJibHmz5/frf5BcL1jI8nauHGjZ8z3339v/f73v7fuuOMOq0+fPtZvfvMb69y5cz6dd0dq6RiVl5dbEyZMsEJDQy2bzWYNHz7cWrp0qVVTU+PrqXeY3/3ud1ZMTIwVGBhoDRw40EpKSvLEidVNzyE/639PHgAAAGN0qfegAACAroFAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYJz/B1g6lyYuVsaFAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(mkbatch(2048)[1].cpu(), bins=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "execution_state": "idle",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([1162.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,  525.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,  242.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,   87.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,   22.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    7.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           3.]),\n",
       " array([1.     , 1.09375, 1.1875 , 1.28125, 1.375  , 1.46875, 1.5625 ,\n",
       "        1.65625, 1.75   , 1.84375, 1.9375 , 2.03125, 2.125  , 2.21875,\n",
       "        2.3125 , 2.40625, 2.5    , 2.59375, 2.6875 , 2.78125, 2.875  ,\n",
       "        2.96875, 3.0625 , 3.15625, 3.25   , 3.34375, 3.4375 , 3.53125,\n",
       "        3.625  , 3.71875, 3.8125 , 3.90625, 4.     , 4.09375, 4.1875 ,\n",
       "        4.28125, 4.375  , 4.46875, 4.5625 , 4.65625, 4.75   , 4.84375,\n",
       "        4.9375 , 5.03125, 5.125  , 5.21875, 5.3125 , 5.40625, 5.5    ,\n",
       "        5.59375, 5.6875 , 5.78125, 5.875  , 5.96875, 6.0625 , 6.15625,\n",
       "        6.25   , 6.34375, 6.4375 , 6.53125, 6.625  , 6.71875, 6.8125 ,\n",
       "        6.90625, 7.     ]),\n",
       " <BarContainer object of 64 artists>)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjF0lEQVR4nO3dfVSUdf7/8RdI3GQMiMXgbIps26rkXUnRZPdyRGXd9cRWFLVUrO5pwTK6gy3vysRsu5GWMDuteE56rPasllYkYcqWiIjLpmRkmyVlA3sOMRN0RJT5/vFbr19TWtgOznzo+TjnOqe5rs/MvOc6nsOzi5khxOv1egUAAGCQ0EAPAAAAcLIIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGCQv0AH2lp6dHBw8eVHR0tEJCQgI9DgAA6AWv16uvvvpKDodDoaEnvs7SbwPm4MGDGjp0aKDHAAAAP0Jzc7POPvvsEx7vtwETHR0t/fcE2Gy2QI8DAAB6wePxaOjQodbP8RM56YCprq7WY489pvr6en3xxRdat26dZsyYIUnq7u7Wgw8+qNdff10ff/yxYmJilJaWpiVLlsjhcFiP0dbWptmzZ2vDhg0KDQ1VZmamli1bpjPOOMNa89577ykvL091dXU666yzNHv2bN133329nvPYr41sNhsBAwCAYX7o7R8n/Sbezs5OjRs3TqWlpd859vXXX2vXrl2aO3eudu3apb///e9qamrSr3/9a5912dnZamxsVGVlpTZu3Kjq6mrNmjXLOu7xeDR58mQlJiaqvr5ejz32mBYsWKAVK1ac7LgAAKAfCvlf/hp1SEiIzxWY46mrq9NFF12kTz/9VMOGDdPevXuVnJysuro6paSkSJIqKio0bdo0ffbZZ3I4HCorK9MDDzwgl8ul8PBwSVJhYaHWr1+vDz74oFezeTwexcTEyO12cwUGAABD9Pbnd59/jNrtdiskJESxsbGSpJqaGsXGxlrxIklpaWkKDQ1VbW2ttebyyy+34kWS0tPT1dTUpC+//PK4z9PV1SWPx+OzAQCA/qlPA+bQoUO6//77dcMNN1gV5XK5FB8f77MuLCxMcXFxcrlc1hq73e6z5tjtY2u+rbi4WDExMdbGJ5AAAOi/+ixguru7dd1118nr9aqsrKyvnsZSVFQkt9ttbc3NzX3+nAAAIDD65GPUx+Ll008/1ebNm31+h5WQkKDW1laf9UeOHFFbW5sSEhKsNS0tLT5rjt0+tubbIiIiFBER0QevBgAABBu/X4E5Fi/79u3TW2+9pcGDB/scdzqdam9vV319vbVv8+bN6unpUWpqqrWmurpa3d3d1prKykqNGDFCgwYN8vfIAADAMCcdMB0dHWpoaFBDQ4Mkaf/+/WpoaNCBAwfU3d2t3/72t9q5c6dWr16to0ePyuVyyeVy6fDhw5KkUaNGacqUKZo5c6Z27Nihd999V/n5+crKyrK+K+bGG29UeHi4cnNz1djYqBdffFHLli1TQUGBv18/AAAw0El/jHrLli266qqrvrM/JydHCxYsUFJS0nHv9/bbb+vKK6+U/vtFdvn5+T5fZFdSUnLCL7I788wzNXv2bN1///29npOPUQMAYJ7e/vz+n74HJpgRMAAAmCdovgcGAADA3wgYAABgHAIGAAAYh4ABAADG6ZMvsuvvhhe+dsJjnyzJOKWzAADwU8QVGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxTjpgqqurNX36dDkcDoWEhGj9+vU+x71er+bNm6chQ4YoKipKaWlp2rdvn8+atrY2ZWdny2azKTY2Vrm5uero6PBZ89577+myyy5TZGSkhg4dqqVLl/7Y1wgAAPqZkw6Yzs5OjRs3TqWlpcc9vnTpUpWUlGj58uWqra3VwIEDlZ6erkOHDllrsrOz1djYqMrKSm3cuFHV1dWaNWuWddzj8Wjy5MlKTExUfX29HnvsMS1YsEArVqz4sa8TAAD0IyFer9f7o+8cEqJ169ZpxowZ0n+vvjgcDt1999265557JElut1t2u13l5eXKysrS3r17lZycrLq6OqWkpEiSKioqNG3aNH322WdyOBwqKyvTAw88IJfLpfDwcElSYWGh1q9frw8++KBXs3k8HsXExMjtdstms/3Yl3hcwwtfO+GxT5Zk+PW5AAD4Kentz2+/vgdm//79crlcSktLs/bFxMQoNTVVNTU1kqSamhrFxsZa8SJJaWlpCg0NVW1trbXm8ssvt+JFktLT09XU1KQvv/zyuM/d1dUlj8fjswEAgP7JrwHjcrkkSXa73We/3W63jrlcLsXHx/scDwsLU1xcnM+a4z3GN5/j24qLixUTE2NtQ4cO9eMrAwAAwaTffAqpqKhIbrfb2pqbmwM9EgAA6CN+DZiEhARJUktLi8/+lpYW61hCQoJaW1t9jh85ckRtbW0+a473GN98jm+LiIiQzWbz2QAAQP/k14BJSkpSQkKCqqqqrH0ej0e1tbVyOp2SJKfTqfb2dtXX11trNm/erJ6eHqWmplprqqur1d3dba2prKzUiBEjNGjQIH+ODAAADHTSAdPR0aGGhgY1NDRI/33jbkNDgw4cOKCQkBDNmTNHixYt0quvvqrdu3frd7/7nRwOh/VJpVGjRmnKlCmaOXOmduzYoXfffVf5+fnKysqSw+GQJN14440KDw9Xbm6uGhsb9eKLL2rZsmUqKCjw9+sHAAAGCjvZO+zcuVNXXXWVdftYVOTk5Ki8vFz33XefOjs7NWvWLLW3t+vSSy9VRUWFIiMjrfusXr1a+fn5mjRpkkJDQ5WZmamSkhLreExMjDZt2qS8vDxNmDBBZ555pubNm+fzXTEAAOCn63/6HphgxvfAAABgnoB8DwwAAMCpQMAAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADj+D1gjh49qrlz5yopKUlRUVE655xz9PDDD8vr9VprvF6v5s2bpyFDhigqKkppaWnat2+fz+O0tbUpOztbNptNsbGxys3NVUdHh7/HBQAABvJ7wDz66KMqKyvTX/7yF+3du1ePPvqoli5dqqefftpas3TpUpWUlGj58uWqra3VwIEDlZ6erkOHDllrsrOz1djYqMrKSm3cuFHV1dWaNWuWv8cFAAAGCvF+89KIH/zqV7+S3W7X888/b+3LzMxUVFSUXnjhBXm9XjkcDt1999265557JElut1t2u13l5eXKysrS3r17lZycrLq6OqWkpEiSKioqNG3aNH322WdyOBw/OIfH41FMTIzcbrdsNps/X6KGF752wmOfLMnw63MBAPBT0tuf336/AnPJJZeoqqpKH374oSTpX//6l9555x1NnTpVkrR//365XC6lpaVZ94mJiVFqaqpqamokSTU1NYqNjbXiRZLS0tIUGhqq2tra4z5vV1eXPB6PzwYAAPqnMH8/YGFhoTwej0aOHKkBAwbo6NGjeuSRR5SdnS1JcrlckiS73e5zP7vdbh1zuVyKj4/3HTQsTHFxcdaabysuLtbChQv9/XIAAEAQ8vsVmJdeekmrV6/WmjVrtGvXLq1atUp//vOftWrVKn8/lY+ioiK53W5ra25u7tPnAwAAgeP3KzD33nuvCgsLlZWVJUkaM2aMPv30UxUXFysnJ0cJCQmSpJaWFg0ZMsS6X0tLi8aPHy9JSkhIUGtrq8/jHjlyRG1tbdb9vy0iIkIRERH+fjkAACAI+f0KzNdff63QUN+HHTBggHp6eiRJSUlJSkhIUFVVlXXc4/GotrZWTqdTkuR0OtXe3q76+nprzebNm9XT06PU1FR/jwwAAAzj9ysw06dP1yOPPKJhw4bpvPPO0z//+U898cQTuu222yRJISEhmjNnjhYtWqRzzz1XSUlJmjt3rhwOh2bMmCFJGjVqlKZMmaKZM2dq+fLl6u7uVn5+vrKysnr1CSQAANC/+T1gnn76ac2dO1d//OMf1draKofDoT/84Q+aN2+etea+++5TZ2enZs2apfb2dl166aWqqKhQZGSktWb16tXKz8/XpEmTFBoaqszMTJWUlPh7XAAAYCC/fw9MsOB7YAAAME/AvgcGAACgrxEwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4/RJwHz++ee66aabNHjwYEVFRWnMmDHauXOnddzr9WrevHkaMmSIoqKilJaWpn379vk8Rltbm7Kzs2Wz2RQbG6vc3Fx1dHT0xbgAAMAwfg+YL7/8UhMnTtRpp52mN954Q++//74ef/xxDRo0yFqzdOlSlZSUaPny5aqtrdXAgQOVnp6uQ4cOWWuys7PV2NioyspKbdy4UdXV1Zo1a5a/xwUAAAYK8Xq9Xn8+YGFhod5991394x//OO5xr9crh8Ohu+++W/fcc48kye12y263q7y8XFlZWdq7d6+Sk5NVV1enlJQUSVJFRYWmTZumzz77TA6H4wfn8Hg8iomJkdvtls1m8+dL1PDC10547JMlGX59LgAAfkp6+/Pb71dgXn31VaWkpOjaa69VfHy8zj//fD333HPW8f3798vlciktLc3aFxMTo9TUVNXU1EiSampqFBsba8WLJKWlpSk0NFS1tbX+HhkAABjG7wHz8ccfq6ysTOeee67efPNN3X777brjjju0atUqSZLL5ZIk2e12n/vZ7XbrmMvlUnx8vM/xsLAwxcXFWWu+raurSx6Px2cDAAD9U5i/H7Cnp0cpKSlavHixJOn888/Xnj17tHz5cuXk5Pj76SzFxcVauHBhnz0+AAAIHn6/AjNkyBAlJyf77Bs1apQOHDggSUpISJAktbS0+KxpaWmxjiUkJKi1tdXn+JEjR9TW1mat+baioiK53W5ra25u9uvrAgAAwcPvATNx4kQ1NTX57Pvwww+VmJgoSUpKSlJCQoKqqqqs4x6PR7W1tXI6nZIkp9Op9vZ21dfXW2s2b96snp4epaamHvd5IyIiZLPZfDYAANA/+f1XSHfddZcuueQSLV68WNddd5127NihFStWaMWKFZKkkJAQzZkzR4sWLdK5556rpKQkzZ07Vw6HQzNmzJD+e8VmypQpmjlzppYvX67u7m7l5+crKyurV59AAgAA/ZvfA+bCCy/UunXrVFRUpIceekhJSUl66qmnlJ2dba2577771NnZqVmzZqm9vV2XXnqpKioqFBkZaa1ZvXq18vPzNWnSJIWGhiozM1MlJSX+HhcAABjI798DEyz4HhgAAMwTsO+BAQAA6GsEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA44QFegD0L8MLXzvhsU+WZJzSWQAA/RdXYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBx+jxglixZopCQEM2ZM8fad+jQIeXl5Wnw4ME644wzlJmZqZaWFp/7HThwQBkZGTr99NMVHx+ve++9V0eOHOnrcQEAgAH6NGDq6ur07LPPauzYsT7777rrLm3YsEEvv/yytm7dqoMHD+qaa66xjh89elQZGRk6fPiwtm3bplWrVqm8vFzz5s3ry3EBAIAh+ixgOjo6lJ2dreeee06DBg2y9rvdbj3//PN64okndPXVV2vChAlauXKltm3bpu3bt0uSNm3apPfff18vvPCCxo8fr6lTp+rhhx9WaWmpDh8+3FcjAwAAQ/RZwOTl5SkjI0NpaWk+++vr69Xd3e2zf+TIkRo2bJhqamokSTU1NRozZozsdru1Jj09XR6PR42Njcd9vq6uLnk8Hp8NAAD0T2F98aBr167Vrl27VFdX951jLpdL4eHhio2N9dlvt9vlcrmsNd+Ml2PHjx07nuLiYi1cuNCPrwIAAAQrv1+BaW5u1p133qnVq1crMjLS3w9/QkVFRXK73dbW3Nx8yp4bAACcWn4PmPr6erW2tuqCCy5QWFiYwsLCtHXrVpWUlCgsLEx2u12HDx9We3u7z/1aWlqUkJAgSUpISPjOp5KO3T625tsiIiJks9l8NgAA0D/5PWAmTZqk3bt3q6GhwdpSUlKUnZ1t/fdpp52mqqoq6z5NTU06cOCAnE6nJMnpdGr37t1qbW211lRWVspmsyk5OdnfIwMAAMP4/T0w0dHRGj16tM++gQMHavDgwdb+3NxcFRQUKC4uTjabTbNnz5bT6dTFF18sSZo8ebKSk5N18803a+nSpXK5XHrwwQeVl5eniIgIf48MAAAM0ydv4v0hTz75pEJDQ5WZmamuri6lp6frmWeesY4PGDBAGzdu1O233y6n06mBAwcqJydHDz30UCDGBQAAQeaUBMyWLVt8bkdGRqq0tFSlpaUnvE9iYqJef/31UzAdAAAwDX8LCQAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYJyzQAwA/dcMLXzvhsU+WZJzSWQDAFFyBAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABjH7wFTXFysCy+8UNHR0YqPj9eMGTPU1NTks+bQoUPKy8vT4MGDdcYZZygzM1MtLS0+aw4cOKCMjAydfvrpio+P17333qsjR474e1wAAGAgvwfM1q1blZeXp+3bt6uyslLd3d2aPHmyOjs7rTV33XWXNmzYoJdffllbt27VwYMHdc0111jHjx49qoyMDB0+fFjbtm3TqlWrVF5ernnz5vl7XAAAYCC//ymBiooKn9vl5eWKj49XfX29Lr/8crndbj3//PNas2aNrr76aknSypUrNWrUKG3fvl0XX3yxNm3apPfff19vvfWW7Ha7xo8fr4cfflj333+/FixYoPDwcH+PDQAADNLn74Fxu92SpLi4OElSfX29uru7lZaWZq0ZOXKkhg0bppqaGklSTU2NxowZI7vdbq1JT0+Xx+NRY2PjcZ+nq6tLHo/HZwMAAP1TnwZMT0+P5syZo4kTJ2r06NGSJJfLpfDwcMXGxvqstdvtcrlc1ppvxsux48eOHU9xcbFiYmKsbejQoX30qgAAQKD1acDk5eVpz549Wrt2bV8+jSSpqKhIbrfb2pqbm/v8OQEAQGD4/T0wx+Tn52vjxo2qrq7W2Wefbe1PSEjQ4cOH1d7e7nMVpqWlRQkJCdaaHTt2+DzesU8pHVvzbREREYqIiOijVwMAAIKJ36/AeL1e5efna926ddq8ebOSkpJ8jk+YMEGnnXaaqqqqrH1NTU06cOCAnE6nJMnpdGr37t1qbW211lRWVspmsyk5OdnfIwMAAMP4/QpMXl6e1qxZo1deeUXR0dHWe1ZiYmIUFRWlmJgY5ebmqqCgQHFxcbLZbJo9e7acTqcuvvhiSdLkyZOVnJysm2++WUuXLpXL5dKDDz6ovLw8rrIAAAD/B0xZWZkk6corr/TZv3LlSt1yyy2SpCeffFKhoaHKzMxUV1eX0tPT9cwzz1hrBwwYoI0bN+r222+X0+nUwIEDlZOTo4ceesjf4wIAAAP5PWC8Xu8PromMjFRpaalKS0tPuCYxMVGvv/66n6cDAAD9AX8LCQAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAccICPQAA/BjDC1874bFPlmSc0lkAnHpcgQEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgnLBADwAA8J/hha+d8NgnSzJO6SxAX+IKDAAAME5QB0xpaamGDx+uyMhIpaamaseOHYEeCQAABIGgDZgXX3xRBQUFmj9/vnbt2qVx48YpPT1dra2tgR4NAAAEWNC+B+aJJ57QzJkzdeutt0qSli9frtdee01//etfVVhYGOjxAAD9CO8dMk9QBszhw4dVX1+voqIia19oaKjS0tJUU1Nz3Pt0dXWpq6vLuu12uyVJHo/H7/P1dH19wmN98Xwm4dycPM7Zj8N5Oz7Oy4/DeTu+0fPfPOGxPQvT++Q5j51vr9f7/Qu9Qejzzz/3SvJu27bNZ/+9997rveiii457n/nz53slsbGxsbGxsfWDrbm5+XtbISivwPwYRUVFKigosG739PSora1NgwcPVkhIiN+ex+PxaOjQoWpubpbNZvPb4/ZXnK/e41z1Hueq9zhXvce56r2+PFder1dfffWVHA7H964LyoA588wzNWDAALW0tPjsb2lpUUJCwnHvExERoYiICJ99sbGxfTajzWbjH/hJ4Hz1Hueq9zhXvce56j3OVe/11bmKiYn5wTVB+Smk8PBwTZgwQVVVVda+np4eVVVVyel0BnQ2AAAQeEF5BUaSCgoKlJOTo5SUFF100UV66qmn1NnZaX0qCQAA/HQFbcBcf/31+s9//qN58+bJ5XJp/PjxqqiokN1uD+hcERERmj9//nd+XYXj43z1Hueq9zhXvce56j3OVe8Fw7kK8f7g55QAAACCS1C+BwYAAOD7EDAAAMA4BAwAADAOAQMAAIxDwJyE6upqTZ8+XQ6HQyEhIVq/fn2gRwpKxcXFuvDCCxUdHa34+HjNmDFDTU1NgR4rKJWVlWns2LHWl0E5nU698cYbgR7LCEuWLFFISIjmzJkT6FGC0oIFCxQSEuKzjRw5MtBjBa3PP/9cN910kwYPHqyoqCiNGTNGO3fuDPRYQWf48OHf+XcVEhKivLy8Uz4LAXMSOjs7NW7cOJWWlgZ6lKC2detW5eXlafv27aqsrFR3d7cmT56szs7OQI8WdM4++2wtWbJE9fX12rlzp66++mr95je/UWNjY6BHC2p1dXV69tlnNXbs2ECPEtTOO+88ffHFF9b2zjvvBHqkoPTll19q4sSJOu200/TGG2/o/fff1+OPP65BgwYFerSgU1dX5/NvqrKyUpJ07bXXnvJZgvZ7YILR1KlTNXXq1ECPEfQqKip8bpeXlys+Pl719fW6/PLLAzZXMJo+fbrP7UceeURlZWXavn27zjvvvIDNFcw6OjqUnZ2t5557TosWLQr0OEEtLCzshH9+Bf/fo48+qqFDh2rlypXWvqSkpIDOFKzOOussn9tLlizROeecoyuuuOKUz8IVGPQ5t9stSYqLiwv0KEHt6NGjWrt2rTo7O/mTGd8jLy9PGRkZSktLC/QoQW/fvn1yOBz6+c9/ruzsbB04cCDQIwWlV199VSkpKbr22msVHx+v888/X88991ygxwp6hw8f1gsvvKDbbrvNr380ube4AoM+1dPTozlz5mjixIkaPXp0oMcJSrt375bT6dShQ4d0xhlnaN26dUpOTg70WEFp7dq12rVrl+rq6gI9StBLTU1VeXm5RowYoS+++EILFy7UZZddpj179ig6OjrQ4wWVjz/+WGVlZSooKNCf/vQn1dXV6Y477lB4eLhycnICPV7QWr9+vdrb23XLLbcE5PkJGPSpvLw87dmzh9+9f48RI0aooaFBbrdbf/vb35STk6OtW7cSMd/S3NysO++8U5WVlYqMjAz0OEHvm7/uHjt2rFJTU5WYmKiXXnpJubm5AZ0t2PT09CglJUWLFy+WJJ1//vnas2ePli9fTsB8j+eff15Tp06Vw+EIyPPzKyT0mfz8fG3cuFFvv/22zj777ECPE7TCw8P1i1/8QhMmTFBxcbHGjRunZcuWBXqsoFNfX6/W1lZdcMEFCgsLU1hYmLZu3aqSkhKFhYXp6NGjgR4xqMXGxuqXv/ylPvroo0CPEnSGDBnynf9hGDVqFL9y+x6ffvqp3nrrLf3+978P2AxcgYHfeb1ezZ49W+vWrdOWLVt4M9xJ6unpUVdXV6DHCDqTJk3S7t27ffbdeuutGjlypO6//34NGDAgYLOZoKOjQ//+97918803B3qUoDNx4sTvfNXDhx9+qMTExIDNFOxWrlyp+Ph4ZWRkBGwGAuYkdHR0+Pzfy/79+9XQ0KC4uDgNGzYsoLMFk7y8PK1Zs0avvPKKoqOj5XK5JEkxMTGKiooK9HhBpaioSFOnTtWwYcP01Vdfac2aNdqyZYvefPPNQI8WdKKjo7/zPqqBAwdq8ODBvL/qOO655x5Nnz5diYmJOnjwoObPn68BAwbohhtuCPRoQeeuu+7SJZdcosWLF+u6667Tjh07tGLFCq1YsSLQowWlnp4erVy5Ujk5OQoLC2BGeNFrb7/9tlfSd7acnJxAjxZUjneOJHlXrlwZ6NGCzm233eZNTEz0hoeHe8866yzvpEmTvJs2bQr0WMa44oorvHfeeWegxwhK119/vXfIkCHe8PBw789+9jPv9ddf7/3oo48CPVbQ2rBhg3f06NHeiIgI78iRI70rVqwI9EhB68033/RK8jY1NQV0jhDv//uBAwAAYAzexAsAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADDO/wGGkh4mNSHckgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(mktunebatch(2048)[1].cpu(), bins=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q3Cg_8UQep8g"
   },
   "source": [
    "# Step 2: Define Transformer Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "execution_state": "idle",
   "metadata": {
    "id": "tLOWhg_CeWzH"
   },
   "outputs": [],
   "source": [
    "class TransformerModel(nn.Module):\n",
    "    def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, device, dropout):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(input_dim, model_dim//2)\n",
    "        # seq_len is odd\n",
    "        self.fancy_encoding = torch.repeat_interleave(torch.rand((1, seq_len // 2 + 1, model_dim // 2), device=device), 2, dim=1)\n",
    "        # cut off last element since the target vertex is not repeated\n",
    "        self.fancy_encoding = self.fancy_encoding[:, :seq_len, :]\n",
    "        \n",
    "        self.model_dim = model_dim\n",
    "        self.seq_len = seq_len\n",
    "        self.device = device\n",
    "\n",
    "        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads,\n",
    "                                                   dim_feedforward=model_dim*4,\n",
    "                                                   dropout=dropout, batch_first=True)\n",
    "        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n",
    "\n",
    "        self.fc_out = nn.Linear(model_dim*seq_len, output_dim)\n",
    "\n",
    "    def forward(self, src, key_padding_mask):\n",
    "        batch_size, src_len = src.size(0), src.size(1)\n",
    "        embed = self.embedding(src)\n",
    "        src = torch.cat((embed * sqrt(self.model_dim), self.fancy_encoding.repeat((batch_size, 1, 1))), dim=2)\n",
    "\n",
    "        output = self.transformer_encoder(src, src_key_padding_mask=key_padding_mask)\n",
    "        output[key_padding_mask] = 0 # Hack to stop no_grad problem\n",
    "        flat_output = torch.flatten(output, start_dim=1, end_dim=2)\n",
    "        output = self.fc_out(flat_output)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bpIeg86S-hBb"
   },
   "source": [
    "# Step 3: Make Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "execution_state": "idle",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "kWXvJRDYgFVP",
    "outputId": "c13adb9d-6565-43b5-8437-20cef3dc0d16"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training data: 131M\n",
      "Trainable parameters in the model: 605K\n"
     ]
    }
   ],
   "source": [
    "# PARAMS\n",
    "VOCAB_SIZE = 1 + MAX_VTXS # one more than the max number of vertices\n",
    "MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
    "NEPOCHS = 1000\n",
    "BSZ = 2048 # Batch size\n",
    "BPE = 64 # Batches per epoch\n",
    "LR = 0.003\n",
    "WD = 0.002\n",
    "NHEADS = 4\n",
    "NLAYERS = 12\n",
    "DROPOUT = 0.2\n",
    "model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
    "                         output_dim=1, num_heads=NHEADS,\n",
    "                         num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
    "                         dropout=DROPOUT, device=device).to(device)\n",
    "\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)\n",
    "\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(f\"Training data: {NEPOCHS*BSZ*BPE//10**6}M\")\n",
    "print(f\"Trainable parameters in the model: {trainable_params//1000}K\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f8Zn33m7CxL5"
   },
   "source": [
    "# Step 4: Train the Model for the first task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "execution_state": "idle",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "execution_state": "idle",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3ba829714ada43c184a04b0a0b4d06f2",
       "version_major": 2,
       "version_minor": 0
      },
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApBUlEQVR4nO3de1SVdb7H8Q+ge2MpaCFXt5LOqI0XaDARzTFnYZQeGmfVkbSj5LHSdByTVZN2EZtKTLPFjDfKyWpOmZSnPK0kOkQ5HY3GCWVWTV7G8DYVKJmgmCDs3/nj1D6zAwwM9tb9e7/W2n/w43nY34dnbL/n2ReCjDFGAAAAsEawvwcAAACAbxGAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgABwkQgKCtKvfvUrf48BIAAQgAAkSc8995yCgoIUFBSkbdu2Nfm+MUYul0tBQUH6l3/5F6/vnTp1StnZ2Ro8eLAuvfRSXX755UpMTNS8efP0+eefe7ZbvHix5z6au1VUVPjkWFtyrtlmzZrl19kAoD118vcAAC4soaGh2rBhg6655hqv9T/96U/6xz/+IafT6bV+9uxZ/exnP9OePXuUmZmpuXPn6tSpU/rb3/6mDRs26Je//KViY2O99lm7dq26du3a5L67d+/eQUfVeuPGjdO0adOarPfv398v8wBARyAAAXgZP368XnnlFf3+979Xp07//5+IDRs2KCkpSVVVVV7bb968Wbt27dKLL76oKVOmeH3vzJkzqq+vb3IfN998syIiIjrwKM5f//799W//9m/+HgMAOhRPAQPwMnnyZH355ZcqKiryrNXX12vTpk1NAk+SPv30U0nSqFGjmnwvNDRUYWFh7TLX4MGDNXbs2CbrbrdbcXFxuvnmmz1rGzduVFJSkrp166awsDANGTJEv/vd79plDkm69tprNXjwYJWWlmrkyJHq0qWLrrjiCuXl5TXZ9ujRo5oxY4aioqIUGhqqhIQEPf/8880ex+9+9zsNGTJEoaGh6tmzp66//np9+OGHTbbdvHmzBg8eLKfTqUGDBqmwsNDr+ydPntTdd9+t+Ph4OZ1ORUZGaty4cdq5c2e7/Q4AXNwIQABe4uPjlZKSopdeesmz9uabb6q6ulq33HJLk+379OkjSfrjH/8oY0yr7uP48eOqqqryup04ceKc+2RkZOi9995r8jrBbdu26fPPP/fMVlRUpMmTJ6tHjx56/PHHtXTpUl177bXavn17q2Y7c+ZMk9mqqqqaXMn86quvNH78eCUlJWnZsmXq1auX7rrrLq1fv96zzddff61rr71W//Ef/6Fbb71Vy5cvV3h4uG677bYmQTpjxgzdfffdcrlcevzxx7VgwQKFhobqgw8+aHK8s2fP1i233KJly5bpzJkzuummm/Tll196tpk1a5bWrl2rm266SWvWrNE999yjLl26aPfu3a36HQCwgAEAY8yzzz5rJJm//OUvZtWqVaZbt27m9OnTxhhj/vVf/9WMHTvWGGNMnz59zIQJEzz7nT592gwYMMBIMn369DG33XabeeaZZ0xlZWWT+8jOzjaSmr0NGDDgnPPt3bvXSDIrV670Wp89e7bp2rWrZ9Z58+aZsLAw09DQ0ObfQUuzSTIvvfSSZ7sxY8YYSWbFihWetbq6OpOYmGgiIyNNfX29McaY3NxcI8m88MILnu3q6+tNSkqK6dq1q6mpqTHGGPPOO+8YSebXv/51k5ncbrfXfA6Hw+zfv9+z9te//rXJ7yU8PNzMmTOnzccPwB5cAQTQxKRJk/T111/rjTfe0MmTJ/XGG280+/SvJHXp0kV//vOfde+990rfvJt4xowZiomJ0dy5c1VXV9dkn//8z/9UUVGR1+3ZZ58950z9+/dXYmKi8vPzPWuNjY3atGmT0tPT1aVLF+mbN5LU1tZ6PYXdFr/4xS+azFZUVNTk6edOnTpp5syZnq8dDodmzpypo0ePqrS0VJJUUFCg6OhoTZ482bNd586d9etf/1qnTp3Sn/70J8/vIygoSNnZ2U3mCQoK8vo6NTVV/fr183w9dOhQhYWFqby83LPWvXt3/fnPf/Z6BzYA/DPeBAKgiZ49eyo1NVUbNmzQ6dOn1djY6PUau+8KDw/XsmXLtGzZMh06dEjFxcV64okntGrVKoWHh+vRRx/12v5nP/vZeb0JJCMjQ/fff78+++wzxcXFaevWrTp69KgyMjI828yePVsvv/yybrjhBsXFxem6667TpEmTdP3117fqPnr16qXU1NTv3S42NlaXXnqp19q37xQ+ePCgRowYoUOHDunHP/6xgoO9/7/2lVdeKUk6dOiQ9M3rKGNjY3XZZZd97/327t27yVqPHj301Vdfeb5etmyZMjMz5XK5lJSUpPHjx2vatGnq27fv9/58AHbgCiCAZk2ZMkVvvvmm8vLydMMNN7T6I1r69Omjf//3f9f27dvVvXt3vfjii+02U0ZGhowxeuWVVyRJL7/8ssLDw73iLjIyUmVlZXr99dd144036t1339UNN9ygzMzMdpvDn0JCQppd/+fXX06aNEnl5eVauXKlYmNjtXz5cg0aNEhvvvmmDycFcCEjAAE065e//KWCg4P1wQcftPj077n06NFD/fr10xdffNFuM11xxRUaPny48vPz1dDQoFdffVUTJ05s8tmEDodD6enpWrNmjT799FPNnDlTf/zjH7V///52m+Xzzz9XbW2t19q+ffukb95Io29i+O9//7vcbrfXdnv27PF8X5L69eunzz//XMePH2+3+WJiYjR79mxt3rxZBw4c0OWXX67HHnus3X4+gIsbAQigWV27dtXatWu1ePFipaent7jdX//61yafDahvnt785JNPNGDAgHadKyMjQx988IHWr1+vqqoqr6d/JXm9G1aSgoODNXToUElq9vWI56uhoUFPPfWU5+v6+no99dRT6tmzp5KSkqRvPlOxoqLC63WLDQ0NWrlypbp27aoxY8ZIkm666SYZY/Twww83uZ/WvrP6W42NjaqurvZai4yMVGxsbLseP4CLG68BBNCi1jxtWlRUpOzsbN14440aMWKEunbtqvLycq1fv151dXVavHhxk302bdrU7F8CGTdunKKios55f5MmTdI999yje+65R5dddlmT1+vdfvvtOn78uH7+85+rV69eOnTokFauXKnExETPa+/OZd++fXrhhRearEdFRWncuHGer2NjY/X444/r4MGD6t+/v/Lz81VWVqann35anTt3liTdeeedeuqpp3TbbbeptLRU8fHx2rRpk7Zv367c3Fx169ZNkjR27FhNnTpVv//97/X3v/9d119/vdxut/7nf/5HY8eObdPf/z158qR69eqlm2++WQkJCeratavefvtt/eUvf9GKFSta/XMABDh/vw0ZwIXhnz8G5ly++zEw5eXlZtGiRWbEiBEmMjLSdOrUyfTs2dNMmDDBvPPOO177nutjYCSZd999t1Wzjho1ykgyt99+e5Pvbdq0yVx33XUmMjLSOBwO07t3bzNz5kzzxRdffO/PPddsY8aM8Ww3ZswYM2jQIPPhhx+alJQUExoaavr06WNWrVrV5GdWVlaa6dOnm4iICONwOMyQIUPMs88+22S7hoYGs3z5cjNw4EDjcDhMz549zQ033GBKS0u95mvu41369OljMjMzjfnm42juvfdek5CQYLp162YuvfRSk5CQYNasWfO9xw/AHkGmrc8vAIDlrr32WlVVVenjjz/29ygAcF54DSAAAIBlCEAAAADLEIAAAACW4TWAAAAAluEKIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIBE4Dvvfee0tPTFRsbq6CgIG3evPl799m6dat++tOfyul06kc/+pGee+45n8wKAADgTwETgLW1tUpISNDq1atbtf2BAwc0YcIEjR07VmVlZbr77rt1++2366233urwWQEAAPwpyBhj/D1EewsKCtJrr72miRMntrjNfffdpy1btujjjz/2rN1yyy06ceKECgsLfTQpAACA7wXMFcC2KikpUWpqqtdaWlqaSkpK/DYTAACAL3Ty9wD+UlFRoaioKK+1qKgo1dTU6Ouvv1aXLl2a7FNXV6e6ujrP1263W8ePH9fll1+uoKAgn8wNAAB+GGOMTp48qdjYWAUH23ktzNoAPB85OTl6+OGH/T0GAABoB0eOHFGvXr38PYZfWBuA0dHRqqys9FqrrKxUWFhYs1f/JGnhwoXKysryfF1dXa3evXvryJEjCgsL6/CZAQDAD1dTUyOXy6Vu3br5exS/sTYAU1JSVFBQ4LVWVFSklJSUFvdxOp1yOp1N1sPCwghAAAAuMja/fCtgnvg+deqUysrKVFZWJn3zMS9lZWU6fPiw9M3Vu2nTpnm2nzVrlsrLy/Wb3/xGe/bs0Zo1a/Tyyy9r/vz5fjsGAAAAXwiYAPzwww911VVX6aqrrpIkZWVl6aqrrtKiRYskSV988YUnBiXpiiuu0JYtW1RUVKSEhAStWLFCf/jDH5SWlua3YwAAAPCFgPwcQF+pqalReHi4qqureQoYAICLBI/fAXQFEAAAAK1DAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAywRUAK5evVrx8fEKDQ1VcnKyduzYcc7tc3NzNWDAAHXp0kUul0vz58/XmTNnfDYvAACAPwRMAObn5ysrK0vZ2dnauXOnEhISlJaWpqNHjza7/YYNG7RgwQJlZ2dr9+7deuaZZ5Sfn6/777/f57MDAAD4UsAE4JNPPqk77rhD06dP109+8hPl5eXpkksu0fr165vd/v3339eoUaM0ZcoUxcfH67rrrtPkyZO/96ohAADAxS4gArC+vl6lpaVKTU31rAUHBys1NVUlJSXN7jNy5EiVlpZ6gq+8vFwFBQUaP368z+YGAADwh07+HqA9VFVVqbGxUVFRUV7rUVFR2rNnT7P7TJkyRVVVVbrmmmtkjFFDQ4NmzZp1zqeA6+rqVFdX5/m6pqamHY8CAADANwLiCuD52Lp1q5YsWaI1a9Zo586devXVV7VlyxY98sgjLe6Tk5Oj8PBwz83lcvl0ZgAAgPYQZIwx/h7ih6qvr9cll1yiTZs2aeLEiZ71zMxMnThxQv/1X//VZJ/Ro0drxIgRWr58uWfthRde0J133qlTp04pOLhpGzd3BdDlcqm6ulphYWEdcmwAAKB91dTUKDw83OrH74C4AuhwOJSUlKTi4mLPmtvtVnFxsVJSUprd5/Tp000iLyQkRJLUUhM7nU6FhYV53QAAAC42AfEaQEnKyspSZmamhg0bpuHDhys3N1e1tbWaPn26JGnatGmKi4tTTk6OJCk9PV1PPvmkrrrqKiUnJ2v//v166KGHlJ6e7glBAACAQBQwAZiRkaFjx45p0aJFqqioUGJiogoLCz1vDDl8+LDXFb8HH3xQQUFBevDBB/XZZ5+pZ8+eSk9P12OPPebHowAAAOh4AfEaQH/hNQQAAFx8ePwOkNcAAgAAoPUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgmYAKwNWrVys+Pl6hoaFKTk7Wjh07zrn9iRMnNGfOHMXExMjpdKp///4qKCjw2bwAAAD+0MnfA7SX/Px8ZWVlKS8vT8nJycrNzVVaWpr27t2ryMjIJtvX19dr3LhxioyM1KZNmxQXF6dDhw6pe/fufpkfAADAV4KMMcbfQ7SH5ORkXX311Vq1apUkye12y+Vyae7cuVqwYEGT7fPy8rR8+XLt2bNHnTt3Pq/7rKmpUXh4uKqrqxUWFvaDjwEAAHQ8Hr8D5Cng+vp6lZaWKjU11bMWHBys1NRUlZSUNLvP66+/rpSUFM2ZM0dRUVEaPHiwlixZosbGxhbvp66uTjU1NV43AACAi01ABGBVVZUaGxsVFRXltR4VFaWKiopm9ykvL9emTZvU2NiogoICPfTQQ1qxYoUeffTRFu8nJydH4eHhnpvL5Wr3YwEAAOhoARGA58PtdisyMlJPP/20kpKSlJGRoQceeEB5eXkt7rNw4UJVV1d7bkeOHPHpzAAAAO0hIN4EEhERoZCQEFVWVnqtV1ZWKjo6utl9YmJi1LlzZ4WEhHjWrrzySlVUVKi+vl4Oh6PJPk6nU06nswOOAAAAwHcC4gqgw+FQUlKSiouLPWtut1vFxcVKSUlpdp9Ro0Zp//79crvdnrV9+/YpJiam2fgDAAAIFAERgJKUlZWldevW6fnnn9fu3bt11113qba2VtOnT5ckTZs2TQsXLvRsf9ddd+n48eOaN2+e9u3bpy1btmjJkiWaM2eOH48CAACg4wXEU8CSlJGRoWPHjmnRokWqqKhQYmKiCgsLPW8MOXz4sIKD/793XS6X3nrrLc2fP19Dhw5VXFyc5s2bp/vuu8+PRwEAANDxAuZzAP2BzxECAODiw+N3AD0FDAAAgNYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQIqAFevXq34+HiFhoYqOTlZO3bsaNV+GzduVFBQkCZOnNjhMwIAAPhbwARgfn6+srKylJ2drZ07dyohIUFpaWk6evToOfc7ePCg7rnnHo0ePdpnswIAAPhTwATgk08+qTvuuEPTp0/XT37yE+Xl5emSSy7R+vXrW9ynsbFRt956qx5++GH17dvXp/MCAAD4S0AEYH19vUpLS5WamupZCw4OVmpqqkpKSlrc77e//a0iIyM1Y8aMVt1PXV2dampqvG4AAAAXm4AIwKqqKjU2NioqKsprPSoqShUVFc3us23bNj3zzDNat25dq+8nJydH4eHhnpvL5frBswMAAPhaQARgW508eVJTp07VunXrFBER0er9Fi5cqOrqas/tyJEjHTonAABAR+jk7wHaQ0REhEJCQlRZWem1XllZqejo6Cbbf/rppzp48KDS09M9a263W5LUqVMn7d27V/369Wuyn9PplNPp7JBjAAAA8JWAuALocDiUlJSk4uJiz5rb7VZxcbFSUlKabD9w4EB99NFHKisr89xuvPFGjR07VmVlZTy1CwAAAlpAXAGUpKysLGVmZmrYsGEaPny4cnNzVVtbq+nTp0uSpk2bpri4OOXk5Cg0NFSDBw/22r979+6S1GQdAAAg0ARMAGZkZOjYsWNatGiRKioqlJiYqMLCQs8bQw4fPqzg4IC44AkAAPCDBBljjL+HuFjV1NQoPDxc1dXVCgsL8/c4AACgFXj8DpDXAAIAAKD1CEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYJmACsDVq1crPj5eoaGhSk5O1o4dO1rcdt26dRo9erR69OihHj16KDU19ZzbAwAABIqACcD8/HxlZWUpOztbO3fuVEJCgtLS0nT06NFmt9+6dasmT56sd999VyUlJXK5XLruuuv02Wef+Xx2AAAAXwoyxhh/D9EekpOTdfXVV2vVqlWSJLfbLZfLpblz52rBggXfu39jY6N69OihVatWadq0aa26z5qaGoWHh6u6ulphYWE/+BgAAEDH4/E7QK4A1tfXq7S0VKmpqZ614OBgpaamqqSkpFU/4/Tp0zp79qwuu+yyFrepq6tTTU2N1w0AAOBiExABWFVVpcbGRkVFRXmtR0VFqaKiolU/47777lNsbKxXRH5XTk6OwsPDPTeXy/WDZwcAAPC1gAjAH2rp0qXauHGjXnvtNYWGhra43cKFC1VdXe25HTlyxKdzAgAAtIdO/h6gPURERCgkJESVlZVe65WVlYqOjj7nvk888YSWLl2qt99+W0OHDj3ntk6nU06ns11mBgAA8JeAuALocDiUlJSk4uJiz5rb7VZxcbFSUlJa3G/ZsmV65JFHVFhYqGHDhvloWgAAAP8KiCuAkpSVlaXMzEwNGzZMw4cPV25urmprazV9+nRJ0rRp0xQXF6ecnBxJ0uOPP65FixZpw4YNio+P97xWsGvXruratatfjwUAAKAjBUwAZmRk6NixY1q0aJEqKiqUmJiowsJCzxtDDh8+rODg/7/guXbtWtXX1+vmm2/2+jnZ2dlavHixz+cHAADwlYD5HEB/4HOEAAC4+PD4HSCvAQQAAEDrEYAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIBFYCrV69WfHy8QkNDlZycrB07dpxz+1deeUUDBw5UaGiohgwZooKCAp/NCgAA4C8BE4D5+fnKyspSdna2du7cqYSEBKWlpeno0aPNbv/+++9r8uTJmjFjhnbt2qWJEydq4sSJ+vjjj30+OwAAgC8FGWOMv4doD8nJybr66qu1atUqSZLb7ZbL5dLcuXO1YMGCJttnZGSotrZWb7zxhmdtxIgRSkxMVF5eXqvus6amRuHh4aqurlZYWFg7Hg0AAOgoPH5Lnfw9QHuor69XaWmpFi5c6FkLDg5WamqqSkpKmt2npKREWVlZXmtpaWnavHlzi/dTV1enuro6z9fV1dXSN/9DAgAAF4dvH7cD5BrYeQmIAKyqqlJjY6OioqK81qOiorRnz55m96moqGh2+4qKihbvJycnRw8//HCTdZfLdd6zAwAA//jyyy8VHh7u7zH8IiAC0FcWLlzoddXwxIkT6tOnjw4fPmzt/4AuFDU1NXK5XDpy5Ii1l/MvFJyLCwvn48LBubhwVFdXq3fv3rrsssv8PYrfBEQARkREKCQkRJWVlV7rlZWVio6Obnaf6OjoNm0vSU6nU06ns8l6eHg4/5gvEGFhYZyLCwTn4sLC+bhwcC4uHMHBAfNe2DYLiCN3OBxKSkpScXGxZ83tdqu4uFgpKSnN7pOSkuK1vSQVFRW1uD0AAECgCIgrgJKUlZWlzMxMDRs2TMOHD1dubq5qa2s1ffp0SdK0adMUFxennJwcSdK8efM0ZswYrVixQhMmTNDGjRv14Ycf6umnn/bzkQAAAHSsgAnAjIwMHTt2TIsWLVJFRYUSExNVWFjoeaPH4cOHvS71jhw5Uhs2bNCDDz6o+++/Xz/+8Y+1efNmDR48uNX36XQ6lZ2d3ezTwvAtzsWFg3NxYeF8XDg4FxcOzkUAfQ4gAAAAWicgXgMIAACA1iMAAQAALEMAAgAAWIYABAAAsAwB+D1Wr16t+Ph4hYaGKjk5WTt27Djn9q+88ooGDhyo0NBQDRkyRAUFBT6bNdC15VysW7dOo0ePVo8ePdSjRw+lpqZ+77lD67X138W3Nm7cqKCgIE2cOLHDZ7RFW8/FiRMnNGfOHMXExMjpdKp///78d6odtfV85ObmasCAAerSpYtcLpfmz5+vM2fO+GzeQPXee+8pPT1dsbGxCgoK0ubNm793n61bt+qnP/2pnE6nfvSjH+m5557zyax+Y9CijRs3GofDYdavX2/+9re/mTvuuMN0797dVFZWNrv99u3bTUhIiFm2bJn55JNPzIMPPmg6d+5sPvroI5/PHmjaei6mTJliVq9ebXbt2mV2795tbrvtNhMeHm7+8Y9/+Hz2QNPWc/GtAwcOmLi4ODN69Gjzi1/8wmfzBrK2nou6ujozbNgwM378eLNt2zZz4MABs3XrVlNWVubz2QNRW8/Hiy++aJxOp3nxxRfNgQMHzFtvvWViYmLM/PnzfT57oCkoKDAPPPCAefXVV40k89prr51z+/LycnPJJZeYrKws88knn5iVK1eakJAQU1hY6LOZfY0APIfhw4ebOXPmeL5ubGw0sbGxJicnp9ntJ02aZCZMmOC1lpycbGbOnNnhswa6tp6L72poaDDdunUzzz//fAdOaYfzORcNDQ1m5MiR5g9/+IPJzMwkANtJW8/F2rVrTd++fU19fb0Pp7RHW8/HnDlzzM9//nOvtaysLDNq1KgOn9UmrQnA3/zmN2bQoEFeaxkZGSYtLa2Dp/MfngJuQX19vUpLS5WamupZCw4OVmpqqkpKSprdp6SkxGt7SUpLS2txe7TO+ZyL7zp9+rTOnj1r9R/+bg/ney5++9vfKjIyUjNmzPDRpIHvfM7F66+/rpSUFM2ZM0dRUVEaPHiwlixZosbGRh9OHpjO53yMHDlSpaWlnqeJy8vLVVBQoPHjx/tsbvwfGx+/A+YvgbS3qqoqNTY2ev6SyLeioqK0Z8+eZvepqKhodvuKiooOnTXQnc+5+K777rtPsbGxTf6Bo23O51xs27ZNzzzzjMrKynw0pR3O51yUl5frnXfe0a233qqCggLt379fs2fP1tmzZ5Wdne2jyQPT+ZyPKVOmqKqqStdcc42MMWpoaNCsWbN0//33+2hqfKulx++amhp9/fXX6tKli99m6yhcAUTAW7p0qTZu3KjXXntNoaGh/h7HKidPntTUqVO1bt06RURE+Hsc67ndbkVGRurpp59WUlKSMjIy9MADDygvL8/fo1lp69atWrJkidasWaOdO3fq1Vdf1ZYtW/TII4/4ezRYgCuALYiIiFBISIgqKyu91isrKxUdHd3sPtHR0W3aHq1zPufiW0888YSWLl2qt99+W0OHDu3gSQNfW8/Fp59+qoMHDyo9Pd2z5na7JUmdOnXS3r171a9fPx9MHnjO599FTEyMOnfurJCQEM/alVdeqYqKCtXX18vhcHT43IHqfM7HQw89pKlTp+r222+XJA0ZMkS1tbW688479cADD3j9/Xp0rJYev8PCwgLy6p+4Atgyh8OhpKQkFRcXe9bcbreKi4uVkpLS7D4pKSle20tSUVFRi9ujdc7nXEjSsmXL9Mgjj6iwsFDDhg3z0bSBra3nYuDAgfroo49UVlbmud14440aO3asysrK5HK5fHwEgeN8/l2MGjVK+/fv90S4JO3bt08xMTHE3w90Pufj9OnTTSLv2zj/v/cuwFesfPz297tQLmQbN240TqfTPPfcc+aTTz4xd955p+nevbupqKgwxhgzdepUs2DBAs/227dvN506dTJPPPGE2b17t8nOzuZjYNpJW8/F0qVLjcPhMJs2bTJffPGF53by5Ek/HkVgaOu5+C7eBdx+2nouDh8+bLp162Z+9atfmb1795o33njDREZGmkcffdSPRxE42no+srOzTbdu3cxLL71kysvLzX//93+bfv36mUmTJvnxKALDyZMnza5du8yuXbuMJPPkk0+aXbt2mUOHDhljjFmwYIGZOnWqZ/tvPwbm3nvvNbt37zarV6/mY2Bst3LlStO7d2/jcDjM8OHDzQcffOD53pgxY0xmZqbX9i+//LLp37+/cTgcZtCgQWbLli1+mDowteVc9OnTx0hqcsvOzvbT9IGlrf8u/hkB2L7aei7ef/99k5ycbJxOp+nbt6957LHHTENDgx8mD0xtOR9nz541ixcvNv369TOhoaHG5XKZ2bNnm6+++spP0weOd999t9nHgG9//5mZmWbMmDFN9klMTDQOh8P07dvXPPvss36a3jeCDNeZAQAArMJrAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADL/C88oWm6KoRVvwAAAABJRU5ErkJggg==",
      "text/html": [
       "\n",
       "            <div style=\"display: inline-block;\">\n",
       "                <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
       "                    Figure\n",
       "                </div>\n",
       "                <img src='' width=640.0/>\n",
       "            </div>\n",
       "        "
      ],
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# This has to be in a separate cell for some weird event loop reasons\n",
    "%matplotlib widget\n",
    "fig,ax = plt.subplots()\n",
    "fig.suptitle('MSE vs Epochs')\n",
    "plt.show()\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "execution_state": "idle",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 486
    },
    "id": "pvTfzGmCeXU4",
    "outputId": "0d3a20f3-23be-4c19-9eb6-46bfe11a48b1"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                    | 0/64 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'mkbatch' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[12], line 8\u001b[0m\n\u001b[1;32m      6\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(BPE)):\n\u001b[0;32m----> 8\u001b[0m     batch_src, batch_labels, batch_padding_mask \u001b[38;5;241m=\u001b[39m \u001b[43mmkbatch\u001b[49m(BSZ)\n\u001b[1;32m      9\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m     10\u001b[0m     output \u001b[38;5;241m=\u001b[39m model(batch_src, batch_padding_mask)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'mkbatch' is not defined"
     ]
    }
   ],
   "source": [
    "train_err = []\n",
    "test_err = []\n",
    "\n",
    "for epoch in range(NEPOCHS):\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    for i in tqdm(range(BPE)):\n",
    "        batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "        train_loss += loss.item() / BPEREPOCH\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    test_loss = evaluate()\n",
    "            \n",
    "    test_err.append(test_loss)\n",
    "    train_err.append(train_loss)\n",
    "    with open(\"loss\", \"a\") as f:\n",
    "        f.write(f\"{train_loss} {test_loss}\\n\")\n",
    "    print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
    "    ax.plot(train_err, label='Train', color='blue')\n",
    "    ax.plot(test_err, label='Test', color='red')\n",
    "    ax.set_xlabel('Epochs')\n",
    "    ax.set_ylabel('MSE')\n",
    "    fig.canvas.draw()\n",
    "\n",
    "    if epoch % 100 == 99:\n",
    "        torch.save(model.state_dict(), f\"model_weights_{epoch}.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "execution_state": "idle",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "177.23538208007812"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {
    "id": "LoGEmM5lH7_A"
   },
   "outputs": [],
   "source": [
    "batch_src, batch_labels, batch_padding_mask = next(iter(train_loader))\n",
    "output = model(batch_src, batch_padding_mask)\n",
    "batch_src[0], batch_labels[0], output[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(output.detach().cpu().numpy().flatten(),bins=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(label.detach().cpu().numpy().flatten(),bins=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(batch_labels.detach().cpu().numpy().flatten(),output.detach().cpu().numpy().flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_src2, batch_labels2, batch_padding_mask2 = next(iter(test_loader))\n",
    "output2 = model(batch_src2, batch_padding_mask2)\n",
    "loss = criterion(output2.squeeze(1), batch_labels2)\n",
    "batch_src2[0], batch_labels2[0], output2[0], loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(batch_labels2.detach().cpu().numpy().flatten(),output2.detach().cpu().numpy().flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LC6Xv3YfC0Rm"
   },
   "source": [
    "# Step 5: Fine Tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TUNE_EPOCHS = 100\n",
    "TUNE_LR = 0.003\n",
    "TUNE_WD = 0.002\n",
    "\n",
    "tune_criterion = nn.MSELoss()\n",
    "tune_optimizer = torch.optim.Adam(model.parameters(), lr=TUNE_LR, weight_decay=TUNE_WD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tuneevaluate():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        batch_src, batch_labels, batch_padding_mask = mktunebatch(BSZ)\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "execution_state": "running",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This has to be in a separate cell for some weird event loop reasons\n",
    "%matplotlib widget\n",
    "fig,ax = plt.subplots()\n",
    "fig.suptitle('MSE vs Epochs')\n",
    "plt.show()\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "execution_state": "idle",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'N_TUNE_EPOCHS' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 4\u001b[0m\n\u001b[1;32m      1\u001b[0m tune_train_err \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      2\u001b[0m tune_test_err \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[43mN_TUNE_EPOCHS\u001b[49m):\n\u001b[1;32m      5\u001b[0m     model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m      6\u001b[0m     train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'N_TUNE_EPOCHS' is not defined"
     ]
    }
   ],
   "source": [
    "tune_train_err = []\n",
    "tune_test_err = []\n",
    "\n",
    "for epoch in range(N_TUNE_EPOCHS):\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    for i in tqdm(range(BPE)):\n",
    "        batch_src, batch_labels, batch_padding_mask = mktunebatch(BSZ)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "        train_loss += loss.item()/BPE\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    test_loss = tuneevaluate()\n",
    "    \n",
    "    tune_test_err.append(test_loss)\n",
    "    tune_train_err.append(train_loss)\n",
    "    with open(\"tune_loss\", \"a\") as f:\n",
    "        f.write(f\"{train_loss} {test_loss}\\n\")\n",
    "    ax.plot(tune_train_err, label='Train', color='blue')\n",
    "    ax.plot(tune_test_err, label='Test', color='red')\n",
    "    ax.set_xlabel('Epochs')\n",
    "    ax.set_ylabel('MSE')\n",
    "    fig.canvas.draw()\n",
    "    print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
    "\n",
    "    if epoch % 10 == 9:\n",
    "        torch.save(model.state_dict(), f\"tune_model_weights_{epoch}.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JtTLXn4zC1z_"
   },
   "source": [
    "# Step 6: Test generalization"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}