aboutsummaryrefslogtreecommitdiff
path: root/transformer_shortest_paths.ipynb
blob: 0adb5dbfdf3660c59a2208773be257c3d33285cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LPphBnKR-aWF"
   },
   "source": [
    "# Step 0: Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ge5QvElvhCOw",
    "outputId": "c7cdaefa-d6dc-44ad-c258-e4fb2aca97a5"
   },
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "# using tqdm.auto glitches out collaborative editing\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from math import sqrt\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import matplotlib.pyplot as plt\n",
    "# %matplotlib widget\n",
    "torch.manual_seed(42)\n",
    "\n",
    "import os\n",
    "import random\n",
    "random.seed(42)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "assert device.type == 'cuda', \"CUDA is not available. Please check your GPU setup.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "lylOX2POPwFL"
   },
   "outputs": [],
   "source": [
    "SEQ_LEN = 65 # means 32 edges, final token is the target vertex\n",
    "PAD_TOKEN = 0\n",
    "AVG_DEG = 2\n",
    "MAX_VTXS = SEQ_LEN//AVG_DEG + 1 # 32 (exclusive)\n",
    "MIN_VTXS = 8\n",
    "MAX_TUNE_VTXS = 16\n",
    "PAD_TOKEN = 0\n",
    "# vertices are labelled 1,2,...,63\n",
    "# we also have a padding token which is 0."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gKt-yIpDebF1"
   },
   "source": [
    "# Step 1: Generate synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1IbzGIWseK3E",
    "outputId": "a3cbc233-358c-4e17-ea6e-f4e9349d886b"
   },
   "outputs": [],
   "source": [
    "# original task data\n",
    "# the data will be edge lists\n",
    "# like this: [1 3 1 5 2 4 0 0 0 0 2]\n",
    "# this represents edges (1,3), (1,5) (2,4)\n",
    "# (the zeros are just padding tokens)\n",
    "# the final 2 means which vertex we're going to \n",
    "\n",
    "# the label is the shortest distance from vtx 1 to vtx 2\n",
    "# or \"number of vertices\" if no path exists\n",
    "\n",
    "def random_graph(n):\n",
    "    edge_list = []\n",
    "    adjacencies = [set() for _ in range(n+1)]\n",
    "    indices = [random.randint(1, n-1) for _ in range(AVG_DEG * (n-1))]\n",
    "    for i in range(0, len(indices), 2):\n",
    "        u = indices[i]\n",
    "        v = indices[i + 1]\n",
    "        if u != v:\n",
    "            edge_list += [u,v]\n",
    "            adjacencies[u].add(v)\n",
    "            adjacencies[v].add(u)\n",
    "\n",
    "    edge_list += [PAD_TOKEN]*(SEQ_LEN-len(edge_list))\n",
    "    return edge_list, adjacencies\n",
    "\n",
    "\"\"\"\n",
    "input: G, represented as an adjacency list\n",
    "output: [number of vertices]+[d(1,i) for i in range(n)] if target=None\n",
    "if target is set to some value, then we instead just output that specific distance\n",
    "\"\"\"\n",
    "def SSSP(n, G, target=2):\n",
    "    dist = [n for _ in G]\n",
    "    dist[1] = 0\n",
    "    frontier = deque()\n",
    "    frontier.append(1)\n",
    "    while len(frontier) > 0:\n",
    "        vtx = frontier.popleft()\n",
    "        for x in G[vtx]:\n",
    "            if dist[x] == n:\n",
    "                dist[x] = 1 + dist[vtx]\n",
    "                frontier.append(x)\n",
    "                if x == target:\n",
    "                    return dist[target]\n",
    "    if target is not None:\n",
    "        return dist[target]\n",
    "    else:\n",
    "        return dist\n",
    "\n",
    "def mkbatch(size):\n",
    "    graphs1 = []\n",
    "    distance1 = []\n",
    "    \n",
    "    for i in range(size):\n",
    "        n = random.randrange(MIN_VTXS, MAX_VTXS)\n",
    "        edge_list, adj_list = random_graph(n)\n",
    "        dist = SSSP(n, adj_list)\n",
    "        edge_list[-1] = 2 # target token\n",
    "        graphs1.append(edge_list)\n",
    "        distance1.append(dist)\n",
    "    \n",
    "    data = torch.tensor(graphs1, device=device)\n",
    "    labels = torch.tensor(distance1, dtype=torch.bfloat16, device=device)\n",
    "    padding = data == PAD_TOKEN\n",
    "    return data, labels, padding\n",
    "\n",
    "def vertices_on_shortest_12_path(n, G, target=2):\n",
    "    dist = [n for _ in G]\n",
    "    parent = [-1 for _ in G]\n",
    "    dist[1] = 0\n",
    "    frontier = deque()\n",
    "    frontier.append(1)\n",
    "    while len(frontier) > 0:\n",
    "        vtx = frontier.popleft()\n",
    "        for x in G[vtx]:\n",
    "            if dist[x] == n:\n",
    "                parent[x] = vtx\n",
    "                dist[x] = 1 + dist[vtx]\n",
    "                frontier.append(x)\n",
    "                if x == target:\n",
    "                    path = [x]\n",
    "                    while parent[x] != -1:\n",
    "                        x = parent[x]\n",
    "                        path.append(x)\n",
    "                    return list(reversed(path))\n",
    "    return []\n",
    "\n",
    "def mktunebatch(size):\n",
    "    graphs2 = []\n",
    "    distance2 = []\n",
    "    \n",
    "    for i in range(size):\n",
    "        n = random.randrange(MIN_VTXS, MAX_TUNE_VTXS)\n",
    "        while True:\n",
    "            edge_list, adj_list = random_graph(n)\n",
    "            path = vertices_on_shortest_12_path(n, adj_list)\n",
    "            if len(path) > 1:\n",
    "                target_vtx_idx = random.randrange(1, len(path))\n",
    "                target_vtx = path[target_vtx_idx]\n",
    "                edge_list[-1] = target_vtx\n",
    "                graphs2.append(edge_list)\n",
    "                distance2.append(target_vtx_idx)\n",
    "                break\n",
    "    \n",
    "    data = torch.tensor(graphs2, device=device)\n",
    "    labels = torch.tensor(distance2, dtype=torch.bfloat16, device=device)\n",
    "    padding = data == PAD_TOKEN\n",
    "    return data, labels, padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([252.,   0., 309.,   0., 278.,   0., 199.,   0., 119.,   0.,  63.,\n",
       "          0.,  23.,   0.,  45.,   0.,  36.,   0.,  36.,   0.,  28.,   0.,\n",
       "         37.,   0.,  23.,   0.,  33.,   0.,  24.,   0.,  34.,   0.,   0.,\n",
       "         30.,   0.,  33.,   0.,  34.,   0.,  34.,   0.,  25.,   0.,  33.,\n",
       "          0.,  39.,   0.,  33.,   0.,  20.,   0.,  29.,   0.,  27.,   0.,\n",
       "         35.,   0.,  37.,   0.,  30.,   0.,  33.,   0.,  37.]),\n",
       " array([ 1.      ,  1.484375,  1.96875 ,  2.453125,  2.9375  ,  3.421875,\n",
       "         3.90625 ,  4.390625,  4.875   ,  5.359375,  5.84375 ,  6.328125,\n",
       "         6.8125  ,  7.296875,  7.78125 ,  8.265625,  8.75    ,  9.234375,\n",
       "         9.71875 , 10.203125, 10.6875  , 11.171875, 11.65625 , 12.140625,\n",
       "        12.625   , 13.109375, 13.59375 , 14.078125, 14.5625  , 15.046875,\n",
       "        15.53125 , 16.015625, 16.5     , 16.984375, 17.46875 , 17.953125,\n",
       "        18.4375  , 18.921875, 19.40625 , 19.890625, 20.375   , 20.859375,\n",
       "        21.34375 , 21.828125, 22.3125  , 22.796875, 23.28125 , 23.765625,\n",
       "        24.25    , 24.734375, 25.21875 , 25.703125, 26.1875  , 26.671875,\n",
       "        27.15625 , 27.640625, 28.125   , 28.609375, 29.09375 , 29.578125,\n",
       "        30.0625  , 30.546875, 31.03125 , 31.515625, 32.      ]),\n",
       " <BarContainer object of 64 artists>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiy0lEQVR4nO3df1BVdeL/8RegXH9xL6HChRXxV4nkj1pSvFO5rrAisq6uNJPllrWOTu6lWaVMaUzLdhbXmn4O6exsqzUjWe6kjlSWYuK2oiXF+KsYdWixhQutjlzFQITz/WPH+/3c/EEoct9cno+ZM8M9533vfZ+zp+XpuT8IsSzLEgAAgEFCAz0BAACAHyNQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABinW6AncD1aWlpUVVWliIgIhYSEBHo6AADgJ7AsS2fPnlVcXJxCQ699jaRTBkpVVZXi4+MDPQ0AAHAdTp48qQEDBlxzTKcMlIiICEn/20G73R7g2QAAgJ/C6/UqPj7e93v8WjploFx6WcdutxMoAAB0Mj/l7Rm8SRYAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMbpFugJBJNBSz+46rZvV2V24EwAAOjcuIICAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzTpkBZs2aNRo8eLbvdLrvdLpfLpY8++si3vaGhQW63W3379lWfPn2UlZWlmpoav8eorKxUZmamevXqpejoaC1evFgXL15sn70BAABBoU2BMmDAAK1atUqlpaU6cOCAJk2apOnTp+vIkSOSpEWLFmnbtm3atGmTiouLVVVVpZkzZ/ru39zcrMzMTF24cEF79+7VW2+9pfXr12v58uXtu1cAAKBTC7Esy7qRB4iKitILL7yg++67T/3791dBQYHuu+8+SdI333yjESNGqKSkROPHj9dHH32kX//616qqqlJMTIwkae3atVqyZIm+//57hYeH/6Tn9Hq9cjgcqqurk91uv5Hptyu+BwUAgKtry+/v634PSnNzszZu3Kj6+nq5XC6VlpaqqalJaWlpvjGJiYkaOHCgSkpKJEklJSUaNWqUL04kKT09XV6v13cVBgAAoM3fJHvo0CG5XC41NDSoT58+2rx5s5KSklRWVqbw8HBFRkb6jY+JiZHH45EkeTwevzi5tP3StqtpbGxUY2Oj77bX623rtAEAQCfS5isow4cPV1lZmfbv368FCxZozpw5Onr06M2Ym09eXp4cDodviY+Pv6nPBwAAAqvNgRIeHq5hw4YpOTlZeXl5GjNmjF599VU5nU5duHBBZ86c8RtfU1Mjp9MpSXI6nZd9qufS7UtjriQ3N1d1dXW+5eTJk22dNgAA6ERu+HtQWlpa1NjYqOTkZHXv3l1FRUW+beXl5aqsrJTL5ZIkuVwuHTp0SLW1tb4xO3bskN1uV1JS0lWfw2az+T7afGkBAADBq03vQcnNzVVGRoYGDhyos2fPqqCgQLt379bHH38sh8OhuXPnKicnR1FRUbLb7Xr88cflcrk0fvx4SdLkyZOVlJSkhx56SKtXr5bH49GyZcvkdrtls9luyg4CAIDOp02BUltbq4cffljV1dVyOBwaPXq0Pv74Y/3qV7+SJL388ssKDQ1VVlaWGhsblZ6erjfeeMN3/7CwMBUWFmrBggVyuVzq3bu35syZo5UrV7bvXgEAgE7thr8HJRD4HhQAADqfDvkeFAAAgJuFQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcboFegKQBi394Krbvl2V2YEzAQDADFxBAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcdoUKHl5eRo7dqwiIiIUHR2tGTNmqLy83G/MxIkTFRIS4rc89thjfmMqKyuVmZmpXr16KTo6WosXL9bFixdvfG8AAEBQ6NaWwcXFxXK73Ro7dqwuXryop59+WpMnT9bRo0fVu3dv37h58+Zp5cqVvtu9evXy/dzc3KzMzEw5nU7t3btX1dXVevjhh9W9e3f9+c9/boddAgAAnV2bAmX79u1+t9evX6/o6GiVlpZqwoQJvvW9evWS0+m84mN88sknOnr0qHbu3KmYmBjdcccdev7557VkyRI9++yzCg8Pv47daF+Dln5w1W3frsrswJkAANA13dB7UOrq6iRJUVFRfus3bNigfv36aeTIkcrNzdX58+d920pKSjRq1CjFxMT41qWnp8vr9erIkSNXfJ7GxkZ5vV6/BQAABK82XUH5v1paWrRw4ULdfffdGjlypG/9gw8+qISEBMXFxengwYNasmSJysvL9f7770uSPB6PX5xI8t32eDxXfK68vDw999xz1ztVAADQyVx3oLjdbh0+fFifffaZ3/r58+f7fh41apRiY2OVmpqqEydOaOjQodf1XLm5ucrJyfHd9nq9io+Pv76JAwAA413XSzzZ2dkqLCzUp59+qgEDBlxzbEpKiiTp+PHjkiSn06mamhq/MZduX+19KzabTXa73W8BAADBq02BYlmWsrOztXnzZu3atUuDBw9u9T5lZWWSpNjYWEmSy+XSoUOHVFtb6xuzY8cO2e12JSUltWU6AAAgSLXpJR63262CggJt3bpVERERvveMOBwO9ezZUydOnFBBQYGmTp2qvn376uDBg1q0aJEmTJig0aNHS5ImT56spKQkPfTQQ1q9erU8Ho+WLVsmt9stm83W/nsIAAA6nTZdQVmzZo3q6uo0ceJExcbG+pZ3331XkhQeHq6dO3dq8uTJSkxM1BNPPKGsrCxt27bN9xhhYWEqLCxUWFiYXC6Xfve73+nhhx/2+94UAADQtbXpCoplWdfcHh8fr+Li4lYfJyEhQR9++GFbnhoAAHQh/C0eAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGKdNgZKXl6exY8cqIiJC0dHRmjFjhsrLy/3GNDQ0yO12q2/fvurTp4+ysrJUU1PjN6ayslKZmZnq1auXoqOjtXjxYl28ePHG9wYAAASFNgVKcXGx3G639u3bpx07dqipqUmTJ09WfX29b8yiRYu0bds2bdq0ScXFxaqqqtLMmTN925ubm5WZmakLFy5o7969euutt7R+/XotX768/fYKAAB0at3aMnj79u1+t9evX6/o6GiVlpZqwoQJqqur05tvvqmCggJNmjRJkrRu3TqNGDFC+/bt0/jx4/XJJ5/o6NGj2rlzp2JiYnTHHXfo+eef15IlS/Tss88qPDy8/fYOAAB0Sjf0HpS6ujpJUlRUlCSptLRUTU1NSktL841JTEzUwIEDVVJSIkkqKSnRqFGjFBMT4xuTnp4ur9erI0eOXPF5Ghsb5fV6/RYAABC8rjtQWlpatHDhQt19990aOXKkJMnj8Sg8PFyRkZF+Y2NiYuTxeHxj/m+cXNp+aduV5OXlyeFw+Jb4+PjrnTYAAOgErjtQ3G63Dh8+rI0bN7bnfK4oNzdXdXV1vuXkyZM3/TkBAEDgtOk9KJdkZ2ersLBQe/bs0YABA3zrnU6nLly4oDNnzvhdRampqZHT6fSN+fzzz/0e79KnfC6N+TGbzSabzXY9Uw1qg5Z+cNVt367K7MCZAADQvtp0BcWyLGVnZ2vz5s3atWuXBg8e7Lc9OTlZ3bt3V1FRkW9deXm5Kisr5XK5JEkul0uHDh1SbW2tb8yOHTtkt9uVlJR0I/sCAACCRJuuoLjdbhUUFGjr1q2KiIjwvWfE4XCoZ8+ecjgcmjt3rnJychQVFSW73a7HH39cLpdL48ePlyRNnjxZSUlJeuihh7R69Wp5PB4tW7ZMbrebqyQAAEBSGwNlzZo1kqSJEyf6rV+3bp0eeeQRSdLLL7+s0NBQZWVlqbGxUenp6XrjjTd8Y8PCwlRYWKgFCxbI5XKpd+/emjNnjlauXHljewIAAIJGmwLFsqxWx/To0UP5+fnKz8+/6piEhAR9+OGHbXlqAADQhfC3eAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGCcNgfKnj17NG3aNMXFxSkkJERbtmzx2/7II48oJCTEb5kyZYrfmNOnT2v27Nmy2+2KjIzU3Llzde7cuRvaEQAAEDzaHCj19fUaM2aM8vPzrzpmypQpqq6u9i3vvPOO3/bZs2fryJEj2rFjhwoLC7Vnzx7Nnz+/7bMHAABBqVtb75CRkaGMjIxrjrHZbHI6nVfc9vXXX2v79u364osvdNddd0mSXn/9dU2dOlUvvvii4uLi2jolAAAQZG7Ke1B2796t6OhoDR8+XAsWLNCpU6d820pKShQZGemLE0lKS0tTaGio9u/ff8XHa2xslNfr9VsAAEDwavdAmTJlit5++20VFRXpL3/5i4qLi5WRkaHm5mZJksfjUXR0tN99unXrpqioKHk8nis+Zl5enhwOh2+Jj49v72kDAACDtPklntbMmjXL9/OoUaM0evRoDR06VLt371Zqaup1PWZubq5ycnJ8t71eL5ECAEAQu+kfMx4yZIj69eun48ePS5KcTqdqa2v9xly8eFGnT5++6vtWbDab7Ha73wIAAILXTQ+U7777TqdOnVJsbKwkyeVy6cyZMyotLfWN2bVrl1paWpSSknKzpwMAADqBNr/Ec+7cOd/VEEmqqKhQWVmZoqKiFBUVpeeee05ZWVlyOp06ceKEnnrqKQ0bNkzp6emSpBEjRmjKlCmaN2+e1q5dq6amJmVnZ2vWrFl8ggcAAEi6jisoBw4c0J133qk777xTkpSTk6M777xTy5cvV1hYmA4ePKjf/OY3uu222zR37lwlJyfrn//8p2w2m+8xNmzYoMTERKWmpmrq1Km655579Ne//rX99goAAHRqbb6CMnHiRFmWddXtH3/8cauPERUVpYKCgrY+NQAA6CL4WzwAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDjdAj0BdLxBSz+46rZvV2V24EwAALgyrqAAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzT5kDZs2ePpk2bpri4OIWEhGjLli1+2y3L0vLlyxUbG6uePXsqLS1Nx44d8xtz+vRpzZ49W3a7XZGRkZo7d67OnTt3QzsCAACCR5sDpb6+XmPGjFF+fv4Vt69evVqvvfaa1q5dq/3796t3795KT09XQ0ODb8zs2bN15MgR7dixQ4WFhdqzZ4/mz59//XsBAACCSre23iEjI0MZGRlX3GZZll555RUtW7ZM06dPlyS9/fbbiomJ0ZYtWzRr1ix9/fXX2r59u7744gvdddddkqTXX39dU6dO1Ysvvqi4uLgb2B0AABAM2vU9KBUVFfJ4PEpLS/OtczgcSklJUUlJiSSppKREkZGRvjiRpLS0NIWGhmr//v1XfNzGxkZ5vV6/BQAABK92DRSPxyNJiomJ8VsfExPj2+bxeBQdHe23vVu3boqKivKN+bG8vDw5HA7fEh8f357TBgAAhukUn+LJzc1VXV2dbzl58mSgpwQAAG6idg0Up9MpSaqpqfFbX1NT49vmdDpVW1vrt/3ixYs6ffq0b8yP2Ww22e12vwUAAASvdg2UwYMHy+l0qqioyLfO6/Vq//79crlckiSXy6UzZ86otLTUN2bXrl1qaWlRSkpKe04HAAB0Um3+FM+5c+d0/Phx3+2KigqVlZUpKipKAwcO1MKFC/WnP/1Jt956qwYPHqxnnnlGcXFxmjFjhiRpxIgRmjJliubNm6e1a9eqqalJ2dnZmjVrFp/gAQAAkq4jUA4cOKBf/vKXvts5OTmSpDlz5mj9+vV66qmnVF9fr/nz5+vMmTO65557tH37dvXo0cN3nw0bNig7O1upqakKDQ1VVlaWXnvttXbYHQAAEAzaHCgTJ06UZVlX3R4SEqKVK1dq5cqVVx0TFRWlgoKCtj41AADoIjrFp3gAAEDXQqAAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwTrdATwCdx6ClH1x127erMjtwJgCAYMcVFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGCcdg+UZ599ViEhIX5LYmKib3tDQ4Pcbrf69u2rPn36KCsrSzU1Ne09DQAA0IndlCsot99+u6qrq33LZ5995tu2aNEibdu2TZs2bVJxcbGqqqo0c+bMmzENAADQSd2Uv8XTrVs3OZ3Oy9bX1dXpzTffVEFBgSZNmiRJWrdunUaMGKF9+/Zp/PjxN2M6AACgk7kpV1COHTumuLg4DRkyRLNnz1ZlZaUkqbS0VE1NTUpLS/ONTUxM1MCBA1VSUnLVx2tsbJTX6/VbAABA8Gr3QElJSdH69eu1fft2rVmzRhUVFbr33nt19uxZeTwehYeHKzIy0u8+MTEx8ng8V33MvLw8ORwO3xIfH9/e0wYAAAZp95d4MjIyfD+PHj1aKSkpSkhI0HvvvaeePXte12Pm5uYqJyfHd9vr9RIpAAAEsZv+MePIyEjddtttOn78uJxOpy5cuKAzZ874jampqbnie1YusdlsstvtfgsAAAheNz1Qzp07pxMnTig2NlbJycnq3r27ioqKfNvLy8tVWVkpl8t1s6cCAAA6iXZ/iefJJ5/UtGnTlJCQoKqqKq1YsUJhYWF64IEH5HA4NHfuXOXk5CgqKkp2u12PP/64XC4Xn+ABAAA+7R4o3333nR544AGdOnVK/fv31z333KN9+/apf//+kqSXX35ZoaGhysrKUmNjo9LT0/XGG2+09zQAAEAn1u6BsnHjxmtu79Gjh/Lz85Wfn9/eT40gM2jpB1fd9u2qzA6cCQCgo/G3eAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgnHb/mDEAoHPhI/0wEVdQAACAcbiCAgQQ/3IFgk9n+u/a5LkSKAg61/sf3LXu19p9AQDti0ABAOAKTL660BXwHhQAAGAcrqAAuGmC4V+gN+Mlw86y70AgEShAF9KVf2l2hX3v6H3sCse0MwjW/x14iQcAABiHQAEAAMbhJR4AnUqwXs4GOLf9ESgAgE6BX+BdC4ECtAP+jxMA2hfvQQEAAMbhCgoAIKhxhbNz4goKAAAwDldQAMAw/IsfIFAA/AT8wgTQ0QgU3HT8cmt/HFMAwY73oAAAAOMQKAAAwDi8xAMA6FC8RImfgisoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADBOQAMlPz9fgwYNUo8ePZSSkqLPP/88kNMBAACGCFigvPvuu8rJydGKFSv05ZdfasyYMUpPT1dtbW2gpgQAAAwRsEB56aWXNG/ePD366KNKSkrS2rVr1atXL/39738P1JQAAIAhugXiSS9cuKDS0lLl5ub61oWGhiotLU0lJSWXjW9sbFRjY6Pvdl1dnSTJ6/XelPm1NJ6/6rZrPSf3C977BeI5u/L9rqWz7MON7HtnmSv369z3u9H7Xo9Lj2lZVuuDrQD4z3/+Y0my9u7d67d+8eLF1rhx4y4bv2LFCksSCwsLCwsLSxAsJ0+ebLUVAnIFpa1yc3OVk5Pju93S0qLTp0+rb9++CgkJuWy81+tVfHy8Tp48Kbvd3pFT7TQ4Rq3jGF0bx6d1HKPWcYyuLdiOj2VZOnv2rOLi4lodG5BA6devn8LCwlRTU+O3vqamRk6n87LxNptNNpvNb11kZGSrz2O324Pif9CbiWPUOo7RtXF8Wscxah3H6NqC6fg4HI6fNC4gb5INDw9XcnKyioqKfOtaWlpUVFQkl8sViCkBAACDBOwlnpycHM2ZM0d33XWXxo0bp1deeUX19fV69NFHAzUlAABgiIAFyv3336/vv/9ey5cvl8fj0R133KHt27crJibmhh/bZrNpxYoVl70shP+PY9Q6jtG1cXxaxzFqHcfo2rry8QmxrJ/yWR8AAICOw9/iAQAAxiFQAACAcQgUAABgHAIFAAAYJygDJT8/X4MGDVKPHj2UkpKizz//PNBTMsazzz6rkJAQvyUxMTHQ0wqYPXv2aNq0aYqLi1NISIi2bNnit92yLC1fvlyxsbHq2bOn0tLSdOzYscBMNkBaO0aPPPLIZefUlClTAjPZAMjLy9PYsWMVERGh6OhozZgxQ+Xl5X5jGhoa5Ha71bdvX/Xp00dZWVmXfVFlMPspx2jixImXnUePPfZYgGbc8dasWaPRo0f7vpDN5XLpo48+8m3viudQ0AXKu+++q5ycHK1YsUJffvmlxowZo/T0dNXW1gZ6asa4/fbbVV1d7Vs+++yzQE8pYOrr6zVmzBjl5+dfcfvq1av12muvae3atdq/f7969+6t9PR0NTQ0dPBMA6e1YyRJU6ZM8Tun3nnnnQ6cYWAVFxfL7XZr37592rFjh5qamjR58mTV19f7xixatEjbtm3Tpk2bVFxcrKqqKs2cOTOAs+5YP+UYSdK8efP8zqPVq1cHaMYdb8CAAVq1apVKS0t14MABTZo0SdOnT9eRI0ckddFzqF3++p9Bxo0bZ7ndbt/t5uZmKy4uzsrLywvgrMyxYsUKa8yYMYGehpEkWZs3b/bdbmlpsZxOp/XCCy/41p05c8ay2WzWO++8E4AZBt6Pj5FlWdacOXOs6dOnB2Q+JqqtrbUkWcXFxZZl/e+c6d69u7Vp0ybfmK+//tqSZJWUlARqmgH142NkWZb1i1/8wvrjH/8YuEkZ6JZbbrH+9re/ddlzKKiuoFy4cEGlpaVKS0vzrQsNDVVaWppKSkoCODOzHDt2THFxcRoyZIhmz56tysrKQE/JSBUVFfJ4PH7nk8PhUEpKCufTj+zevVvR0dEaPny4FixYoFOnTgV6SgFTV1cnSYqKipIklZaWqqmpye88SkxM1MCBA7vsefTjY3TJhg0b1K9fP40cOVK5ubk6f/58IKYXcM3Nzdq4caPq6+vlcrm67DnUKf6a8U/13//+V83NzZd9G21MTIy++eabAM3KLCkpKVq/fr2GDx+u6upqPffcc7r33nt1+PBhRUREBHp6RvF4PJJ0xfPp0jb87+WdmTNnavDgwTpx4oSefvppZWRkqKSkRGFhYYGeXodqaWnRwoULdffdd2vkyJGS/ncehYeHX/YHTrvqeXSlYyRJDz74oBISEhQXF6eDBw9qyZIlKi8v1/vvvx/A2XasQ4cOyeVyqaGhQX369NHmzZuVlJSksrKyLnkOBVWgoHUZGRm+n0ePHq2UlBQlJCTovffe09y5cwM4M3RWs2bN8v08atQojR49WkOHDtXu3buVmpoawJl1PLfbrcOHD3fp93W15mrHaP78+b6fR40apdjYWKWmpurEiRMaOnRoR08zIIYPH66ysjLV1dXpH//4h+bMmaPi4uJATytgguolnn79+iksLOyydzbX1NTI6XQGaFZmi4yM1G233abjx48HeirGuXTOcD61zZAhQ9SvX78ud05lZ2ersLBQn376qQYMGOBb73Q6deHCBZ05c8ZvfFc8j652jK4kJSVFkrrUeRQeHq5hw4YpOTlZeXl5GjNmjF599dUuew4FVaCEh4crOTlZRUVFvnUtLS0qKiqSy+UK4MzMde7cOZ04cUKxsbGBnopxBg8eLKfT6Xc+eb1e7d+/n/PpGr777judOnWqy5xTlmUpOztbmzdv1q5duzR48GC/7cnJyerevbvfeVReXq7Kysoucx61doyupKysTJK6zHl0JS0tLWpsbOy651Cg36Xb3jZu3GjZbDZr/fr11tGjR6358+dbkZGRlsfjCfTUjPDEE09Yu3fvtioqKqx//etfVlpamtWvXz+rtrY20FMLiLNnz1pfffWV9dVXX1mSrJdeesn66quvrH//+9+WZVnWqlWrrMjISGvr1q3WwYMHrenTp1uDBw+2fvjhhwDPvONc6xidPXvWevLJJ62SkhKroqLC2rlzp/Xzn//cuvXWW62GhoZAT71DLFiwwHI4HNbu3but6upq33L+/HnfmMcee8waOHCgtWvXLuvAgQOWy+WyXC5XAGfdsVo7RsePH7dWrlxpHThwwKqoqLC2bt1qDRkyxJowYUKAZ95xli5dahUXF1sVFRXWwYMHraVLl1ohISHWJ598YllW1zyHgi5QLMuyXn/9dWvgwIFWeHi4NW7cOGvfvn2BnpIx7r//fis2NtYKDw+3fvazn1n333+/dfz48UBPK2A+/fRTS9Jly5w5cyzL+t9HjZ955hkrJibGstlsVmpqqlVeXh7YSXewax2j8+fPW5MnT7b69+9vde/e3UpISLDmzZvXpf5BcKVjI8lat26db8wPP/xg/eEPf7BuueUWq1evXtZvf/tbq7q6OnCT7mCtHaPKykprwoQJVlRUlGWz2axhw4ZZixcvturq6gI78Q70+9//3kpISLDCw8Ot/v37W6mpqb44sayueQ6FWJZlddz1GgAAgNYF1XtQAABAcCBQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGOf/Ab6LlksYuhZ8AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(mkbatch(2048)[1].to(torch.float32).cpu(), bins=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([1157.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,  570.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,    0.,  210.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,   86.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,   20.,    0.,    0.,\n",
       "           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,\n",
       "           5.]),\n",
       " array([1.      , 1.078125, 1.15625 , 1.234375, 1.3125  , 1.390625,\n",
       "        1.46875 , 1.546875, 1.625   , 1.703125, 1.78125 , 1.859375,\n",
       "        1.9375  , 2.015625, 2.09375 , 2.171875, 2.25    , 2.328125,\n",
       "        2.40625 , 2.484375, 2.5625  , 2.640625, 2.71875 , 2.796875,\n",
       "        2.875   , 2.953125, 3.03125 , 3.109375, 3.1875  , 3.265625,\n",
       "        3.34375 , 3.421875, 3.5     , 3.578125, 3.65625 , 3.734375,\n",
       "        3.8125  , 3.890625, 3.96875 , 4.046875, 4.125   , 4.203125,\n",
       "        4.28125 , 4.359375, 4.4375  , 4.515625, 4.59375 , 4.671875,\n",
       "        4.75    , 4.828125, 4.90625 , 4.984375, 5.0625  , 5.140625,\n",
       "        5.21875 , 5.296875, 5.375   , 5.453125, 5.53125 , 5.609375,\n",
       "        5.6875  , 5.765625, 5.84375 , 5.921875, 6.      ]),\n",
       " <BarContainer object of 64 artists>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGeCAYAAACKDztsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAilklEQVR4nO3df1BVdf7H8ReI/FDhIrbcy13R2N1WJU1NjG5aW8mISu46sZVFxqojOy1YZj+USu2HibrbL1rDdFp1Jh2rndHKiiIt3ZQQcSl/ZbZZ0o8LO4vcGzQiyvn+8R3PdEs3bS9cPrfnY+bMdM/53Hve9zYNzw73XiIsy7IEAABgkMhQDwAAAHCuCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcaJCPUBHaW9v15dffqn4+HhFRESEehwAAHAWLMvS119/LbfbrcjI/3KdxTpHW7duta655horJSXFkmRt2LDBPnb8+HHrnnvusQYPHmz16NHDSklJsaZMmWJ98cUXAY/xn//8x7rpppus+Ph4y+FwWNOmTbO+/vrrgDXvv/++NXr0aCsmJsbq27evtWTJknOas66uzpLExsbGxsbGZuBWV1f3X3/On/MVmJaWFg0dOlTTpk3TtddeG3Dsm2++0e7duzVv3jwNHTpUR48e1e23367f/va32rVrl70uLy9PX331lSoqKtTW1qapU6eqoKBA69atkyT5/X6NHTtWWVlZWr58ufbs2aNp06YpMTFRBQUFZzVnfHy8JKmurk4JCQnn+jQBAEAI+P1+paam2j/HzyTCsn78H3OMiIjQhg0bNGnSpDOuqa6u1iWXXKLPPvtM/fr104EDB5Senq7q6mplZGRIksrLyzVhwgR9/vnncrvdKisr03333Sev16vo6GhJ0ty5c7Vx40Z9+OGHZzWb3++Xw+GQz+cjYAAAMMTZ/vzu8Dfx+nw+RUREKDExUZJUWVmpxMREO14kKSsrS5GRkaqqqrLXXHHFFXa8SFJ2drYOHjyoo0ePnvY8ra2t8vv9ARsAAAhPHRowx44d05w5c3TjjTfaFeX1epWcnBywLioqSklJSfJ6vfYap9MZsObU7VNrvqukpEQOh8PeUlNTg/10AABAF9FhAdPW1qbrr79elmWprKyso05jKy4uls/ns7e6uroOPycAAAiNDvkY9al4+eyzz7Rly5aA32G5XC41NDQErD9x4oQaGxvlcrnsNfX19QFrTt0+tea7YmJiFBMTE8ynAQAAuqigX4E5FS+HDh3SW2+9pT59+gQc93g8ampqUk1Njb1vy5Ytam9vV2Zmpr1m27Ztamtrs9dUVFRowIAB6t27d7BHBgAAhjnngGlublZtba1qa2slSYcPH1Ztba2OHDmitrY2/f73v9euXbu0du1anTx5Ul6vV16vV8ePH5ckDRo0SOPGjdOMGTO0c+dObd++XUVFRZo8ebLcbrck6aabblJ0dLSmT5+uffv26fnnn9eTTz6p2bNnB++ZAwAAY53zx6jfeecdXXXVVd/bn5+frwceeEBpaWmnvd/bb7+tK6+8UpLU2NiooqIivfLKK4qMjFRubq5KS0vVq1cve/0HH3ygwsJCVVdX67zzztPMmTM1Z86cs56Tj1EDAGCes/35/T99D0xXRsAAAGCeLvM9MAAAAMFGwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA43TInxIId+fPffWMxz5dnNOJkwAA8NPEFRgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxjnngNm2bZsmTpwot9utiIgIbdy4MeC4ZVmaP3++UlJSFBcXp6ysLB06dChgTWNjo/Ly8pSQkKDExERNnz5dzc3NAWs++OADXX755YqNjVVqaqqWLl167s8OAACEpXMOmJaWFg0dOlTLli077fGlS5eqtLRUy5cvV1VVlXr27Kns7GwdO3bMXpOXl6d9+/apoqJCmzZt0rZt21RQUGAf9/v9Gjt2rPr376+amhr9+c9/1gMPPKAVK1b8iKcIAADCTYRlWdaPvnNEhDZs2KBJkyZJ+v+rL263W3feeafuuusuSZLP55PT6dTq1as1efJkHThwQOnp6aqurlZGRoYkqby8XBMmTNDnn38ut9utsrIy3XffffJ6vYqOjpYkzZ07Vxs3btSHH354VrP5/X45HA75fD4lJCT82Kd4WufPffWMxz5dnBPUcwEA8FNytj+/g/oemMOHD8vr9SorK8ve53A4lJmZqcrKSklSZWWlEhMT7XiRpKysLEVGRqqqqspec8UVV9jxIknZ2dk6ePCgjh49etpzt7a2yu/3B2wAACA8BTVgvF6vJMnpdAbsdzqd9jGv16vk5OSA41FRUUpKSgpYc7rH+PY5vqukpEQOh8PeUlNT//cnBAAAuqSw+RRScXGxfD6fvdXV1YV6JAAA0EGCGjAul0uSVF9fH7C/vr7ePuZyudTQ0BBw/MSJE2psbAxYc7rH+PY5vismJkYJCQkBGwAACE9BDZi0tDS5XC5t3rzZ3uf3+1VVVSWPxyNJ8ng8ampqUk1Njb1my5Ytam9vV2Zmpr1m27Ztamtrs9dUVFRowIAB6t27dzBHBgAABjrngGlublZtba1qa2sl/f8bd2tra3XkyBFFRERo1qxZWrhwoV5++WXt2bNHt9xyi9xut/1JpUGDBmncuHGaMWOGdu7cqe3bt6uoqEiTJ0+W2+2WJN10002Kjo7W9OnTtW/fPj3//PN68sknNXv27KA9cQAAYK6oc73Drl27dNVVV9m3T0VFfn6+Vq9erXvuuUctLS0qKChQU1OTRo8erfLycsXGxtr3Wbt2rYqKijRmzBhFRkYqNzdXpaWl9nGHw6E333xThYWFGjFihM477zzNnz8/4LtiAADAT9f/9D0wXRnfAwMAgHlC8j0wAAAAnYGAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxgl6wJw8eVLz5s1TWlqa4uLi9Mtf/lIPP/ywLMuy11iWpfnz5yslJUVxcXHKysrSoUOHAh6nsbFReXl5SkhIUGJioqZPn67m5uZgjwsAAAwU9IBZsmSJysrK9Ne//lUHDhzQkiVLtHTpUj311FP2mqVLl6q0tFTLly9XVVWVevbsqezsbB07dsxek5eXp3379qmiokKbNm3Stm3bVFBQEOxxAQCAgSKsb18aCYJrrrlGTqdTzz77rL0vNzdXcXFxeu6552RZltxut+68807dddddkiSfzyen06nVq1dr8uTJOnDggNLT01VdXa2MjAxJUnl5uSZMmKDPP/9cbrf7B+fw+/1yOBzy+XxKSEgI5lPU+XNfPeOxTxfnBPVcAAD8lJztz++gX4G57LLLtHnzZn300UeSpPfff1/vvvuuxo8fL0k6fPiwvF6vsrKy7Ps4HA5lZmaqsrJSklRZWanExEQ7XiQpKytLkZGRqqqqOu15W1tb5ff7AzYAABCeooL9gHPnzpXf79fAgQPVrVs3nTx5Uo888ojy8vIkSV6vV5LkdDoD7ud0Ou1jXq9XycnJgYNGRSkpKcle810lJSV68MEHg/10AABAFxT0KzAvvPCC1q5dq3Xr1mn37t1as2aN/vKXv2jNmjXBPlWA4uJi+Xw+e6urq+vQ8wEAgNAJ+hWYu+++W3PnztXkyZMlSUOGDNFnn32mkpIS5efny+VySZLq6+uVkpJi36++vl7Dhg2TJLlcLjU0NAQ87okTJ9TY2Gjf/7tiYmIUExMT7KcDAAC6oKBfgfnmm28UGRn4sN26dVN7e7skKS0tTS6XS5s3b7aP+/1+VVVVyePxSJI8Ho+amppUU1Njr9myZYva29uVmZkZ7JEBAIBhgn4FZuLEiXrkkUfUr18/XXjhhfrnP/+pxx57TNOmTZMkRUREaNasWVq4cKEuuOACpaWlad68eXK73Zo0aZIkadCgQRo3bpxmzJih5cuXq62tTUVFRZo8efJZfQIJAACEt6AHzFNPPaV58+bpT3/6kxoaGuR2u/XHP/5R8+fPt9fcc889amlpUUFBgZqamjR69GiVl5crNjbWXrN27VoVFRVpzJgxioyMVG5urkpLS4M9LgAAMFDQvwemq+B7YAAAME/IvgcGAACgoxEwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOB0SMF988YVuvvlm9enTR3FxcRoyZIh27dplH7csS/Pnz1dKSori4uKUlZWlQ4cOBTxGY2Oj8vLylJCQoMTERE2fPl3Nzc0dMS4AADBM0APm6NGjGjVqlLp3767XX39d+/fv16OPPqrevXvba5YuXarS0lItX75cVVVV6tmzp7Kzs3Xs2DF7TV5envbt26eKigpt2rRJ27ZtU0FBQbDHBQAABoqwLMsK5gPOnTtX27dv1z/+8Y/THrcsS263W3feeafuuusuSZLP55PT6dTq1as1efJkHThwQOnp6aqurlZGRoYkqby8XBMmTNDnn38ut9v9g3P4/X45HA75fD4lJCQE7wlKOn/uq2c89uninKCeCwCAn5Kz/fkd9CswL7/8sjIyMnTdddcpOTlZw4cP18qVK+3jhw8fltfrVVZWlr3P4XAoMzNTlZWVkqTKykolJiba8SJJWVlZioyMVFVV1WnP29raKr/fH7ABAIDwFPSA+eSTT1RWVqYLLrhAb7zxhm699VbddtttWrNmjSTJ6/VKkpxOZ8D9nE6nfczr9So5OTngeFRUlJKSkuw131VSUiKHw2FvqampwX5qAACgiwh6wLS3t+viiy/WokWLNHz4cBUUFGjGjBlavnx5sE8VoLi4WD6fz97q6uo69HwAACB0gh4wKSkpSk9PD9g3aNAgHTlyRJLkcrkkSfX19QFr6uvr7WMul0sNDQ0Bx0+cOKHGxkZ7zXfFxMQoISEhYAMAAOEp6AEzatQoHTx4MGDfRx99pP79+0uS0tLS5HK5tHnzZvu43+9XVVWVPB6PJMnj8aipqUk1NTX2mi1btqi9vV2ZmZnBHhkAABgmKtgPeMcdd+iyyy7TokWLdP3112vnzp1asWKFVqxYIUmKiIjQrFmztHDhQl1wwQVKS0vTvHnz5Ha7NWnSJEn/f8Vm3Lhx9q+e2traVFRUpMmTJ5/VJ5AAAEB4C3rAjBw5Uhs2bFBxcbEeeughpaWl6YknnlBeXp695p577lFLS4sKCgrU1NSk0aNHq7y8XLGxsfaatWvXqqioSGPGjFFkZKRyc3NVWloa7HEBAICBgv49MF0F3wMDAIB5QvY9MAAAAB2NgAEAAMYJ+ntggB+DX8sBAM4FV2AAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgnA4PmMWLFysiIkKzZs2y9x07dkyFhYXq06ePevXqpdzcXNXX1wfc78iRI8rJyVGPHj2UnJysu+++WydOnOjocQEAgAE6NGCqq6v1zDPP6KKLLgrYf8cdd+iVV17Riy++qK1bt+rLL7/Utddeax8/efKkcnJydPz4ce3YsUNr1qzR6tWrNX/+/I4cFwAAGKLDAqa5uVl5eXlauXKlevfube/3+Xx69tln9dhjj+nqq6/WiBEjtGrVKu3YsUPvvfeeJOnNN9/U/v379dxzz2nYsGEaP368Hn74YS1btkzHjx/vqJEBAIAhOixgCgsLlZOTo6ysrID9NTU1amtrC9g/cOBA9evXT5WVlZKkyspKDRkyRE6n016TnZ0tv9+vffv2ddTIAADAEFEd8aDr16/X7t27VV1d/b1jXq9X0dHRSkxMDNjvdDrl9XrtNd+Ol1PHTx07ndbWVrW2ttq3/X7///IUAABAFxb0KzB1dXW6/fbbtXbtWsXGxgb74c+opKREDofD3lJTUzvt3AAAoHMFPWBqamrU0NCgiy++WFFRUYqKitLWrVtVWlqqqKgoOZ1OHT9+XE1NTQH3q6+vl8vlkiS5XK7vfSrp1O1Ta76ruLhYPp/P3urq6oL91AAAQBcR9IAZM2aM9uzZo9raWnvLyMhQXl6e/c/du3fX5s2b7fscPHhQR44ckcfjkSR5PB7t2bNHDQ0N9pqKigolJCQoPT39tOeNiYlRQkJCwAYAAMJT0N8DEx8fr8GDBwfs69mzp/r06WPvnz59umbPnq2kpCQlJCRo5syZ8ng8uvTSSyVJY8eOVXp6uqZMmaKlS5fK6/Xq/vvvV2FhoWJiYoI9MgAAMEyHvIn3hzz++OOKjIxUbm6uWltblZ2draeffto+3q1bN23atEm33nqrPB6Pevbsqfz8fD300EOhGBcAAHQxnRIw77zzTsDt2NhYLVu2TMuWLTvjffr376/XXnutgycDAAAm4m8hAQAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOEEPmJKSEo0cOVLx8fFKTk7WpEmTdPDgwYA1x44dU2Fhofr06aNevXopNzdX9fX1AWuOHDminJwc9ejRQ8nJybr77rt14sSJYI8LAAAMFBXsB9y6dasKCws1cuRInThxQvfee6/Gjh2r/fv3q2fPnpKkO+64Q6+++qpefPFFORwOFRUV6dprr9X27dslSSdPnlROTo5cLpd27Nihr776Srfccou6d++uRYsWBXtk4Cfn/LmvnvHYp4tzOnESAPhxgh4w5eXlAbdXr16t5ORk1dTU6IorrpDP59Ozzz6rdevW6eqrr5YkrVq1SoMGDdJ7772nSy+9VG+++ab279+vt956S06nU8OGDdPDDz+sOXPm6IEHHlB0dHSwxwYAAAbp8PfA+Hw+SVJSUpIkqaamRm1tbcrKyrLXDBw4UP369VNlZaUkqbKyUkOGDJHT6bTXZGdny+/3a9++fac9T2trq/x+f8AGAADCU4cGTHt7u2bNmqVRo0Zp8ODBkiSv16vo6GglJiYGrHU6nfJ6vfaab8fLqeOnjp1OSUmJHA6HvaWmpgb52QAAgK6iQwOmsLBQe/fu1fr16zvyNJKk4uJi+Xw+e6urq+vwcwIAgNAI+ntgTikqKtKmTZu0bds29e3b197vcrl0/PhxNTU1BVyFqa+vl8vlstfs3Lkz4PFOfUrp1JrviomJUUxMTJCfBQAA6IqCfgXGsiwVFRVpw4YN2rJli9LS0gKOjxgxQt27d9fmzZvtfQcPHtSRI0fk8XgkSR6PR3v27FFDQ4O9pqKiQgkJCUpPTw/2yAAAwDBBvwJTWFiodevW6aWXXlJ8fLz9nhWHw6G4uDg5HA5Nnz5ds2fPVlJSkhISEjRz5kx5PB5deumlkqSxY8cqPT1dU6ZM0dKlS+X1enX//fersLCQqywAACD4AVNWViZJuvLKKwP2r1q1Sn/4wx8kSY8//rgiIyOVm5ur1tZWZWdn6+mnn7bXduvWTZs2bdKtt94qj8ejnj17Kj8/Xw899FCwxwUAAAYKesBYlvWDa2JjY7Vs2TItW7bsjGv69++v1157LZijAQCAMMHfQgIAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGCcq1AMAQDg7f+6rZzz26eKcTpwECC9cgQEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgnKhQDwAAQLCcP/fVMx77dHFOJ06CjsYVGAAAYJwuHTDLli3T+eefr9jYWGVmZmrnzp2hHgkAAHQBXTZgnn/+ec2ePVsLFizQ7t27NXToUGVnZ6uhoSHUowEAgBDrsgHz2GOPacaMGZo6darS09O1fPly9ejRQ3/7299CPRoAAAixLvkm3uPHj6umpkbFxcX2vsjISGVlZamysvK092ltbVVra6t92+fzSZL8fn/Q52tv/eaMxzrifD8FvKadi9e78/Bady5e7+AavOCNMx7b+2B2h5zz1L8ny7L++0KrC/riiy8sSdaOHTsC9t99993WJZdcctr7LFiwwJLExsbGxsbGFgZbXV3df22FLnkF5scoLi7W7Nmz7dvt7e1qbGxUnz59FBEREbTz+P1+paamqq6uTgkJCUF7XHwfr3Xn4HXuHLzOnYPXuXN05OtsWZa+/vprud3u/7quSwbMeeedp27duqm+vj5gf319vVwu12nvExMTo5iYmIB9iYmJHTWiEhIS+I+jk/Badw5e587B69w5eJ07R0e9zg6H4wfXdMk38UZHR2vEiBHavHmzva+9vV2bN2+Wx+MJ4WQAAKAr6JJXYCRp9uzZys/PV0ZGhi655BI98cQTamlp0dSpU0M9GgAACLEuGzA33HCD/v3vf2v+/Pnyer0aNmyYysvL5XQ6QzpXTEyMFixY8L1fVyH4eK07B69z5+B17hy8zp2jK7zOEZb1Q59TAgAA6Fq65HtgAAAA/hsCBgAAGIeAAQAAxiFgAACAcQiYc7Bt2zZNnDhRbrdbERER2rhxY6hHCjslJSUaOXKk4uPjlZycrEmTJungwYOhHivslJWV6aKLLrK/hMrj8ej1118P9Vhhb/HixYqIiNCsWbNCPUrYeeCBBxQRERGwDRw4MNRjhaUvvvhCN998s/r06aO4uDgNGTJEu3bt6vQ5CJhz0NLSoqFDh2rZsmWhHiVsbd26VYWFhXrvvfdUUVGhtrY2jR07Vi0tLaEeLaz07dtXixcvVk1NjXbt2qWrr75av/vd77Rv375Qjxa2qqur9cwzz+iiiy4K9Shh68ILL9RXX31lb++++26oRwo7R48e1ahRo9S9e3e9/vrr2r9/vx599FH17t2702fpst8D0xWNHz9e48ePD/UYYa28vDzg9urVq5WcnKyamhpdccUVIZoq/EycODHg9iOPPKKysjK99957uvDCC0M0Vfhqbm5WXl6eVq5cqYULF4Z6nLAVFRV1xj83g+BYsmSJUlNTtWrVKntfWlpaSGbhCgy6NJ/PJ0lKSkoK8STh6+TJk1q/fr1aWlr4Ux0dpLCwUDk5OcrKygr1KGHt0KFDcrvd+sUvfqG8vDwdOXIk1COFnZdfflkZGRm67rrrlJycrOHDh2vlypUhmYUrMOiy2tvbNWvWLI0aNUqDBw8O9ThhZ8+ePfJ4PDp27Jh69eqlDRs2KD09PdRjhZ3169dr9+7dqq6uDvUoYS0zM1OrV6/WgAED9NVXX+nBBx/U5Zdfrr179yo+Pj7U44WNTz75RGVlZZo9e7buvfdeVVdX67bbblN0dLTy8/M7dRYCBl1WYWGh9u7dy++xO8iAAQNUW1srn8+nv//978rPz9fWrVuJmCCqq6vT7bffroqKCsXGxoZ6nLD27V/vX3TRRcrMzFT//v31wgsvaPr06SGcLLy0t7crIyNDixYtkiQNHz5ce/fu1fLlyzs9YPgVErqkoqIibdq0SW+//bb69u0b6nHCUnR0tH71q19pxIgRKikp0dChQ/Xkk0+GeqywUlNTo4aGBl188cWKiopSVFSUtm7dqtLSUkVFRenkyZOhHjFsJSYm6te//rU+/vjjUI8SVlJSUr73PzmDBg0Kya/ruAKDLsWyLM2cOVMbNmzQO++8E7I3h/0Utbe3q7W1NdRjhJUxY8Zoz549AfumTp2qgQMHas6cOerWrVuIJgt/zc3N+te//qUpU6aEepSwMmrUqO99tcVHH32k/v37d/osBMw5aG5uDqj5w4cPq7a2VklJSerXr18IJwsfhYWFWrdunV566SXFx8fL6/VKkhwOh+Li4kI8XfgoLi7W+PHj1a9fP3399ddat26d3nnnHb3xxhuhHi2sxMfHf+/9Wz179lSfPn14X1eQ3XXXXZo4caL69++vL7/8UgsWLFC3bt104403hnq0sHLHHXfosssu06JFi3T99ddr586dWrFihVasWNH5w1g4a2+//bYl6Xtbfn5+qEcLG6d7fSVZq1atCvVoYWXatGlW//79rejoaOtnP/uZNWbMGOvNN98M9Vg/Cb/5zW+s22+/PdRjhJ0bbrjBSklJsaKjo62f//zn1g033GB9/PHHoR4rLL3yyivW4MGDrZiYGGvgwIHWihUrQjJHhGVZVudnEwAAwI/Hm3gBAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADG+T+/FGJF6kiv7QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(mktunebatch(2048)[1].to(torch.float32).cpu(), bins=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q3Cg_8UQep8g"
   },
   "source": [
    "# Step 2: Define Transformer Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "tLOWhg_CeWzH"
   },
   "outputs": [],
   "source": [
    "class TransformerModel(nn.Module):\n",
    "    def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, device, dropout):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(input_dim, model_dim//2, dtype=torch.bfloat16)\n",
    "        # seq_len is odd\n",
    "        self.fancy_encoding = torch.repeat_interleave(torch.rand((1, seq_len // 2 + 1, model_dim // 2), device=device, dtype=torch.bfloat16), 2, dim=1)\n",
    "        # cut off last element since the target vertex is not repeated\n",
    "        self.fancy_encoding = self.fancy_encoding[:, :seq_len, :]\n",
    "        \n",
    "        self.model_dim = model_dim\n",
    "        self.seq_len = seq_len\n",
    "        self.device = device\n",
    "\n",
    "        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads,\n",
    "                                                   dim_feedforward=model_dim*4,\n",
    "                                                   dropout=dropout, batch_first=True, dtype=torch.bfloat16)\n",
    "        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n",
    "\n",
    "        self.fc_out = nn.Linear(model_dim*seq_len, output_dim, dtype=torch.bfloat16)\n",
    "\n",
    "    def forward(self, src, key_padding_mask):\n",
    "        batch_size, src_len = src.size(0), src.size(1)\n",
    "        embed = self.embedding(src)\n",
    "        src = torch.cat((embed * sqrt(self.model_dim), self.fancy_encoding.repeat((batch_size, 1, 1))), dim=2)\n",
    "\n",
    "        output = self.transformer_encoder(src, src_key_padding_mask=key_padding_mask)\n",
    "        output[key_padding_mask] = 0 # Hack to stop no_grad problem\n",
    "        flat_output = torch.flatten(output, start_dim=1, end_dim=2)\n",
    "        output = self.fc_out(flat_output)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bpIeg86S-hBb"
   },
   "source": [
    "# Step 3: Make Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "kWXvJRDYgFVP",
    "outputId": "c13adb9d-6565-43b5-8437-20cef3dc0d16"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training data: 839M\n",
      "Trainable parameters in the model: 804K\n"
     ]
    }
   ],
   "source": [
    "# PARAMS\n",
    "VOCAB_SIZE = 1 + MAX_VTXS # one more than the max number of vertices\n",
    "MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
    "NEPOCHS = 100\n",
    "BSZ = 8196 # Batch size\n",
    "BPE = 1024 # Batches per epoch\n",
    "LR = 0.003\n",
    "WD = 0.002\n",
    "NHEADS = 4\n",
    "NLAYERS = 16\n",
    "DROPOUT = 0.2\n",
    "model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
    "                         output_dim=1, num_heads=NHEADS,\n",
    "                         num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
    "                         dropout=DROPOUT, device=device).to(device)\n",
    "model = torch.compile(model)\n",
    "\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)\n",
    "\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(f\"Training data: {NEPOCHS*BSZ*BPE//10**6}M\")\n",
    "print(f\"Trainable parameters in the model: {trainable_params//1000}K\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f8Zn33m7CxL5"
   },
   "source": [
    "# Step 4: Train the Model for the first task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc7a660a4efe4342bf4af23cecdea1b1",
       "version_major": 2,
       "version_minor": 0
      },
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAYAAAA10dzkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApCklEQVR4nO3df1TUdb7H8RegM1gKWggKjpLuqq0/oMVENNfcg1F6ad1TV9KukNdK03VNTm3aD7GtxDQ77ipKuVntLZPilreTRJcot6vRuqHsqc0fa/hrK1AyQTFBmM/9o9PsToCBwYzO5/k4Z/7gw+fLvMfvGs/9zg+DjDFGAAAAsEawvwcAAACAbxGAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgABwkQgKCtKvfvUrf48BIAAQgAAkSc8995yCgoIUFBSkbdu2Nfm+MUYul0tBQUH6t3/7N6/vnTp1SllZWRo6dKguvfRSXX755YqPj9f8+fP1+eefe/YtWbLEcx/N3SoqKjr8cZ7LuWabPXu2X2cDgPbUyd8DALiwhIaGauPGjbrmmmu81v/0pz/pH//4h5xOp9f62bNn9bOf/Ux79uxRRkaG5s2bp1OnTulvf/ubNm7cqF/+8peKjo72OmbdunXq2rVrk/vu3r17uz+etpowYYLS09ObrA8cONAP0wBAxyAAAXiZOHGiXnnlFf3+979Xp07//E/Exo0blZCQoKqqKq/9mzdv1q5du/Tiiy9q2rRpXt87c+aM6uvrm9zHzTffrIiIiI55AD/QwIED9R//8R/+HgMAOhRPAQPwMnXqVH355ZcqKiryrNXX1ys/P79J4EnSp59+KkkaM2ZMk++FhoYqLCysXeYaOnSoxo8f32Td7XYrJiZGN998s2dt06ZNSkhIULdu3RQWFqZhw4bpd7/7XbvMIUnXXnuthg4dqtLSUo0ePVpdunTRFVdcodzc3CZ7jx49qpkzZyoqKkqhoaGKi4vT888/3+zj+N3vfqdhw4YpNDRUPXv21PXXX68PP/ywyd7Nmzdr6NChcjqdGjJkiAoLC72+f/LkSd19992KjY2V0+lUZGSkJkyYoJ07d7bbnwGAixsBCMBLbGyskpKS9NJLL3nW3nzzTVVXV+uWW25psr9fv36SpD/+8Y8yxrTqPo4fP66qqiqv24kTJ855TFpamt57770mrxPctm2bPv/8c89sRUVFmjp1qnr06KHHH39cy5Yt07XXXqvt27e3arYzZ840ma2qqqrJlcyvvvpKEydOVEJCgpYvX64+ffrorrvu0oYNGzx7vv76a1177bX6r//6L916661asWKFwsPDddtttzUJ0pkzZ+ruu++Wy+XS448/roULFyo0NFQffPBBk8c7Z84c3XLLLVq+fLnOnDmjm266SV9++aVnz+zZs7Vu3TrddNNNWrt2re655x516dJFu3fvbtWfAQALGAAwxjz77LNGkvnLX/5i1qxZY7p162ZOnz5tjDHm3//938348eONMcb069fPTJo0yXPc6dOnzaBBg4wk069fP3PbbbeZZ555xlRWVja5j6ysLCOp2dugQYPOOd/evXuNJLN69Wqv9Tlz5piuXbt6Zp0/f74JCwszDQ0Nbf4zaGk2Seall17y7Bs3bpyRZFauXOlZq6urM/Hx8SYyMtLU19cbY4xZtWqVkWReeOEFz776+nqTlJRkunbtampqaowxxrzzzjtGkvn1r3/dZCa32+01n8PhMPv37/es/fWvf23y5xIeHm7mzp3b5scPwB5cAQTQxJQpU/T111/rjTfe0MmTJ/XGG280+/SvJHXp0kV//vOfde+990r65t3EM2fOVO/evTVv3jzV1dU1Oea///u/VVRU5HV79tlnzznTwIEDFR8fr7y8PM9aY2Oj8vPzlZqaqi5dukj65o0ktbW1Xk9ht8UvfvGLJrMVFRU1efq5U6dOmjVrludrh8OhWbNm6ejRoyotLZUkFRQUqFevXpo6dapnX+fOnfXrX/9ap06d0p/+9CfPn0dQUJCysrKazBMUFOT1dXJysgYMGOD5evjw4QoLC1N5eblnrXv37vrzn//s9Q5sAPhXvAkEQBM9e/ZUcnKyNm7cqNOnT6uxsdHrNXbfFR4eruXLl2v58uU6dOiQiouL9cQTT2jNmjUKDw/Xo48+6rX/Zz/72Xm9CSQtLU3333+/PvvsM8XExGjr1q06evSo0tLSPHvmzJmjl19+WTfccINiYmJ03XXXacqUKbr++utbdR99+vRRcnLy9+6Ljo7WpZde6rX27TuFDx48qFGjRunQoUP68Y9/rOBg7/+vfeWVV0qSDh06JOmb11FGR0frsssu+9777du3b5O1Hj166KuvvvJ8vXz5cmVkZMjlcikhIUETJ05Uenq6+vfv/70/H4AduAIIoFnTpk3Tm2++qdzcXN1www2t/oiWfv366T//8z+1fft2de/eXS+++GK7zZSWliZjjF555RVJ0ssvv6zw8HCvuIuMjFRZWZlef/113XjjjXr33Xd1ww03KCMjo93m8KeQkJBm182/vP5yypQpKi8v1+rVqxUdHa0VK1ZoyJAhevPNN301JoALHAEIoFm//OUvFRwcrA8++KDFp3/PpUePHhowYIC++OKLdpvpiiuu0MiRI5WXl6eGhga9+uqrmjx5cpPPJnQ4HEpNTdXatWv16aefatasWfrjH/+o/fv3t9ssn3/+uWpra73W9u3bJ+mbN9JI38Tw3//+d7ndbq99e/bs8XxfkgYMGKDPP/9cx48fb7f5evfurTlz5mjz5s06cOCALr/8cj322GPt9vMBXNwIQADN6tq1q9atW6clS5YoNTW1xX1//etfm3w2oPTN05uffPKJBg0a1K5zpaWl6YMPPtCGDRtUVVXl9fSvJK93w0pScHCwhg8fLknNvh7xfDU0NOipp57yfF1fX6+nnnpKPXv2VEJCgqRvPlOxoqLC63WLDQ0NWr16tbp27apx48ZJkm666SYZY/Twww83uR/TyndWf6uxsVHV1dVea5GRkYqOjm7Xxw/g4sZrAAG0qDVPmxYVFSkrK0s33nijRo0apa5du6q8vFwbNmxQXV2dlixZ0uSY/Pz8Zv8lkAkTJigqKuqc9zdlyhTdc889uueee3TZZZc1eb3e7bffruPHj+vnP/+5+vTpo0OHDmn16tWKj4/3vPbuXPbt26cXXnihyXpUVJQmTJjg+To6OlqPP/64Dh48qIEDByovL09lZWV6+umn1blzZ0nSnXfeqaeeekq33XabSktLFRsbq/z8fG3fvl2rVq1St27dJEnjx4/X9OnT9fvf/15///vfdf3118vtduv//u//NH78+Db9+78nT55Unz59dPPNNysuLk5du3bV22+/rb/85S9auXJlq38OgADn3zchA7hQ/OvHwJzLdz8Gpry83CxevNiMGjXKREZGmk6dOpmePXuaSZMmmXfeecfr2HN9DIwk8+6777Zq1jFjxhhJ5vbbb2/yvfz8fHPdddeZyMhI43A4TN++fc2sWbPMF1988b0/91yzjRs3zrNv3LhxZsiQIebDDz80SUlJJjQ01PTr18+sWbOmyc+srKw0M2bMMBEREcbhcJhhw4aZZ599tsm+hoYGs2LFCjN48GDjcDhMz549zQ033GBKS0u95mvu41369etnMjIyjDHffBzNvffea+Li4ky3bt3MpZdeauLi4szatWu/9/EDsEeQMW18fgEALHfttdeqqqpKH3/8sb9HAYDzwmsAAQAALEMAAgAAWIYABAAAsAyvAQQAALAMVwABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWCZgAfO+995Samqro6GgFBQVp8+bN33vM1q1b9dOf/lROp1M/+tGP9Nxzz3X4nAAAAP4WMAFYW1uruLg45eTktGr/gQMHNGnSJI0fP15lZWW6++67dfvtt+utt97q4EkBAAD8K8gYY/w9RHsLCgrSa6+9psmTJ7e457777tOWLVv08ccfe9ZuueUWnThxQoWFhT6YEgAAwD8C5gpgW5WUlCg5OdlrLSUlRSUlJX6aCAAAwDc6+XsAf6moqFBUVJTXWlRUlGpqavT111+rS5cuTY6pq6tTXV2d52u3263jx4/r8ssvV1BQUIfPDAAAfjhjjE6ePKno6GgFB9t5LczaADwf2dnZevjhh/09BgAAaAdHjhxRnz59/D2GX1gbgL169VJlZaXXWmVlpcLCwpq9+idJixYtUmZmpufr6upq9e3bV0eOHFFYWFiHzgsAANpHTU2NXC6XunXr5u9R/MbaAExKSlJBQYHXWlFRkZKSklo8xul0yul0NlkPCwsjAAEAuMjY/PKtgHni+9SpUyorK1NZWZmkbz7mpaysTIcPH5b0zdW79PR0z/7Zs2ervLxcv/nNb7Rnzx6tXbtWL7/8shYsWOCP8QEAAHwmYALwww8/1FVXXaWrrrpKkpSZmamrrrpKixcvliR98cUXnhiUpCuuuEJbtmxRUVGR4uLitHLlSv3hD39QSkqKX+YHAADwlYD8HEBfqampUXh4uKqrq3kKGACAiwS/vwPoCiAAAABahwAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYJqADMyclRbGysQkNDlZiYqB07dpxz/6pVqzRo0CB16dJFLpdLCxYs0JkzZ3w0LQAAgH8ETADm5eUpMzNTWVlZ2rlzp+Li4pSSkqKjR482u3/jxo1auHChsrKytHv3bj3zzDPKy8vT/fff7+PJAQAAfCtgAvDJJ5/UHXfcoRkzZugnP/mJcnNzdckll2jDhg3N7n///fc1ZswYTZs2TbGxsbruuus0derU771qCAAAcLELiACsr69XaWmpkpOTPWvBwcFKTk5WSUlJs8eMHj1apaWlnuArLy9XQUGBJk6c6JOZAQAA/KWTvwdoD1VVVWpsbFRUVJTXelRUlPbs2dPsMdOmTVNVVZWuueYaGWPU0NCg2bNnn/Mp4Lq6OtXV1Xm+rqmpaZ8HAAAA4EMBcQXwfGzdulVLly7V2rVrtXPnTr366qvasmWLHnnkkRaPyc7OVnh4uOfmcrl8ODEAAED7CDLGGH8P8UPV19frkksuUX5+viZPnuxZz8jI0IkTJ/Q///M/TY4ZO3asRo0apRUrVnjWXnjhBd155506deqUgoObtnFzVwBdLpeqq6sVFhbWvg8KAAB0iJqaGoWHh1v9+zsgrgA6HA4lJCSouLjYs+Z2u1VcXKykpKRmjzl9+nSTyAsJCZEktdTETqdTYWFhXjcAAICLTUC8BlCSMjMzlZGRoREjRmjkyJFatWqVamtrNWPGDElSenq6YmJilJ2dLUlKTU3Vk08+qauuukqJiYnav3+/HnroIaWmpnpCEAAAIBAFTACmpaXp2LFjWrx4sSoqKhQfH6/CwkLPG0MOHz7sdcXvwQcfVFBQkB588EF99tln6tmzp1JTU/XYY4/56yEAAAD4REC8BtBfeA0BAAAXH35/B8hrAAEAANB6BCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsExABWBOTo5iY2MVGhqqxMRE7dix45z7T5w4oblz56p3795yOp0aOHCgCgoKfDQtAACAf3Ty9wDtJS8vT5mZmcrNzVViYqJWrVqllJQU7d27V5GRkU3219fXa8KECYqMjFR+fr5iYmJ06NAhde/e3ffDAwAA+FCQMcb4e4j2kJiYqKuvvlpr1qyRJLndbrlcLs2bN08LFy5ssj83N1crVqzQnj171Llz5/O6z5qaGoWHh6u6ulphYWE/aH4AAOAb/P4OkKeA6+vrVVpaquTkZM9acHCwkpOTVVJS0uwxr7/+upKSkjR37lxFRUVp6NChWrp0qRobG1u8n7q6OtXU1HjdAAAALjYBEYBVVVVqbGxUVFSU13pUVJQqKiqaPaa8vFz5+flqbGxUQUGBHnroIa1cuVKPPvpoi/eTnZ2t8PBwz83lcrXr4wAAAPCFgAjA8+F2uxUZGamnn35aCQkJSktL0wMPPKDc3NwWj1m0aJGqq6s9tyNHjvhwYgAAgPYREG8CiYiIUEhIiCorK73WKysr1atXr2aP6d27tzp37qyQkBDP2pVXXqmKigrV19fL4XA0OcbpdMrpdLbv8AAAAD4WEFcAHQ6HEhISVFxc7Flzu90qLi5WUlJSs8eMGTNG+/fvl9vt9qzt27dPvXv3bjb+AAAAAkVABKAkZWZmav369Xr++ee1e/du3XXXXaqtrdWMGTMkSenp6Vq0aJFn/1133aXjx49r/vz52rdvn7Zs2aKlS5dq7ty5/noIAAAAPhEQTwFLUlpamo4dO6bFixeroqJC8fHxKiws9Lwx5PDhwwoO/mfvulwuvfXWW1qwYIGGDx+umJgYzZ8/X/fdd5+/HgIAAIBPBMznAPoDnyMEAMDFh9/fAfQUMAAAAFqHAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAlgmoAMzJyVFsbKxCQ0OVmJioHTt2tOq4TZs2KSgoSJMnT+7YAQEAAC4AAROAeXl5yszMVFZWlnbu3Km4uDilpKTo6NGj5zzu4MGDuueeezR27FgfTQoAAOBfAROATz75pO644w7NmDFDP/nJT5Sbm6tLLrlEGzZsaPGYxsZG3XrrrXr44YfVv39/H04LAADgPwERgPX19SotLVVycrJnLTg4WMnJySopKWnxuN/+9reKjIzUzJkzW3U/dXV1qqmp8boBAABcbAIiAKuqqtTY2KioqCiv9aioKFVUVDR7zLZt2/TMM89o/fr1rb6f7OxshYeHe24ul+sHzQ0AAOAPARGAbXXy5ElNnz5d69evV0RERKuPW7Rokaqrqz23I0eOdOCUAAAAHaOTvwdoDxEREQoJCVFlZaXXemVlpXr16tVk/6effqqDBw8qNTXVs+Z2uyVJnTp10t69ezVgwIAmxzmdTjmdznaeHgAAwLcC4gqgw+FQQkKCiouLPWtut1vFxcVKSkpqsn/w4MH66KOPVFZW5rndeOONGj9+vMrKynhqFwAABLSAuAIoSZmZmcrIyNCIESM0cuRIrVq1SrW1tZoxY4YkKT09XTExMcrOzlZoaKiGDh3qdXz37t0lqck6AABAoAmYAExLS9OxY8e0ePFiVVRUKD4+XoWFhZ43hhw+fFjBwQFxwRMAAOAHCTLGGH8PcbGqqalReHi4qqurFRYW5u9xAABAK/D7O0BeAwgAAIDWIwABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUCKgBzcnIUGxur0NBQJSYmaseOHS3uXb9+vcaOHasePXqoR48eSk5OPud+AACAQBEwAZiXl6fMzExlZWVp586diouLU0pKio4ePdrs/q1bt2rq1Kl69913VVJSIpfLpeuuu06fffaZjycHAADwrSBjjPH3EO0hMTFRV199tdasWSNJcrvdcrlcmjdvnhYuXPi9xzc2NqpHjx5as2aN0tPTW3WfNTU1Cg8PV3V1tcLCwn7Q/AAAwDf4/R0gVwDr6+tVWlqq5ORkz1pwcLCSk5NVUlLSqp9x+vRpnT17VpdddlmLe+rq6lRTU+N1AwAAuNgERABWVVWpsbFRUVFRXutRUVGqqKho1c+47777FB0d7RWR35Wdna3w8HDPzeVy/aC5AQAA/CEgAvCHWrZsmTZt2qTXXntNoaGhLe5btGiRqqurPbcjR474cEoAAID20cnfA7SHiIgIhYSEqLKy0mu9srJSvXr1OuexTzzxhJYtW6a3335bw4cPP+dep9Mpp9P5g+cFAADwp4C4AuhwOJSQkKDi4mLPmtvtVnFxsZKSklo8bvny5XrkkUdUWFioESNG+GJUAAAAvwuIK4CSlJmZqYyMDI0YMUIjR47UqlWrVFtbqxkzZkiS0tPTFRMTo+zsbEnS448/rsWLF2vjxo2KjY31vFawa9eu6tq1q98eBwAAQEcLmABMS0vTsWPHtHjxYlVUVCg+Pl6FhYWeN4YcPnxYwcH/vOC5bt061dfX6+abb/b6OVlZWVqyZIkvRwcAAPCpgPkcQH/gc4QAALj48Ps7QF4DCAAAgNYjAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQIqAHNychQbG6vQ0FAlJiZqx44d59z/yiuvaPDgwQoNDdWwYcNUUFDgo0kBAAD8J2ACMC8vT5mZmcrKytLOnTsVFxenlJQUHT16tNn977//vqZOnaqZM2dq165dmjx5siZPnqyPP/7Yx5MDAAD4VpAxxvh7iPaQmJioq6++WmvWrJEkud1uuVwuzZs3TwsXLmyyPy0tTbW1tXrjjTc8a6NGjVJ8fLxyc3NbdZ81NTUKDw9XdXW1wsLC2ueBAACADsXvb6mTvwdoD/X19SotLdWiRYs8a8HBwUpOTlZJSUmzx5SUlCgzM9NrLSUlRZs3b27xfurq6lRXV+f5urq6WtI3/0MCAAAXh29/bwfINbDzEhABWFVVpcbGRkVFRXmtR0VFac+ePc0eU1FR0ez+ioqKFu8nOztbDz/8cJN1l8t1HlMDAAB/+vLLLxUeHu7vMfwiIALQVxYtWuR11fDEiRPq16+fDh8+bO3/gC4UNTU1crlcOnLkiLWX8y8UnIsLC+fjwsG5uHBUV1erb9++uuyyy/w9it8ERABGREQoJCRElZWVXuuVlZXq1atXs8f06tWrTfslyel0yul0NlkPDw/nL/MFIiwsjHNxgeBcXFg4HxcOzsWFIzg4YN4L22YB8cgdDocSEhJUXFzsWXO73SouLlZSUlKzxyQlJXntl6SioqIW9wMAAASKgLgCKEmZmZnKyMjQiBEjNHLkSK1atUq1tbWaMWOGJCk9PV0xMTHKzs6WJM2fP1/jxo3TypUrNWnSJG3atEkffvihnn76aX8+DAAAgA4XMAGYlpamY8eOafHixaqoqFB8fLwKCws9b/Q4fPiw16Xe0aNHa+PGjXrwwQd1//3368c//rE2b96soUOHtvo+nU6nsrKymn1aGL7FubhwcC4uLJyPCwfn4sLBuQigzwEEAABA6wTEawABAADQegQgAACAZQhAAAAAyxCAAAAAliEAv0dOTo5iY2MVGhqqxMRE7dix45z7X3nlFQ0ePFihoaEaNmyYCgoKfDRp4GvLuVi/fr3Gjh2rHj16qEePHkpOTv7ec4fWa+vfi29t2rRJQUFBmjx5cscOaJG2nosTJ05o7ty56t27t5xOpwYOHMh/p9pRW8/HqlWrNGjQIHXp0kUul0sLFizQmTNnfDRt4HrvvfeUmpqq6OhoBQUFafPmzd97zNatW/XTn/5UTqdTP/rRj/Tcc891+Jx+ZdCiTZs2GYfDYTZs2GD+9re/mTvuuMN0797dVFZWNrt/+/btJiQkxCxfvtx88skn5sEHHzSdO3c2H330kY8nDzxtPRfTpk0zOTk5ZteuXWb37t3mtttuM+Hh4eYf//iHjycPPG09F986cOCAiYmJMWPHjjW/+MUvfDNsgGvruairqzMjRowwEydONNu2bTMHDhwwW7duNWVlZT6ePDC19Xy8+OKLxul0mhdffNEcOHDAvPXWW6Z3795mwYIFPp488BQUFJgHHnjAvPrqq0aSee211865v7y83FxyySUmMzPTfPLJJ2b16tUmJCTEFBYW+mZgPyAAz2HkyJFm7ty5nq8bGxtNdHS0yc7Obnb/lClTzKRJk7zWEhMTzaxZszp0Thu09Vx8V0NDg+nWrZt5/vnnO2pEa5zPuWhoaDCjR482f/jDH0xGRgYB2E7aei7WrVtn+vfvb+rr6301olXaej7mzp1rfv7zn3utZWZmmjFjxnTonLZpTQD+5je/MUOGDPFaS0tLMykpKR04mX/xFHAL6uvrVVpaquTkZM9acHCwkpOTVVJS0uwxJSUlXvslKSUlpcX9aJ3zORffdfr0aZ09e9bqf/i7PZzvufjtb3+ryMhIzZw50xdjWuF8zsXrr7+upKQkzZ07V1FRURo6dKiWLl2qxsZGX40dsM7nfIwePVqlpaWep4nLy8tVUFCgiRMn+mRm/JONv78D5l8CaW9VVVVqbGz0/Esi34qKitKePXuaPaaioqLZ/RUVFR02pw3O51x813333afo6Ogmf8HRNudzLrZt26ZnnnlGZWVlPpjQHudzLsrLy/XOO+/o1ltvVUFBgfbv3685c+bo7NmzysrK8sXYAet8zse0adNUVVWla665RsYYNTQ0aPbs2br//vt9MTL+RUu/v2tqavT111+rS5cufpqs43AFEAFv2bJl2rRpk1577TWFhob6exyrnDx5UtOnT9f69esVERHh73Gs53a7FRkZqaeffloJCQlKS0vTAw88oNzcXH+PZqWtW7dq6dKlWrt2rXbu3KlXX31VW7Zs0SOPPOLv0WABrgC2ICIiQiEhIaqsrPRar6ysVK9evZo9plevXm3aj9Y5n3PxrSeeeELLli3T22+/reHDh3fkmFZo67n49NNPdfDgQaWmpnrW3G63JKlTp07au3evBgwY0LFDB6jz+XvRu3dvde7cWSEhIZ61K6+8UhUVFaqvr5fD4ejQmQPZ+ZyPhx56SNOnT9ftt98uSRo2bJhqa2t155136oEHHvD69+vRsVr6/R0WFhaQV/8krgC2yOFwKCEhQcXFxZ41t9ut4uJiJSUlNXtMUlKS135JKioqanE/Wud8zoUkLV++XI888ogKCws1YsQIX4wa8Np6LgYPHqyPPvpIZWVlntuNN96o8ePHq6ysTC6Xy5fjB5Tz+XsxZswY7d+/3xPhkrRv3z717t2b+PuBzud8nD59uknkfRvnxpiOGxZNWPn729/vQrmQbdq0yTidTvPcc8+ZTz75xNx5552me/fupqKiwhhjzPTp083ChQs9+7dv3246depknnjiCbN7926TlZXFx8C0k7aei2XLlhmHw2Hy8/PNF1984bmdPHnSXw8hYLT1XHwX7wJuP209F4cPHzbdunUzv/rVr8zevXvNG2+8YSIjI82jjz7qr4cQUNp6PrKysky3bt3MSy+9ZMrLy83//u//mgEDBpgpU6b46yEEjJMnT5pdu3aZXbt2GUnmySefNLt27TKHDh0yxhizcOFCM336dM/+bz8G5t577zW7d+82OTk5fAyM7VavXm369u1rHA6HGTlypPnggw883xs3bpzJyMjw2v/yyy+bgQMHGofDYYYMGWK2bNni44kDV1vORb9+/YykJresrCzfDx6A2vr34l8RgO2rrefi/fffN4mJicbpdJr+/fubxx57zDQ0NPh46sDVlvNx9uxZs2TJEjNgwAATGhpqXC6XmTNnjvnqq698P3iAeffdd5v9HfDtn39GRoYZN25ck2Pi4+ONw+Ew/fv3N88++6zP5/alIGO4zgwAAGATXgMIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWIYABAAAsAwBCAAAYBkCEAAAwDIEIAAAgGUIQAAAAMsQgAAAAJYhAAEAACxDAAIAAFiGAAQAALAMAQgAAGAZAhAAAMAyBCAAAIBlCEAAAADLEIAAAACWIQABAAAsQwACAABYhgAEAACwDAEIAABgGQIQAADAMgQgAACAZQhAAAAAyxCAAAAAliEAAQAALEMAAgAAWOb/ATyhabrGscARAAAAAElFTkSuQmCC",
      "text/html": [
       "\n",
       "            <div style=\"display: inline-block;\">\n",
       "                <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
       "                    Figure\n",
       "                </div>\n",
       "                <img src='' width=640.0/>\n",
       "            </div>\n",
       "        "
      ],
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# This has to be in a separate cell for some weird event loop reasons\n",
    "%matplotlib widget\n",
    "fig,ax = plt.subplots()\n",
    "fig.suptitle('MSE vs Epochs')\n",
    "plt.show()\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 486
    },
    "id": "pvTfzGmCeXU4",
    "outputId": "0d3a20f3-23be-4c19-9eb6-46bfe11a48b1"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                                                                                                              | 0/1024 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "train_err = []\n",
    "test_err = []\n",
    "\n",
    "# clear loss file\n",
    "open('loss', 'w').close()\n",
    "\n",
    "for epoch in range(NEPOCHS):\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    for i in tqdm(range(BPE)):\n",
    "        batch_src, batch_labels, batch_padding_mask = mkbatch(BSZ)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "        train_loss += loss.item() / BPE\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    test_loss = evaluate()\n",
    "    \n",
    "    test_err.append(test_loss)\n",
    "    train_err.append(train_loss)\n",
    "    with open('loss', 'a') as f:\n",
    "        f.write(f\"{train_loss} {test_loss}\\n\")\n",
    "    print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
    "    ax.plot(train_err, label='Train', color='blue')\n",
    "    ax.plot(test_err, label='Test', color='red')\n",
    "    ax.set_xlabel('Epochs')\n",
    "    ax.set_ylabel('MSE')\n",
    "    fig.canvas.draw()\n",
    "\n",
    "    if epoch % 10 == 9:\n",
    "        torch.save(model.state_dict(), f\"model_weights_{epoch}.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "177.23538208007812"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LoGEmM5lH7_A"
   },
   "outputs": [],
   "source": [
    "batch_src, batch_labels, batch_padding_mask = next(iter(train_loader))\n",
    "output = model(batch_src, batch_padding_mask)\n",
    "batch_src[0], batch_labels[0], output[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(output.detach().cpu().numpy().flatten(),bins=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(label.detach().cpu().numpy().flatten(),bins=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(batch_labels.detach().cpu().numpy().flatten(),output.detach().cpu().numpy().flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_src2, batch_labels2, batch_padding_mask2 = next(iter(test_loader))\n",
    "output2 = model(batch_src2, batch_padding_mask2)\n",
    "loss = criterion(output2.squeeze(1), batch_labels2)\n",
    "batch_src2[0], batch_labels2[0], output2[0], loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(batch_labels2.detach().cpu().numpy().flatten(),output2.detach().cpu().numpy().flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LC6Xv3YfC0Rm"
   },
   "source": [
    "# Step 5: Fine Tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TUNE_EPOCHS = 100\n",
    "TUNE_LR = 0.003\n",
    "TUNE_WD = 0.002\n",
    "\n",
    "tune_criterion = nn.MSELoss()\n",
    "tune_optimizer = torch.optim.Adam(model.parameters(), lr=TUNE_LR, weight_decay=TUNE_WD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tuneevaluate():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    with torch.no_grad():\n",
    "        batch_src, batch_labels, batch_padding_mask = mktunebatch(BSZ)\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This has to be in a separate cell for some weird event loop reasons\n",
    "%matplotlib widget\n",
    "fig,ax = plt.subplots()\n",
    "fig.suptitle('MSE vs Epochs')\n",
    "plt.show()\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'N_TUNE_EPOCHS' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 4\u001b[0m\n\u001b[1;32m      1\u001b[0m tune_train_err \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m      2\u001b[0m tune_test_err \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[43mN_TUNE_EPOCHS\u001b[49m):\n\u001b[1;32m      5\u001b[0m     model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m      6\u001b[0m     train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'N_TUNE_EPOCHS' is not defined"
     ]
    }
   ],
   "source": [
    "tune_train_err = []\n",
    "tune_test_err = []\n",
    "\n",
    "# clear loss file\n",
    "open('tune_loss', 'w').close()\n",
    "\n",
    "for epoch in range(N_TUNE_EPOCHS):\n",
    "    model.train()\n",
    "    train_loss = 0\n",
    "    for i in tqdm(range(BPE)):\n",
    "        batch_src, batch_labels, batch_padding_mask = mktunebatch(BSZ)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(batch_src, batch_padding_mask)\n",
    "        loss = criterion(output.squeeze(1), batch_labels)\n",
    "        train_loss += loss.item()/BPE\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    test_loss = tuneevaluate()\n",
    "    \n",
    "    tune_test_err.append(test_loss)\n",
    "    tune_train_err.append(train_loss)\n",
    "    with open('tune_loss', 'a') as f:\n",
    "        f.write(f\"{train_loss} {test_loss}\\n\")\n",
    "    ax.plot(tune_train_err, label='Train', color='blue')\n",
    "    ax.plot(tune_test_err, label='Test', color='red')\n",
    "    ax.set_xlabel('Epochs')\n",
    "    ax.set_ylabel('MSE')\n",
    "    fig.canvas.draw()\n",
    "    print(f\"Epoch {epoch + 1}/{NEPOCHS} \\t Train Err: {train_loss:.4f} \\t Test Err: {test_loss:.4f}\")\n",
    "\n",
    "    if epoch % 10 == 9:\n",
    "        torch.save(model.state_dict(), f\"tune_model_weights_{epoch}.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JtTLXn4zC1z_"
   },
   "source": [
    "# Step 6: Test generalization"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}