1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
|
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "LPphBnKR-aWF"
},
"source": [
"# Step 0: Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ge5QvElvhCOw",
"outputId": "c7cdaefa-d6dc-44ad-c258-e4fb2aca97a5"
},
"outputs": [],
"source": [
"# using tqdm.auto glitches out collaborative editing\n",
"from tqdm import tqdm\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt, mpld3\n",
"\n",
"from math import sqrt\n",
"from collections import deque\n",
"import os\n",
"import random\n",
"from concurrent.futures import ProcessPoolExecutor\n",
"import pickle\n",
"\n",
"torch.manual_seed(42)\n",
"random.seed(42)\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_state": "idle",
"metadata": {
"id": "lylOX2POPwFL"
},
"outputs": [],
"source": [
"# VTXS numbers here are inclusive\n",
"MIN_VTXS = 3\n",
"MAX_VTXS = 15\n",
"MAX_TUNE_VTXS = 7\n",
"AVG_DEG = 2\n",
"SEQ_LEN = MAX_VTXS + 1 # means 32 edges, final token is the target vertex\n",
"PAD_TOKEN = 0\n",
"EPOCH_SIZE = 2**20\n",
"# vertices are labelled 1,2,...,MAX_VTXS\n",
"# we also have a padding token which is 0."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gKt-yIpDebF1"
},
"source": [
"# Step 1: Generate synthetic data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"# original task data\n",
"# the data will be edge lists\n",
"# like this: [1 3 1 5 2 4 0 0 0 0 2]\n",
"# this represents edges (1,3), (1,5) (2,4)\n",
"# (the zeros are just padding tokens)\n",
"# the final 2 means which vertex we're going to \n",
"\n",
"# the label is the shortest distance from vtx 1 to vtx 2\n",
"# or \"number of vertices\" if no path exists\n",
"\n",
"def random_graph(n):\n",
" edge_list = []\n",
" adjacencies = [set() for _ in range(n+1)]\n",
" indices = [random.randint(1, n) for _ in range(AVG_DEG * n)]\n",
" for i in range(0, len(indices), 2):\n",
" u = indices[i]\n",
" v = indices[i + 1]\n",
" if u != v:\n",
" if u < v:\n",
" edge_list += [u, v]\n",
" else:\n",
" edge_list += [v, u]\n",
" adjacencies[u].add(v)\n",
" adjacencies[v].add(u)\n",
" edge_list += [PAD_TOKEN]*(2*SEQ_LEN-1-len(edge_list))\n",
" return edge_list, adjacencies\n",
"\n",
"def SSSP(G, target=2, distonly=True):\n",
" dist = [MAX_VTXS] * len(G)\n",
" parent = [-1] * len(G)\n",
" dist[1] = 0\n",
" frontier = deque()\n",
" frontier.append(1)\n",
" while len(frontier) > 0:\n",
" vtx = frontier.popleft()\n",
" for x in G[vtx]:\n",
" if dist[x] == MAX_VTXS:\n",
" parent[x] = vtx\n",
" dist[x] = 1 + dist[vtx]\n",
" frontier.append(x)\n",
" if x == target:\n",
" if distonly:\n",
" return dist[x]\n",
" path = [x]\n",
" while parent[x] != -1:\n",
" x = parent[x]\n",
" path.append(x)\n",
" return list(reversed(path))\n",
" return MAX_VTXS if distonly else []\n",
"\n",
"# large: size up to MAX_VTXS instead of MAX_TUNE_VTXS\n",
"# target: 2, \"onpath\", \"any\"\n",
"# largetarget: whether can go up to MAX_VTXS or just MAX_TUNE_VTXS\n",
"def mkbatch(size, large=True, target=2, largetarget=True):\n",
" graphs = []\n",
" dists = []\n",
" \n",
" for i in range(size):\n",
" n = random.randint(MIN_VTXS, MAX_VTXS if large else MAX_TUNE_VTXS)\n",
" while True:\n",
" edge_list, adj_list = random_graph(n)\n",
" if target == 2:\n",
" # Train data\n",
" dist = SSSP(adj_list)\n",
" # Filter out some short distance graphs\n",
" if 6 <= dist < MAX_VTXS or (dist < 6 and 2 * random.random() < 1 / (6 - dist)) or (dist == MAX_VTXS and random.random() < 1 / MAX_VTXS):\n",
" edge_list[-1] = 2\n",
" break\n",
" elif target == \"onpath\":\n",
" path = SSSP(adj_list, 2, False)\n",
" if len(path) < 1:\n",
" continue\n",
" dist = random.randrange(1, len(path))\n",
" if largetarget or path[dist] <= MAX_TUNE_VTXS:\n",
" edge_list[-1] = path[dist]\n",
" break\n",
" elif target == \"any\":\n",
" edge_list[-1] = random.randint(2, MAX_VTXS if largetarget else MAX_TUNE_VTXS)\n",
" dist = SSSP(adj_list, edge_list[-1])\n",
" break\n",
" graphs.append(edge_list)\n",
" dists.append(dist)\n",
" \n",
" data = torch.tensor(graphs)\n",
" labels = torch.tensor(dists, dtype=torch.float32)\n",
" padding = data == PAD_TOKEN\n",
" return data, labels, padding\n",
"\n",
"def savebatch(size, idx):\n",
" data, labels, padding = mkbatch(size)\n",
" everything = {\n",
" \"data\": data,\n",
" \"labels\": labels,\n",
" \"padding\": padding,\n",
" }\n",
" \n",
" with open(f'data-new/{idx}.pickle', 'wb') as file:\n",
" pickle.dump(everything, file)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"# Only need to run this once to generate training data\n",
"# RESTART THE KERNEL BEFORE RUNNING AND ONLY RUN THE CELLS ABOVE\n",
"# Python is slow and awful\n",
"\n",
"# with ProcessPoolExecutor() as executor:\n",
"# for i in range(64):\n",
"# executor.submit(savebatch, EPOCH_SIZE, i)\n",
"# executor.shutdown()"
]
},
{
"cell_type": "code",
"execution_count": 95,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[1, 3, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 2],\n",
" [1, 5, 3, 4, 1, 3, 4, 5, 3, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 2],\n",
" [2, 9, 1, 3, 4, 8, 8, 9, 3, 9, 5, 7, 6, 7, 1, 6, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 2]]),\n",
" tensor([ 1., 15., 3.]),\n",
" tensor([[False, False, False, False, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" False],\n",
" [False, False, False, False, False, False, False, False, False, False,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" False],\n",
" [False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" False]]))"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mkbatch(3)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqu0lEQVR4nO3de3BUZZ7/8U8upAmX7kg03WQJkFlcIYgXwIUe1F00Sw/TuroEZ3AQWUAtqEZNsgsxu5hx0DGIowgqZPBCmBpYhSpxJCnAGCCsQ7gYNg6gRFyjYSZ2x11NGhhJQtK/P36Ts7SCkpCkecL7VXWq7PN8zznfpyukP56cczoqFAqFBAAAYJDoSDcAAADQXgQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxYiPdQFdpbW1VbW2t+vfvr6ioqEi3AwAAzkMoFNLx48eVnJys6Ohzn2fpsQGmtrZWKSkpkW4DAAB0wLFjxzRo0KBzjvfYANO/f3/pL2+A3W6PdDsAAOA8BINBpaSkWJ/j59JjA0zbn43sdjsBBgAAw3zf5R9cxAsAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgnNhIN2CioY8UR+S4ny7xRuS4AABcbDgDAwAAjEOAAQAAxiHAAAAA4xBgAACAcdoVYIYOHaqoqKhvLT6fT5J06tQp+Xw+JSYmql+/fsrIyFAgEAjbR01Njbxer/r06aOkpCQtWLBAp0+fDqvZuXOnRo8eLZvNpmHDhqmwsLAz5goAAHqIdgWY/fv36/PPP7eWkpISSdJdd90lScrKytLmzZu1ceNGlZWVqba2VlOmTLG2b2lpkdfrVVNTk3bv3q21a9eqsLBQeXl5Vk11dbW8Xq8mTpyoyspKZWZm6r777tO2bds6b9YAAMBoUaFQKNTRjTMzM1VUVKSjR48qGAzqiiuu0Pr16zV16lRJ0pEjRzRixAiVl5dr/Pjx2rJli2677TbV1tbK6XRKkgoKCpSTk6MvvvhCcXFxysnJUXFxsQ4dOmQdZ9q0aaqvr9fWrVvPu7dgMCiHw6GGhgbZ7faOTvGsuI0aAICucb6f3x2+BqapqUm//e1vNXv2bEVFRamiokLNzc1KT0+3aoYPH67BgwervLxcklReXq5Ro0ZZ4UWSPB6PgsGgDh8+bNWcuY+2mrZ9nEtjY6OCwWDYAgAAeqYOB5g333xT9fX1+ud//mdJkt/vV1xcnBISEsLqnE6n/H6/VXNmeGkbbxv7rppgMKivv/76nP3k5+fL4XBYS0pKSkenBgAALnIdDjCvvPKKJk+erOTk5M7tqINyc3PV0NBgLceOHYt0SwAAoIt06KsEPvvsM73zzjt64403rHUul0tNTU2qr68POwsTCATkcrmsmn379oXtq+0upTNrvnnnUiAQkN1uV3x8/Dl7stlsstlsHZkOAAAwTIfOwKxZs0ZJSUnyev/votIxY8aoV69eKi0ttdZVVVWppqZGbrdbkuR2u3Xw4EHV1dVZNSUlJbLb7UpLS7NqztxHW03bPgAAANodYFpbW7VmzRrNnDlTsbH/dwLH4XBozpw5ys7O1o4dO1RRUaFZs2bJ7XZr/PjxkqRJkyYpLS1NM2bM0Pvvv69t27Zp0aJF8vl81tmTuXPn6pNPPtHChQt15MgRrVy5Uhs2bFBWVlZnzhsAABis3X9Ceuedd1RTU6PZs2d/a2zZsmWKjo5WRkaGGhsb5fF4tHLlSms8JiZGRUVFmjdvntxut/r27auZM2dq8eLFVk1qaqqKi4uVlZWl5cuXa9CgQXr55Zfl8XguZJ4AAKAHuaDnwFzMeA4MAADm6fLnwAAAAEQKAQYAABiHAAMAAIzToefAAACAi8Olel0mZ2AAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDh8lYBBIvW4aF0Ej4wGAOBMnIEBAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYp90B5k9/+pPuueceJSYmKj4+XqNGjdJ7771njYdCIeXl5WngwIGKj49Xenq6jh49GraPL7/8UtOnT5fdbldCQoLmzJmjEydOhNX84Q9/0E033aTevXsrJSVFS5cuvZB5AgCAHqRdAearr77ShAkT1KtXL23ZskUffPCBnnnmGV122WVWzdKlS7VixQoVFBRo79696tu3rzwej06dOmXVTJ8+XYcPH1ZJSYmKioq0a9cuPfDAA9Z4MBjUpEmTNGTIEFVUVOjpp5/WY489ptWrV3fWvAEAgMFi21P81FNPKSUlRWvWrLHWpaamWv8dCoX03HPPadGiRbrjjjskSb/5zW/kdDr15ptvatq0afrwww+1detW7d+/X2PHjpUkPf/88/rxj3+sX/3qV0pOTta6devU1NSkV199VXFxcRo5cqQqKyv17LPPhgUdAABwaWrXGZi33npLY8eO1V133aWkpCRdf/31eumll6zx6upq+f1+paenW+scDofGjRun8vJySVJ5ebkSEhKs8CJJ6enpio6O1t69e62am2++WXFxcVaNx+NRVVWVvvrqq7P21tjYqGAwGLYAAICeqV0B5pNPPtGqVat05ZVXatu2bZo3b54eeughrV27VpLk9/slSU6nM2w7p9Npjfn9fiUlJYWNx8bGasCAAWE1Z9vHmcf4pvz8fDkcDmtJSUlpz9QAAIBB2hVgWltbNXr0aD355JO6/vrr9cADD+j+++9XQUFB13V4nnJzc9XQ0GAtx44di3RLAACgi7QrwAwcOFBpaWlh60aMGKGamhpJksvlkiQFAoGwmkAgYI25XC7V1dWFjZ8+fVpffvllWM3Z9nHmMb7JZrPJbreHLQAAoGdqV4CZMGGCqqqqwtZ99NFHGjJkiPSXC3pdLpdKS0ut8WAwqL1798rtdkuS3G636uvrVVFRYdVs375dra2tGjdunFWza9cuNTc3WzUlJSW66qqrwu54AgAAl6Z2BZisrCzt2bNHTz75pD7++GOtX79eq1evls/nkyRFRUUpMzNTTzzxhN566y0dPHhQ9957r5KTk3XnnXdKfzlj86Mf/Uj333+/9u3bp9///veaP3++pk2bpuTkZEnSz372M8XFxWnOnDk6fPiwXn/9dS1fvlzZ2dld8R4AAADDtOs26htuuEGbNm1Sbm6uFi9erNTUVD333HOaPn26VbNw4UKdPHlSDzzwgOrr63XjjTdq69at6t27t1Wzbt06zZ8/X7feequio6OVkZGhFStWWOMOh0Nvv/22fD6fxowZo8svv1x5eXncQg0AACRJUaFQKBTpJrpCMBiUw+FQQ0NDp18PM/SR4k7dnwk+XeKNdAsAgLOI1GdSV30unO/nN9+FBAAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGCcdgWYxx57TFFRUWHL8OHDrfFTp07J5/MpMTFR/fr1U0ZGhgKBQNg+ampq5PV61adPHyUlJWnBggU6ffp0WM3OnTs1evRo2Ww2DRs2TIWFhRc6TwAA0IO0+wzMyJEj9fnnn1vLu+++a41lZWVp8+bN2rhxo8rKylRbW6spU6ZY4y0tLfJ6vWpqatLu3bu1du1aFRYWKi8vz6qprq6W1+vVxIkTVVlZqczMTN13333atm1bZ8wXAAD0ALHt3iA2Vi6X61vrGxoa9Morr2j9+vW65ZZbJElr1qzRiBEjtGfPHo0fP15vv/22PvjgA73zzjtyOp267rrr9PjjjysnJ0ePPfaY4uLiVFBQoNTUVD3zzDOSpBEjRujdd9/VsmXL5PF4OmPOAADAcO0+A3P06FElJyfrBz/4gaZPn66amhpJUkVFhZqbm5Wenm7VDh8+XIMHD1Z5ebkkqby8XKNGjZLT6bRqPB6PgsGgDh8+bNWcuY+2mrZ9nEtjY6OCwWDYAgAAeqZ2BZhx48apsLBQW7du1apVq1RdXa2bbrpJx48fl9/vV1xcnBISEsK2cTqd8vv9kiS/3x8WXtrG28a+qyYYDOrrr78+Z2/5+flyOBzWkpKS0p6pAQAAg7TrT0iTJ0+2/vuaa67RuHHjNGTIEG3YsEHx8fFd0d95y83NVXZ2tvU6GAwSYgAA6KEu6DbqhIQE/c3f/I0+/vhjuVwuNTU1qb6+PqwmEAhY18y4XK5v3ZXU9vr7aux2+3eGJJvNJrvdHrYAAICe6YICzIkTJ/Tf//3fGjhwoMaMGaNevXqptLTUGq+qqlJNTY3cbrckye126+DBg6qrq7NqSkpKZLfblZaWZtWcuY+2mrZ9AAAAtCvA/Ou//qvKysr06aefavfu3fqnf/onxcTE6O6775bD4dCcOXOUnZ2tHTt2qKKiQrNmzZLb7db48eMlSZMmTVJaWppmzJih999/X9u2bdOiRYvk8/lks9kkSXPnztUnn3yihQsX6siRI1q5cqU2bNigrKysrnkHAACAcdp1Dcwf//hH3X333frf//1fXXHFFbrxxhu1Z88eXXHFFZKkZcuWKTo6WhkZGWpsbJTH49HKlSut7WNiYlRUVKR58+bJ7Xarb9++mjlzphYvXmzVpKamqri4WFlZWVq+fLkGDRqkl19+mVuoAQCAJSoUCoUi3URXCAaDcjgcamho6PTrYYY+Utyp+zPBp0u8kW4BAHAWkfpM6qrPhfP9/Oa7kAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgnHZ9mSMuXT3tuzYAAGbjDAwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABjnggLMkiVLFBUVpczMTGvdqVOn5PP5lJiYqH79+ikjI0OBQCBsu5qaGnm9XvXp00dJSUlasGCBTp8+HVazc+dOjR49WjabTcOGDVNhYeGFtAoAAHqQDgeY/fv369e//rWuueaasPVZWVnavHmzNm7cqLKyMtXW1mrKlCnWeEtLi7xer5qamrR7926tXbtWhYWFysvLs2qqq6vl9Xo1ceJEVVZWKjMzU/fdd5+2bdvW0XYBAEAP0qEAc+LECU2fPl0vvfSSLrvsMmt9Q0ODXnnlFT377LO65ZZbNGbMGK1Zs0a7d+/Wnj17JElvv/22PvjgA/32t7/Vddddp8mTJ+vxxx/Xiy++qKamJklSQUGBUlNT9cwzz2jEiBGaP3++pk6dqmXLlnXWvAEAgME6FGB8Pp+8Xq/S09PD1ldUVKi5uTls/fDhwzV48GCVl5dLksrLyzVq1Cg5nU6rxuPxKBgM6vDhw1bNN/ft8XisfZxNY2OjgsFg2AIAAHqm2PZu8Nprr+nAgQPav3//t8b8fr/i4uKUkJAQtt7pdMrv91s1Z4aXtvG2se+qCQaD+vrrrxUfH/+tY+fn5+sXv/hFe6cDAAAM1K4zMMeOHdPDDz+sdevWqXfv3l3XVQfk5uaqoaHBWo4dOxbplgAAQBdpV4CpqKhQXV2dRo8erdjYWMXGxqqsrEwrVqxQbGysnE6nmpqaVF9fH7ZdIBCQy+WSJLlcrm/dldT2+vtq7Hb7Wc++SJLNZpPdbg9bAABAz9SuAHPrrbfq4MGDqqystJaxY8dq+vTp1n/36tVLpaWl1jZVVVWqqamR2+2WJLndbh08eFB1dXVWTUlJiex2u9LS0qyaM/fRVtO2DwAAcGlr1zUw/fv319VXXx22rm/fvkpMTLTWz5kzR9nZ2RowYIDsdrsefPBBud1ujR8/XpI0adIkpaWlacaMGVq6dKn8fr8WLVokn88nm80mSZo7d65eeOEFLVy4ULNnz9b27du1YcMGFRcXd97MAQCAsdp9Ee/3WbZsmaKjo5WRkaHGxkZ5PB6tXLnSGo+JiVFRUZHmzZsnt9utvn37aubMmVq8eLFVk5qaquLiYmVlZWn58uUaNGiQXn75ZXk8ns5uFwAAGCgqFAqFIt1EVwgGg3I4HGpoaOj062GGPsKZoO7y6RJvpFsAgItapD6Tuur38/l+fvNdSAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGaVeAWbVqla655hrZ7XbZ7Xa53W5t2bLFGj916pR8Pp8SExPVr18/ZWRkKBAIhO2jpqZGXq9Xffr0UVJSkhYsWKDTp0+H1ezcuVOjR4+WzWbTsGHDVFhYeKHzBAAAPUi7AsygQYO0ZMkSVVRU6L333tMtt9yiO+64Q4cPH5YkZWVlafPmzdq4caPKyspUW1urKVOmWNu3tLTI6/WqqalJu3fv1tq1a1VYWKi8vDyrprq6Wl6vVxMnTlRlZaUyMzN13333adu2bZ05bwAAYLCoUCgUupAdDBgwQE8//bSmTp2qK664QuvXr9fUqVMlSUeOHNGIESNUXl6u8ePHa8uWLbrttttUW1srp9MpSSooKFBOTo6++OILxcXFKScnR8XFxTp06JB1jGnTpqm+vl5bt249776CwaAcDocaGhpkt9svZIrfMvSR4k7dH87t0yXeSLcAABe1SH0mddXv5/P9/I7t6AFaWlq0ceNGnTx5Um63WxUVFWpublZ6erpVM3z4cA0ePNgKMOXl5Ro1apQVXiTJ4/Fo3rx5Onz4sK6//nqVl5eH7aOtJjMz8zv7aWxsVGNjo/U6GAx2dGq4iPS0f5gAgM7R7ot4Dx48qH79+slms2nu3LnatGmT0tLS5Pf7FRcXp4SEhLB6p9Mpv98vSfL7/WHhpW28bey7aoLBoL7++utz9pWfny+Hw2EtKSkp7Z0aAAAwRLsDzFVXXaXKykrt3btX8+bN08yZM/XBBx90TXftkJubq4aGBms5duxYpFsCAABdpN1/QoqLi9OwYcMkSWPGjNH+/fu1fPly/fSnP1VTU5Pq6+vDzsIEAgG5XC5Jksvl0r59+8L213aX0pk137xzKRAIyG63Kz4+/px92Ww22Wy29k4HAAAY6IKfA9Pa2qrGxkaNGTNGvXr1UmlpqTVWVVWlmpoaud1uSZLb7dbBgwdVV1dn1ZSUlMhutystLc2qOXMfbTVt+wAAAGjXGZjc3FxNnjxZgwcP1vHjx7V+/Xrt3LlT27Ztk8Ph0Jw5c5Sdna0BAwbIbrfrwQcflNvt1vjx4yVJkyZNUlpammbMmKGlS5fK7/dr0aJF8vl81tmTuXPn6oUXXtDChQs1e/Zsbd++XRs2bFBxMXf+AACA/69dAaaurk733nuvPv/8czkcDl1zzTXatm2b/uEf/kGStGzZMkVHRysjI0ONjY3yeDxauXKltX1MTIyKioo0b948ud1u9e3bVzNnztTixYutmtTUVBUXFysrK0vLly/XoEGD9PLLL8vj8XTmvAEAgMEu+DkwFyueA4MLwW3UAEzR0x43cb6f33wXEgAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA48RGugHgYjT0keKIHPfTJd6IHBcATMMZGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOO0KMPn5+brhhhvUv39/JSUl6c4771RVVVVYzalTp+Tz+ZSYmKh+/fopIyNDgUAgrKampkZer1d9+vRRUlKSFixYoNOnT4fV7Ny5U6NHj5bNZtOwYcNUWFh4IfMEAAA9SLsCTFlZmXw+n/bs2aOSkhI1Nzdr0qRJOnnypFWTlZWlzZs3a+PGjSorK1Ntba2mTJlijbe0tMjr9aqpqUm7d+/W2rVrVVhYqLy8PKumurpaXq9XEydOVGVlpTIzM3Xfffdp27ZtnTVvAABgsKhQKBTq6MZffPGFkpKSVFZWpptvvlkNDQ264oortH79ek2dOlWSdOTIEY0YMULl5eUaP368tmzZottuu021tbVyOp2SpIKCAuXk5OiLL75QXFyccnJyVFxcrEOHDlnHmjZtmurr67V169bz6i0YDMrhcKihoUF2u72jUzyrSD1mHj0fXyUAoL162lefnO/n9wVdA9PQ0CBJGjBggCSpoqJCzc3NSk9Pt2qGDx+uwYMHq7y8XJJUXl6uUaNGWeFFkjwej4LBoA4fPmzVnLmPtpq2fZxNY2OjgsFg2AIAAHqmDgeY1tZWZWZmasKECbr66qslSX6/X3FxcUpISAirdTqd8vv9Vs2Z4aVtvG3su2qCwaC+/vrrs/aTn58vh8NhLSkpKR2dGgAAuMh1OMD4fD4dOnRIr732Wud21EG5ublqaGiwlmPHjkW6JQAA0EViO7LR/PnzVVRUpF27dmnQoEHWepfLpaamJtXX14edhQkEAnK5XFbNvn37wvbXdpfSmTXfvHMpEAjIbrcrPj7+rD3ZbDbZbLaOTAcAABimXWdgQqGQ5s+fr02bNmn79u1KTU0NGx8zZox69eql0tJSa11VVZVqamrkdrslSW63WwcPHlRdXZ1VU1JSIrvdrrS0NKvmzH201bTtAwAAXNradQbG5/Np/fr1+t3vfqf+/ftb16w4HA7Fx8fL4XBozpw5ys7O1oABA2S32/Xggw/K7XZr/PjxkqRJkyYpLS1NM2bM0NKlS+X3+7Vo0SL5fD7rDMrcuXP1wgsvaOHChZo9e7a2b9+uDRs2qLiYu38AAEA7z8CsWrVKDQ0N+vu//3sNHDjQWl5//XWrZtmyZbrtttuUkZGhm2++WS6XS2+88YY1HhMTo6KiIsXExMjtduuee+7Rvffeq8WLF1s1qampKi4uVklJia699lo988wzevnll+XxeDpr3gAAwGAX9ByYixnPgYGJeA4MgPbiOTAAAACGIMAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOO068scAXStnvZIcADoKpyBAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcdodYHbt2qXbb79dycnJioqK0ptvvhk2HgqFlJeXp4EDByo+Pl7p6ek6evRoWM2XX36p6dOny263KyEhQXPmzNGJEyfCav7whz/opptuUu/evZWSkqKlS5d2dI4AAKCHaXeAOXnypK699lq9+OKLZx1funSpVqxYoYKCAu3du1d9+/aVx+PRqVOnrJrp06fr8OHDKikpUVFRkXbt2qUHHnjAGg8Gg5o0aZKGDBmiiooKPf3003rssce0evXqjs4TAAD0ILHt3WDy5MmaPHnyWcdCoZCee+45LVq0SHfccYck6Te/+Y2cTqfefPNNTZs2TR9++KG2bt2q/fv3a+zYsZKk559/Xj/+8Y/1q1/9SsnJyVq3bp2ampr06quvKi4uTiNHjlRlZaWeffbZsKADAAAuTZ16DUx1dbX8fr/S09OtdQ6HQ+PGjVN5ebkkqby8XAkJCVZ4kaT09HRFR0dr7969Vs3NN9+suLg4q8bj8aiqqkpfffXVWY/d2NioYDAYtgAAgJ6pUwOM3++XJDmdzrD1TqfTGvP7/UpKSgobj42N1YABA8JqzraPM4/xTfn5+XI4HNaSkpLSiTMDAAAXkx5zF1Jubq4aGhqs5dixY5FuCQAAdJFODTAul0uSFAgEwtYHAgFrzOVyqa6uLmz89OnT+vLLL8NqzraPM4/xTTabTXa7PWwBAAA9U6cGmNTUVLlcLpWWllrrgsGg9u7dK7fbLUlyu92qr69XRUWFVbN9+3a1trZq3LhxVs2uXbvU3Nxs1ZSUlOiqq67SZZdd1pktAwAAA7U7wJw4cUKVlZWqrKyU/nLhbmVlpWpqahQVFaXMzEw98cQTeuutt3Tw4EHde++9Sk5O1p133ilJGjFihH70ox/p/vvv1759+/T73/9e8+fP17Rp05ScnCxJ+tnPfqa4uDjNmTNHhw8f1uuvv67ly5crOzu7s+cPAAAM1O7bqN977z1NnDjRet0WKmbOnKnCwkItXLhQJ0+e1AMPPKD6+nrdeOON2rp1q3r37m1ts27dOs2fP1+33nqroqOjlZGRoRUrVljjDodDb7/9tnw+n8aMGaPLL79ceXl53EINAAAkSVGhUCgU6Sa6QjAYlMPhUENDQ6dfDzP0keJO3R8QaZ8u8Ua6BQAdFKnPpK76vXG+n9895i4kAABw6SDAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADAOAQYAABiHAAMAAIxDgAEAAMYhwAAAAOPERroBAJE39JHiiBz30yXeiBwXgPk4AwMAAIxDgAEAAMYhwAAAAOMQYAAAgHEIMAAAwDgEGAAAYBwCDAAAMA4BBgAAGIcAAwAAjEOAAQAAxiHAAAAA4xBgAACAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMExvpBgBcuoY+UhyR4366xBuR4wLoPJyBAQAAxrmoA8yLL76ooUOHqnfv3ho3bpz27dsX6ZYAAMBF4KINMK+//rqys7P185//XAcOHNC1114rj8ejurq6SLcGAAAi7KINMM8++6zuv/9+zZo1S2lpaSooKFCfPn306quvRro1AAAQYRflRbxNTU2qqKhQbm6utS46Olrp6ekqLy8/6zaNjY1qbGy0Xjc0NEiSgsFgp/fX2vjnTt8ngO4zOGtjRI576BeeiBwXPVukPpO64vP1zP2GQqHvrLsoA8z//M//qKWlRU6nM2y90+nUkSNHzrpNfn6+fvGLX3xrfUpKSpf1CQDt4Xgu0h0Anaerf56PHz8uh8NxzvGLMsB0RG5urrKzs63Xra2t+vLLL5WYmKioqKiI9tbZgsGgUlJSdOzYMdnt9ki30+2Y/6U9f/EeXPLzF+9Bj55/KBTS8ePHlZyc/J11F2WAufzyyxUTE6NAIBC2PhAIyOVynXUbm80mm80Wti4hIaFL+4w0u93e435w24P5X9rzF+/BJT9/8R702Pl/15mXNhflRbxxcXEaM2aMSktLrXWtra0qLS2V2+2OaG8AACDyLsozMJKUnZ2tmTNnauzYsfrbv/1bPffcczp58qRmzZoV6dYAAECEXbQB5qc//am++OIL5eXlye/367rrrtPWrVu/dWHvpchms+nnP//5t/5kdqlg/pf2/MV7cMnPX7wHl/z8JSkq9H33KQEAAFxkLsprYAAAAL4LAQYAABiHAAMAAIxDgAEAAMYhwBgiPz9fN9xwg/r376+kpCTdeeedqqqqinRbEbNkyRJFRUUpMzMz0q10qz/96U+65557lJiYqPj4eI0aNUrvvfdepNvqFi0tLXr00UeVmpqq+Ph4/fVf/7Uef/zx7/2+FJPt2rVLt99+u5KTkxUVFaU333wzbDwUCikvL08DBw5UfHy80tPTdfTo0Yj129m+a/7Nzc3KycnRqFGj1LdvXyUnJ+vee+9VbW1tRHvubN/3M3CmuXPnKioqSs89d2l8ZwUBxhBlZWXy+Xzas2ePSkpK1NzcrEmTJunkyZORbq3b7d+/X7/+9a91zTXXRLqVbvXVV19pwoQJ6tWrl7Zs2aIPPvhAzzzzjC677LJIt9YtnnrqKa1atUovvPCCPvzwQz311FNaunSpnn/++Ui31mVOnjypa6+9Vi+++OJZx5cuXaoVK1aooKBAe/fuVd++feXxeHTq1Klu77UrfNf8//znP+vAgQN69NFHdeDAAb3xxhuqqqrSP/7jP0ak167yfT8DbTZt2qQ9e/Z87+P3e5QQjFRXVxeSFCorK4t0K93q+PHjoSuvvDJUUlIS+ru/+7vQww8/HOmWuk1OTk7oxhtvjHQbEeP1ekOzZ88OWzdlypTQ9OnTI9ZTd5IU2rRpk/W6tbU15HK5Qk8//bS1rr6+PmSz2UL/8R//EaEuu8435382+/btC0kKffbZZ93WV3c613vwxz/+MfRXf/VXoUOHDoWGDBkSWrZsWUT6626cgTFUQ0ODJGnAgAGRbqVb+Xw+eb1epaenR7qVbvfWW29p7Nixuuuuu5SUlKTrr79eL730UqTb6jY//OEPVVpaqo8++kiS9P777+vdd9/V5MmTI91aRFRXV8vv94f9W3A4HBo3bpzKy8sj2lukNDQ0KCoqqsd/D96ZWltbNWPGDC1YsEAjR46MdDvd6qJ9Ei/OrbW1VZmZmZowYYKuvvrqSLfTbV577TUdOHBA+/fvj3QrEfHJJ59o1apVys7O1r/9279p//79euihhxQXF6eZM2dGur0u98gjjygYDGr48OGKiYlRS0uLfvnLX2r69OmRbi0i/H6/JH3r6eROp9Mau5ScOnVKOTk5uvvuu3vklxuey1NPPaXY2Fg99NBDkW6l2xFgDOTz+XTo0CG9++67kW6l2xw7dkwPP/ywSkpK1Lt370i3ExGtra0aO3asnnzySUnS9ddfr0OHDqmgoOCSCDAbNmzQunXrtH79eo0cOVKVlZXKzMxUcnLyJTF/nFtzc7N+8pOfKBQKadWqVZFup9tUVFRo+fLlOnDggKKioiLdTrfjT0iGmT9/voqKirRjxw4NGjQo0u10m4qKCtXV1Wn06NGKjY1VbGysysrKtGLFCsXGxqqlpSXSLXa5gQMHKi0tLWzdiBEjVFNTE7GeutOCBQv0yCOPaNq0aRo1apRmzJihrKws5efnR7q1iHC5XJKkQCAQtj4QCFhjl4K28PLZZ5+ppKTkkjr78p//+Z+qq6vT4MGDrd+Ln332mf7lX/5FQ4cOjXR7XY4zMIYIhUJ68MEHtWnTJu3cuVOpqamRbqlb3XrrrTp48GDYulmzZmn48OHKyclRTExMxHrrLhMmTPjWrfMfffSRhgwZErGeutOf//xnRUeH/z9XTEyMWltbI9ZTJKWmpsrlcqm0tFTXXXedJCkYDGrv3r2aN29epNvrFm3h5ejRo9qxY4cSExMj3VK3mjFjxreuB/R4PJoxY4ZmzZoVsb66CwHGED6fT+vXr9fvfvc79e/f3/obt8PhUHx8fKTb63L9+/f/1vU+ffv2VWJi4iVzHVBWVpZ++MMf6sknn9RPfvIT7du3T6tXr9bq1asj3Vq3uP322/XLX/5SgwcP1siRI/Vf//VfevbZZzV79uxIt9ZlTpw4oY8//th6XV1drcrKSg0YMECDBw9WZmamnnjiCV155ZVKTU3Vo48+quTkZN15550R7buzfNf8Bw4cqKlTp+rAgQMqKipSS0uL9XtxwIABiouLi2Dnnef7fga+Gdp69eoll8ulq666KgLddrNI3waF8yPprMuaNWsi3VrEXGq3UYdCodDmzZtDV199dchms4WGDx8eWr16daRb6jbBYDD08MMPhwYPHhzq3bt36Ac/+EHo3//930ONjY2Rbq3L7Nix46z/7mfOnBkK/eVW6kcffTTkdDpDNpstdOutt4aqqqoi3Xan+a75V1dXn/P34o4dOyLdeqf5vp+Bb7qUbqOOCvXkx1gCAIAeiYt4AQCAcQgwAADAOAQYAABgHAIMAAAwDgEGAAAYhwADAACMQ4ABAADGIcAAAADjEGAAAIBxCDAAAMA4BBgAAGAcAgwAADDO/wPP0WtNPThbKQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.hist(mkbatch(2**15)[1].cpu(), bins=MAX_VTXS)\n",
"with open(\"train-dist.html\", \"w\") as f:\n",
" mpld3.save_html(plt.gcf(), f)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([2.4747e+04, 6.5340e+03, 1.2840e+03, 1.8300e+02, 2.0000e+01]),\n",
" array([1., 2., 3., 4., 5., 6.]),\n",
" <BarContainer object of 5 artists>)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAGdCAYAAADwjmIIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkD0lEQVR4nO3dfVSUdf7/8Rc3DZgxQ94AspJabiJ5GxpOllvJcVRyY3N31VwXjerUATelvKE8aG2F2fGkHU3X2iPtObGpndUKEiVMWBPvMFaxdLP0YKuDlskIW6gwvz9+ca3zFSsUmfj4fJxznW3m+sw177n27PJsmLkI8Hq9XgEAABgm0N8DAAAAXAlEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjBft7AH9qaGjQ0aNHFRYWpoCAAH+PAwAAfgKv16vTp08rOjpagYEXf7/mqo6co0ePKiYmxt9jAACAS3DkyBF17dr1ovuv6sgJCwuTvj9Jdrvd3+MAAICfwOPxKCYmxvo5fjFXdeQ0/orKbrcTOQAAtDE/9lETPngMAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEjNipzs7GwNHjxYYWFhioiIUHJysg4cOOCz5q677lJAQIDP9uijj/qsqaysVFJSkq699lpFRERoxowZOnfunM+azZs369Zbb1VISIh69uypnJycC+ZZunSpunfvrtDQUCUkJGjHjh3Ne/UAAMBYzYqc4uJipaWladu2bSosLNTZs2c1YsQI1dbW+qx7+OGHdezYMWtbsGCBta++vl5JSUk6c+aMtm7dqjfeeEM5OTnKysqy1hw6dEhJSUm6++67VV5ermnTpumhhx7Shg0brDWrVq1SRkaG5s6dq927d6t///5yuVw6fvz45Z0RAABghACv1+u91AefOHFCERERKi4u1rBhw6Tv38kZMGCAFi1a1ORj1q9fr3vvvVdHjx5VZGSkJGn58uWaNWuWTpw4IZvNplmzZik/P18VFRXW48aPH69Tp06poKBAkpSQkKDBgwdryZIl0vd/bDMmJkZTp07V7Nmzf9L8Ho9HDodD1dXVXPEYAIA24qf+/L6sz+RUV1dLkjp06OBz/5tvvqlOnTqpT58+yszM1H//+19rX2lpqfr27WsFjiS5XC55PB7t27fPWpOYmOhzTJfLpdLSUknSmTNnVFZW5rMmMDBQiYmJ1pqm1NXVyePx+GwAAMBMl/y3qxoaGjRt2jQNHTpUffr0se5/4IEH1K1bN0VHR2vPnj2aNWuWDhw4oH/84x+SJLfb7RM4kqzbbrf7B9d4PB59++23+uabb1RfX9/kmv3791905uzsbD3zzDOX+pIBAEAbcsmRk5aWpoqKCm3ZssXn/kceecT65759+6pLly4aPny4Pv/8c910002XN+1lyszMVEZGhnW78a+YAgAA81xS5KSnpysvL08lJSXq2rXrD65NSEiQJB08eFA33XSToqKiLvgWVFVVlSQpKirK+s/G+85fY7fb1a5dOwUFBSkoKKjJNY3HaEpISIhCQkKa+WoBAEBb1KzI8Xq9mjp1qtauXavNmzerR48eP/qY8vJySVKXLl0kSU6nU88//7yOHz+uiIgISVJhYaHsdrvi4uKsNe+//77PcQoLC+V0OiVJNptN8fHxKioqUnJysvT9r8+KioqUnp7enJd0xXSfne/vEa4Kh+cn+XsEAMDPVLMiJy0tTbm5uXrnnXcUFhZmfYbG4XCoXbt2+vzzz5Wbm6vRo0erY8eO2rNnj6ZPn65hw4apX79+kqQRI0YoLi5OkyZN0oIFC+R2uzVnzhylpaVZ77I8+uijWrJkiWbOnKkHH3xQmzZt0urVq5Wf/79wyMjIUEpKigYNGqTbbrtNixYtUm1traZMmdKyZwgAALRJzYqcZcuWSd9/Tfx8K1eu1OTJk2Wz2fTBBx9YwRETE6OxY8dqzpw51tqgoCDl5eXpsccek9PpVPv27ZWSkqJnn33WWtOjRw/l5+dr+vTpWrx4sbp27arXX39dLpfLWjNu3DidOHFCWVlZcrvdGjBggAoKCi74MDIAALg6XdZ1ctq6K3mdHH5d1Tr4dRUAXH1a5To5AAAAP1dEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACM1K3Kys7M1ePBghYWFKSIiQsnJyTpw4IDPmu+++05paWnq2LGjrrvuOo0dO1ZVVVU+ayorK5WUlKRrr71WERERmjFjhs6dO+ezZvPmzbr11lsVEhKinj17Kicn54J5li5dqu7duys0NFQJCQnasWNH8149AAAwVrMip7i4WGlpadq2bZsKCwt19uxZjRgxQrW1tdaa6dOn67333tOaNWtUXFyso0eP6v7777f219fXKykpSWfOnNHWrVv1xhtvKCcnR1lZWdaaQ4cOKSkpSXfffbfKy8s1bdo0PfTQQ9qwYYO1ZtWqVcrIyNDcuXO1e/du9e/fXy6XS8ePH7/8swIAANq8AK/X673UB584cUIREREqLi7WsGHDVF1drc6dOys3N1e//e1vJUn79+9X7969VVpaqiFDhmj9+vW69957dfToUUVGRkqSli9frlmzZunEiROy2WyaNWuW8vPzVVFRYT3X+PHjderUKRUUFEiSEhISNHjwYC1ZskSS1NDQoJiYGE2dOlWzZ8/+SfN7PB45HA5VV1fLbrdf6mloUvfZ+S16PDTt8Pwkf48AAGhlP/Xn92V9Jqe6ulqS1KFDB0lSWVmZzp49q8TERGtNbGysbrjhBpWWlkqSSktL1bdvXytwJMnlcsnj8Wjfvn3WmvOP0bim8RhnzpxRWVmZz5rAwEAlJiZaa5pSV1cnj8fjswEAADNdcuQ0NDRo2rRpGjp0qPr06SNJcrvdstlsCg8P91kbGRkpt9ttrTk/cBr3N+77oTUej0fffvutvvrqK9XX1ze5pvEYTcnOzpbD4bC2mJiYS335AADgZ+6SIyctLU0VFRV66623WnaiKygzM1PV1dXWduTIEX+PBAAArpDgS3lQenq68vLyVFJSoq5du1r3R0VF6cyZMzp16pTPuzlVVVWKioqy1vzfb0E1fvvq/DX/9xtZVVVVstvtateunYKCghQUFNTkmsZjNCUkJEQhISGX8pIBAEAb06x3crxer9LT07V27Vpt2rRJPXr08NkfHx+va665RkVFRdZ9Bw4cUGVlpZxOpyTJ6XRq7969Pt+CKiwslN1uV1xcnLXm/GM0rmk8hs1mU3x8vM+ahoYGFRUVWWsAAMDVrVnv5KSlpSk3N1fvvPOOwsLCrM+/OBwOtWvXTg6HQ6mpqcrIyFCHDh1kt9s1depUOZ1ODRkyRJI0YsQIxcXFadKkSVqwYIHcbrfmzJmjtLQ0612WRx99VEuWLNHMmTP14IMPatOmTVq9erXy8//3jaWMjAylpKRo0KBBuu2227Ro0SLV1tZqypQpLXuGAABAm9SsyFm2bJkk6a677vK5f+XKlZo8ebIk6eWXX1ZgYKDGjh2ruro6uVwuvfrqq9baoKAg5eXl6bHHHpPT6VT79u2VkpKiZ5991lrTo0cP5efna/r06Vq8eLG6du2q119/XS6Xy1ozbtw4nThxQllZWXK73RowYIAKCgou+DAyAAC4Ol3WdXLaOq6T0/ZxnRwAuPq0ynVyAAAAfq6IHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEZqduSUlJRozJgxio6OVkBAgNatW+ezf/LkyQoICPDZRo4c6bPm5MmTmjhxoux2u8LDw5WamqqamhqfNXv27NGdd96p0NBQxcTEaMGCBRfMsmbNGsXGxio0NFR9+/bV+++/39yXAwAADNXsyKmtrVX//v21dOnSi64ZOXKkjh07Zm1///vfffZPnDhR+/btU2FhofLy8lRSUqJHHnnE2u/xeDRixAh169ZNZWVleumllzRv3jytWLHCWrN161ZNmDBBqamp+vjjj5WcnKzk5GRVVFQ09yUBAAADBXi9Xu8lPzggQGvXrlVycrJ13+TJk3Xq1KkL3uFp9OmnnyouLk47d+7UoEGDJEkFBQUaPXq0vvzyS0VHR2vZsmV6+umn5Xa7ZbPZJEmzZ8/WunXrtH//fknSuHHjVFtbq7y8POvYQ4YM0YABA7R8+fKfNL/H45HD4VB1dbXsdvulnoYmdZ+d36LHQ9MOz0/y9wgAgFb2U39+X5HP5GzevFkRERHq1auXHnvsMX399dfWvtLSUoWHh1uBI0mJiYkKDAzU9u3brTXDhg2zAkeSXC6XDhw4oG+++cZak5iY6PO8LpdLpaWlF52rrq5OHo/HZwMAAGZq8cgZOXKk/va3v6moqEgvvviiiouLNWrUKNXX10uS3G63IiIifB4THBysDh06yO12W2siIyN91jTe/rE1jfubkp2dLYfDYW0xMTEt9KoBAMDPTXBLH3D8+PHWP/ft21f9+vXTTTfdpM2bN2v48OEt/XTNkpmZqYyMDOu2x+MhdAAAMNQV/wr5jTfeqE6dOungwYOSpKioKB0/ftxnzblz53Ty5ElFRUVZa6qqqnzWNN7+sTWN+5sSEhIiu93uswEAADNd8cj58ssv9fXXX6tLly6SJKfTqVOnTqmsrMxas2nTJjU0NCghIcFaU1JSorNnz1prCgsL1atXL11//fXWmqKiIp/nKiwslNPpvNIvCQAAtAHNjpyamhqVl5ervLxcknTo0CGVl5ersrJSNTU1mjFjhrZt26bDhw+rqKhI9913n3r27CmXyyVJ6t27t0aOHKmHH35YO3bs0EcffaT09HSNHz9e0dHRkqQHHnhANptNqamp2rdvn1atWqXFixf7/Krp8ccfV0FBgRYuXKj9+/dr3rx52rVrl9LT01vu7AAAgDar2ZGza9cuDRw4UAMHDpQkZWRkaODAgcrKylJQUJD27NmjX//617r55puVmpqq+Ph4/fOf/1RISIh1jDfffFOxsbEaPny4Ro8erTvuuMPnGjgOh0MbN27UoUOHFB8fryeeeEJZWVk+19K5/fbblZubqxUrVqh///56++23tW7dOvXp0+fyzwoAAGjzLus6OW0d18lp+7hODgBcffx6nRwAAAB/I3IAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYqdmRU1JSojFjxig6OloBAQFat26dz36v16usrCx16dJF7dq1U2Jioj777DOfNSdPntTEiRNlt9sVHh6u1NRU1dTU+KzZs2eP7rzzToWGhiomJkYLFiy4YJY1a9YoNjZWoaGh6tu3r95///3mvhwAAGCoZkdObW2t+vfvr6VLlza5f8GCBXrllVe0fPlybd++Xe3bt5fL5dJ3331nrZk4caL27dunwsJC5eXlqaSkRI888oi13+PxaMSIEerWrZvKysr00ksvad68eVqxYoW1ZuvWrZowYYJSU1P18ccfKzk5WcnJyaqoqGj+WQAAAMYJ8Hq93kt+cECA1q5dq+TkZOn7d3Gio6P1xBNP6Mknn5QkVVdXKzIyUjk5ORo/frw+/fRTxcXFaefOnRo0aJAkqaCgQKNHj9aXX36p6OhoLVu2TE8//bTcbrdsNpskafbs2Vq3bp32798vSRo3bpxqa2uVl5dnzTNkyBANGDBAy5cv/0nzezweORwOVVdXy263X+ppaFL32fktejw07fD8JH+PAABoZT/153eLfibn0KFDcrvdSkxMtO5zOBxKSEhQaWmpJKm0tFTh4eFW4EhSYmKiAgMDtX37dmvNsGHDrMCRJJfLpQMHDuibb76x1pz/PI1rGp+nKXV1dfJ4PD4bAAAwU4tGjtvtliRFRkb63B8ZGWntc7vdioiI8NkfHBysDh06+Kxp6hjnP8fF1jTub0p2drYcDoe1xcTEXMarBQAAP2dX1berMjMzVV1dbW1Hjhzx90gAAOAKadHIiYqKkiRVVVX53F9VVWXti4qK0vHjx332nzt3TidPnvRZ09Qxzn+Oi61p3N+UkJAQ2e12nw0AAJipRSOnR48eioqKUlFRkXWfx+PR9u3b5XQ6JUlOp1OnTp1SWVmZtWbTpk1qaGhQQkKCtaakpERnz5611hQWFqpXr166/vrrrTXnP0/jmsbnAQAAV7dmR05NTY3Ky8tVXl4uff9h4/LyclVWViogIEDTpk3Tc889p3fffVd79+7VH//4R0VHR1vfwOrdu7dGjhyphx9+WDt27NBHH32k9PR0jR8/XtHR0ZKkBx54QDabTampqdq3b59WrVqlxYsXKyMjw5rj8ccfV0FBgRYuXKj9+/dr3rx52rVrl9LT01vu7AAAgDYruLkP2LVrl+6++27rdmN4pKSkKCcnRzNnzlRtba0eeeQRnTp1SnfccYcKCgoUGhpqPebNN99Uenq6hg8frsDAQI0dO1avvPKKtd/hcGjjxo1KS0tTfHy8OnXqpKysLJ9r6dx+++3Kzc3VnDlz9NRTT+mXv/yl1q1bpz59+lzO+QAAAIa4rOvktHVcJ6ft4zo5AHD18ct1cgAAAH4uiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYKRgfw8AXI7us/P9PcJV4/D8JH+PAADNwjs5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBILR458+bNU0BAgM8WGxtr7f/uu++Ulpamjh076rrrrtPYsWNVVVXlc4zKykolJSXp2muvVUREhGbMmKFz5875rNm8ebNuvfVWhYSEqGfPnsrJyWnplwIAANqwK/JOzi233KJjx45Z25YtW6x906dP13vvvac1a9aouLhYR48e1f3332/tr6+vV1JSks6cOaOtW7fqjTfeUE5OjrKysqw1hw4dUlJSku6++26Vl5dr2rRpeuihh7Rhw4Yr8XIAAEAbFHxFDhocrKioqAvur66u1l//+lfl5ubqnnvukSStXLlSvXv31rZt2zRkyBBt3LhRn3zyiT744ANFRkZqwIAB+vOf/6xZs2Zp3rx5stlsWr58uXr06KGFCxdKknr37q0tW7bo5ZdflsvluhIvCQAAtDFX5J2czz77TNHR0brxxhs1ceJEVVZWSpLKysp09uxZJSYmWmtjY2N1ww03qLS0VJJUWlqqvn37KjIy0lrjcrnk8Xi0b98+a835x2hc03iMi6mrq5PH4/HZAACAmVo8chISEpSTk6OCggItW7ZMhw4d0p133qnTp0/L7XbLZrMpPDzc5zGRkZFyu92SJLfb7RM4jfsb9/3QGo/Ho2+//fais2VnZ8vhcFhbTExMi71uAADw89Liv64aNWqU9c/9+vVTQkKCunXrptWrV6tdu3Yt/XTNkpmZqYyMDOu2x+MhdAAAMNQV/wp5eHi4br75Zh08eFBRUVE6c+aMTp065bOmqqrK+gxPVFTUBd+2arz9Y2vsdvsPhlRISIjsdrvPBgAAzHTFI6empkaff/65unTpovj4eF1zzTUqKiqy9h84cECVlZVyOp2SJKfTqb179+r48ePWmsLCQtntdsXFxVlrzj9G45rGYwAAALR45Dz55JMqLi7W4cOHtXXrVv3mN79RUFCQJkyYIIfDodTUVGVkZOjDDz9UWVmZpkyZIqfTqSFDhkiSRowYobi4OE2aNEn/+te/tGHDBs2ZM0dpaWkKCQmRJD366KP64osvNHPmTO3fv1+vvvqqVq9erenTp7f0ywEAAG1Ui38m58svv9SECRP09ddfq3Pnzrrjjju0bds2de7cWZL08ssvKzAwUGPHjlVdXZ1cLpdeffVV6/FBQUHKy8vTY489JqfTqfbt2yslJUXPPvustaZHjx7Kz8/X9OnTtXjxYnXt2lWvv/46Xx8HAACWAK/X6/X3EP7i8XjkcDhUXV3d4p/P6T47v0WPB/jb4flJ/h4BAKRm/Pzmb1cBAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASMH+HgBA29B9dr6/R7gqHJ6f5O8RAGPwTg4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADBSm4+cpUuXqnv37goNDVVCQoJ27Njh75EAAMDPQJuOnFWrVikjI0Nz587V7t271b9/f7lcLh0/ftzfowEAAD8L8Hq9Xn8PcakSEhI0ePBgLVmyRJLU0NCgmJgYTZ06VbNnz/7Rx3s8HjkcDlVXV8tut7fobN1n57fo8QAALefw/CR/j4DL8FN/fge36lQt6MyZMyorK1NmZqZ1X2BgoBITE1VaWtrkY+rq6lRXV2fdrq6ulr4/WS2toe6/LX5MAEDLuBL/v4/W0/jf34+9T9NmI+err75SfX29IiMjfe6PjIzU/v37m3xMdna2nnnmmQvuj4mJuWJzAgB+fhyL/D0BWsLp06flcDguur/NRs6lyMzMVEZGhnW7oaFBJ0+eVMeOHRUQENBiz+PxeBQTE6MjR460+K/B8D+c59bDuW4dnOfWwXluHVfyPHu9Xp0+fVrR0dE/uK7NRk6nTp0UFBSkqqoqn/urqqoUFRXV5GNCQkIUEhLic194ePgVm9Fut/M/oFbAeW49nOvWwXluHZzn1nGlzvMPvYPTqM1+u8pmsyk+Pl5FRUXWfQ0NDSoqKpLT6fTrbAAAwP/a7Ds5kpSRkaGUlBQNGjRIt912mxYtWqTa2lpNmTLF36MBAAA/a9ORM27cOJ04cUJZWVlyu90aMGCACgoKLvgwcmsLCQnR3LlzL/jVGFoW57n1cK5bB+e5dXCeW8fP4Ty36evkAAAAXEyb/UwOAADADyFyAACAkYgcAABgJCIHAAAYichpQSUlJRozZoyio6MVEBCgdevW+XskI2VnZ2vw4MEKCwtTRESEkpOTdeDAAX+PZZxly5apX79+1oW8nE6n1q9f7++xjDd//nwFBARo2rRp/h7FOPPmzVNAQIDPFhsb6++xjPSf//xHf/jDH9SxY0e1a9dOffv21a5du1p9DiKnBdXW1qp///5aunSpv0cxWnFxsdLS0rRt2zYVFhbq7NmzGjFihGpra/09mlG6du2q+fPnq6ysTLt27dI999yj++67T/v27fP3aMbauXOn/vKXv6hfv37+HsVYt9xyi44dO2ZtW7Zs8fdIxvnmm280dOhQXXPNNVq/fr0++eQTLVy4UNdff32rz9Kmr5PzczNq1CiNGjXK32MYr6CgwOd2Tk6OIiIiVFZWpmHDhvltLtOMGTPG5/bzzz+vZcuWadu2bbrlllv8NpepampqNHHiRL322mt67rnn/D2OsYKDgy/6p3/QMl588UXFxMRo5cqV1n09evTwyyy8k4M2r7q6WpLUoUMHf49irPr6er311luqra3lz6ZcIWlpaUpKSlJiYqK/RzHaZ599pujoaN14442aOHGiKisr/T2Scd59910NGjRIv/vd7xQREaGBAwfqtdde88ssvJODNq2hoUHTpk3T0KFD1adPH3+PY5y9e/fK6XTqu+++03XXXae1a9cqLi7O32MZ56233tLu3bu1c+dOf49itISEBOXk5KhXr146duyYnnnmGd15552qqKhQWFiYv8czxhdffKFly5YpIyNDTz31lHbu3Kk//elPstlsSklJadVZiBy0aWlpaaqoqOD36ldIr169VF5erurqar399ttKSUlRcXExodOCjhw5oscff1yFhYUKDQ319zhGO//jBP369VNCQoK6deum1atXKzU11a+zmaShoUGDBg3SCy+8IEkaOHCgKioqtHz58laPHH5dhTYrPT1deXl5+vDDD9W1a1d/j2Mkm82mnj17Kj4+XtnZ2erfv78WL17s77GMUlZWpuPHj+vWW29VcHCwgoODVVxcrFdeeUXBwcGqr6/394jGCg8P180336yDBw/6exSjdOnS5YJ/Eerdu7dffjXIOzloc7xer6ZOnaq1a9dq8+bNfvtA29WooaFBdXV1/h7DKMOHD9fevXt97psyZYpiY2M1a9YsBQUF+W0209XU1Ojzzz/XpEmT/D2KUYYOHXrBZT3+/e9/q1u3bq0+C5HTgmpqanz+jeDQoUMqLy9Xhw4ddMMNN/h1NpOkpaUpNzdX77zzjsLCwuR2uyVJDodD7dq18/d4xsjMzNSoUaN0ww036PTp08rNzdXmzZu1YcMGf49mlLCwsAs+T9a+fXt17NiRz5m1sCeffFJjxoxRt27ddPToUc2dO1dBQUGaMGGCv0czyvTp03X77bfrhRde0O9//3vt2LFDK1as0IoVK1p/GC9azIcffuiVdMGWkpLi79GM0tQ5luRduXKlv0czyoMPPujt1q2b12azeTt37uwdPny4d+PGjf4e66rwq1/9yvv444/7ewzjjBs3ztulSxevzWbz/uIXv/COGzfOe/DgQX+PZaT33nvP26dPH29ISIg3NjbWu2LFCr/MEeD9/z80AAAAjMIHjwEAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEb6f8M11LhmDDKeAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.hist(mkbatch(2**15, large=False, target=\"onpath\")[1].cpu(), bins=5)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([2.0379e+04, 7.8380e+03, 3.0080e+03, 1.0690e+03, 3.4600e+02,\n",
" 9.8000e+01, 2.5000e+01, 5.0000e+00]),\n",
" array([1., 2., 3., 4., 5., 6., 7., 8., 9.]),\n",
" <BarContainer object of 8 artists>)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAGdCAYAAADwjmIIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwsElEQVR4nO3de3BUZZ7/8U8T7AQwHa65DSEEcAiXAHIxtBeUJZOAGcasjCKioEQZ3KBAFCHKYICdCQMFyBS3ZVXCrCDIlkQFBUIQGIaoEIwQHDKCMNGFDs4IaYgaSHJ+/2zOz15ADXZs8vh+VZ2qnOf59unv01r0p06fPu2wLMsSAACAYZoEugEAAICGQMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABipaaAbCKTa2lqdPHlSoaGhcjgcgW4HAAB8D5Zl6dy5c4qOjlaTJlc+X/OTDjknT55UTExMoNsAAABX4dNPP1X79u2vOP+TDjmhoaHS/75ILpcr0O0AAIDvwev1KiYmxn4fv5KfdMip+4jK5XIRcgAAaGS+61ITLjwGAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYCRCDgAAMFLTQDdgqo7TNwe6hQZ3Ym5qoFsAAOCKOJMDAACMRMgBAABGqlfIycnJ0YABAxQaGqrw8HClpaWptLTUp+brr79WRkaG2rRpo+uvv14jRoxQeXm5T01ZWZlSU1PVvHlzhYeHa+rUqaqurvap2blzp/r27avg4GB16dJFubm5l/SzdOlSdezYUSEhIUpMTNT7779fv9UDAABj1Svk7Nq1SxkZGXr33XeVn5+vixcvKjk5WZWVlXbNlClT9Oabb2rDhg3atWuXTp48qbvvvtuer6mpUWpqqi5cuKC9e/dq9erVys3N1cyZM+2a48ePKzU1VYMHD1ZxcbEmT56sRx55RFu3brVr1q9fr8zMTD333HM6cOCAevfurZSUFJ0+ffqHvyoAAKDRc1iWZV3tgz///HOFh4dr165dGjRokCoqKtSuXTutXbtWv/71ryVJR44cUbdu3VRYWKiBAwfq7bff1i9/+UudPHlSERERkqQVK1Zo2rRp+vzzz+V0OjVt2jRt3rxZJSUl9nPdd999Onv2rLZs2SJJSkxM1IABA7RkyRJJUm1trWJiYvT4449r+vTp36t/r9ersLAwVVRUyOVyXe3LcFlceAwAQMP4vu/fP+ianIqKCklS69atJUlFRUW6ePGikpKS7Jr4+Hh16NBBhYWFkqTCwkIlJCTYAUeSUlJS5PV6dfjwYbvmm8eoq6k7xoULF1RUVORT06RJEyUlJdk1l1NVVSWv1+uzAQAAM111yKmtrdXkyZN1yy23qGfPnpIkj8cjp9Opli1b+tRGRETI4/HYNd8MOHXzdXPfVuP1evXVV1/pH//4h2pqai5bU3eMy8nJyVFYWJi9xcTEXO3yAQDANe6qQ05GRoZKSkq0bt06/3bUgLKyslRRUWFvn376aaBbAgAADeSqbgY4ceJEbdq0Sbt371b79u3t8cjISF24cEFnz571OZtTXl6uyMhIu+b/fguq7ttX36z5v9/IKi8vl8vlUrNmzRQUFKSgoKDL1tQd43KCg4MVHBx8NUsGAACNTL3O5FiWpYkTJ2rjxo3asWOH4uLifOb79eun6667TgUFBfZYaWmpysrK5Ha7JUlut1uHDh3y+RZUfn6+XC6Xunfvbtd88xh1NXXHcDqd6tevn09NbW2tCgoK7BoAAPDTVq8zORkZGVq7dq1ef/11hYaG2te/hIWFqVmzZgoLC1N6eroyMzPVunVruVwuPf7443K73Ro4cKAkKTk5Wd27d9eDDz6oefPmyePxaMaMGcrIyLDPskyYMEFLlizR008/rXHjxmnHjh169dVXtXnz///GUmZmpsaOHav+/fvrpptu0vPPP6/Kyko9/PDD/n2FAABAo1SvkLN8+XJJ0h133OEzvmrVKj300EOSpEWLFqlJkyYaMWKEqqqqlJKSomXLltm1QUFB2rRpkx577DG53W61aNFCY8eO1ezZs+2auLg4bd68WVOmTNHixYvVvn17vfDCC0pJSbFrRo4cqc8//1wzZ86Ux+NRnz59tGXLlksuRgYAAD9NP+g+OY0d98n5YbhPDgAgEH6U++QAAABcqwg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAj1Tvk7N69W8OHD1d0dLQcDofy8vJ85h0Ox2W3+fPn2zUdO3a8ZH7u3Lk+xzl48KBuu+02hYSEKCYmRvPmzbuklw0bNig+Pl4hISFKSEjQW2+9Vd/lAAAAQ9U75FRWVqp3795aunTpZedPnTrls7300ktyOBwaMWKET93s2bN96h5//HF7zuv1Kjk5WbGxsSoqKtL8+fOVnZ2tlStX2jV79+7VqFGjlJ6erg8++EBpaWlKS0tTSUlJfZcEAAAM1LS+Dxg2bJiGDRt2xfnIyEif/ddff12DBw9Wp06dfMZDQ0Mvqa2zZs0aXbhwQS+99JKcTqd69Oih4uJiLVy4UOPHj5ckLV68WEOHDtXUqVMlSXPmzFF+fr6WLFmiFStW1HdZAADAMA16TU55ebk2b96s9PT0S+bmzp2rNm3a6MYbb9T8+fNVXV1tzxUWFmrQoEFyOp32WEpKikpLS3XmzBm7JikpyeeYKSkpKiwsvGI/VVVV8nq9PhsAADBTvc/k1Mfq1asVGhqqu+++22f8iSeeUN++fdW6dWvt3btXWVlZOnXqlBYuXChJ8ng8iouL83lMRESEPdeqVSt5PB577Js1Ho/niv3k5ORo1qxZflwhAAC4VjVoyHnppZc0evRohYSE+IxnZmbaf/fq1UtOp1O/+c1vlJOTo+Dg4AbrJysry+e5vV6vYmJiGuz5AABA4DRYyPnzn/+s0tJSrV+//jtrExMTVV1drRMnTqhr166KjIxUeXm5T03dft11PFequdJ1PpIUHBzcoCEKAABcOxrsmpwXX3xR/fr1U+/evb+ztri4WE2aNFF4eLgkye12a/fu3bp48aJdk5+fr65du6pVq1Z2TUFBgc9x8vPz5Xa7/b4WAADQ+NQ75Jw/f17FxcUqLi6WJB0/flzFxcUqKyuza7xerzZs2KBHHnnkkscXFhbq+eef14cffqhPPvlEa9as0ZQpU/TAAw/YAeb++++X0+lUenq6Dh8+rPXr12vx4sU+HzVNmjRJW7Zs0YIFC3TkyBFlZ2dr//79mjhx4tW+FgAAwCD1/rhq//79Gjx4sL1fFzzGjh2r3NxcSdK6detkWZZGjRp1yeODg4O1bt06ZWdnq6qqSnFxcZoyZYpPgAkLC9O2bduUkZGhfv36qW3btpo5c6b99XFJuvnmm7V27VrNmDFDzzzzjG644Qbl5eWpZ8+e9X8VAACAcRyWZVmBbiJQvF6vwsLCVFFRIZfL5ddjd5y+2a/HuxadmJsa6BYAAD9B3/f9m9+uAgAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMVO+Qs3v3bg0fPlzR0dFyOBzKy8vzmX/ooYfkcDh8tqFDh/rUfPHFFxo9erRcLpdatmyp9PR0nT9/3qfm4MGDuu222xQSEqKYmBjNmzfvkl42bNig+Ph4hYSEKCEhQW+99VZ9lwMAAAxV75BTWVmp3r17a+nSpVesGTp0qE6dOmVvr7zyis/86NGjdfjwYeXn52vTpk3avXu3xo8fb897vV4lJycrNjZWRUVFmj9/vrKzs7Vy5Uq7Zu/evRo1apTS09P1wQcfKC0tTWlpaSopKanvkgAAgIEclmVZV/1gh0MbN25UWlqaPfbQQw/p7Nmzl5zhqfPXv/5V3bt31759+9S/f39J0pYtW3TnnXfqs88+U3R0tJYvX65nn31WHo9HTqdTkjR9+nTl5eXpyJEjkqSRI0eqsrJSmzZtso89cOBA9enTRytWrPhe/Xu9XoWFhamiokIul+tqX4bL6jh9s1+Pdy06MTc10C0AAH6Cvu/7d4Nck7Nz506Fh4era9eueuyxx/TPf/7TnissLFTLli3tgCNJSUlJatKkid577z27ZtCgQXbAkaSUlBSVlpbqzJkzdk1SUpLP86akpKiwsPCKfVVVVcnr9fpsAADATH4POUOHDtWf/vQnFRQU6A9/+IN27dqlYcOGqaamRpLk8XgUHh7u85imTZuqdevW8ng8dk1ERIRPTd3+d9XUzV9OTk6OwsLC7C0mJsZPqwYAANeapv4+4H333Wf/nZCQoF69eqlz587auXOnhgwZ4u+nq5esrCxlZmba+16vl6ADAIChGvwr5J06dVLbtm119OhRSVJkZKROnz7tU1NdXa0vvvhCkZGRdk15eblPTd3+d9XUzV9OcHCwXC6XzwYAAMzU4CHns88+0z//+U9FRUVJktxut86ePauioiK7ZseOHaqtrVViYqJds3v3bl28eNGuyc/PV9euXdWqVSu7pqCgwOe58vPz5Xa7G3pJAACgEah3yDl//ryKi4tVXFwsSTp+/LiKi4tVVlam8+fPa+rUqXr33Xd14sQJFRQU6K677lKXLl2UkpIiSerWrZuGDh2qRx99VO+//77+8pe/aOLEibrvvvsUHR0tSbr//vvldDqVnp6uw4cPa/369Vq8eLHPR02TJk3Sli1btGDBAh05ckTZ2dnav3+/Jk6c6L9XBwAANFr1Djn79+/XjTfeqBtvvFGSlJmZqRtvvFEzZ85UUFCQDh48qF/96lf6+c9/rvT0dPXr109//vOfFRwcbB9jzZo1io+P15AhQ3TnnXfq1ltv9bkHTlhYmLZt26bjx4+rX79+evLJJzVz5kyfe+ncfPPNWrt2rVauXKnevXvrv//7v5WXl6eePXv+8FcFAAA0ej/oPjmNHffJ+WG4Tw4AIBACep8cAACAQCPkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMVO+Qs3v3bg0fPlzR0dFyOBzKy8uz5y5evKhp06YpISFBLVq0UHR0tMaMGaOTJ0/6HKNjx45yOBw+29y5c31qDh48qNtuu00hISGKiYnRvHnzLullw4YNio+PV0hIiBISEvTWW2/VdzkAAMBQ9Q45lZWV6t27t5YuXXrJ3JdffqkDBw7ot7/9rQ4cOKDXXntNpaWl+tWvfnVJ7ezZs3Xq1Cl7e/zxx+05r9er5ORkxcbGqqioSPPnz1d2drZWrlxp1+zdu1ejRo1Senq6PvjgA6WlpSktLU0lJSX1XRIAADBQ0/o+YNiwYRo2bNhl58LCwpSfn+8ztmTJEt10000qKytThw4d7PHQ0FBFRkZe9jhr1qzRhQsX9NJLL8npdKpHjx4qLi7WwoULNX78eEnS4sWLNXToUE2dOlWSNGfOHOXn52vJkiVasWJFfZcFAAAM0+DX5FRUVMjhcKhly5Y+43PnzlWbNm104403av78+aqurrbnCgsLNWjQIDmdTnssJSVFpaWlOnPmjF2TlJTkc8yUlBQVFhZesZeqqip5vV6fDQAAmKneZ3Lq4+uvv9a0adM0atQouVwue/yJJ55Q37591bp1a+3du1dZWVk6deqUFi5cKEnyeDyKi4vzOVZERIQ916pVK3k8HnvsmzUej+eK/eTk5GjWrFl+XiUAALgWNVjIuXjxou69915ZlqXly5f7zGVmZtp/9+rVS06nU7/5zW+Uk5Oj4ODghmpJWVlZPs/t9XoVExPTYM8HAAACp0FCTl3A+fvf/64dO3b4nMW5nMTERFVXV+vEiRPq2rWrIiMjVV5e7lNTt193Hc+Vaq50nY8kBQcHN2iIAgAA1w6/X5NTF3A+/vhjbd++XW3atPnOxxQXF6tJkyYKDw+XJLndbu3evVsXL160a/Lz89W1a1e1atXKrikoKPA5Tn5+vtxut7+XBAAAGqF6n8k5f/68jh49au8fP35cxcXFat26taKiovTrX/9aBw4c0KZNm1RTU2NfI9O6dWs5nU4VFhbqvffe0+DBgxUaGqrCwkJNmTJFDzzwgB1g7r//fs2aNUvp6emaNm2aSkpKtHjxYi1atMh+3kmTJun222/XggULlJqaqnXr1mn//v0+XzMHAAA/XQ7Lsqz6PGDnzp0aPHjwJeNjx45Vdnb2JRcM13nnnXd0xx136MCBA/q3f/s3HTlyRFVVVYqLi9ODDz6ozMxMn4+SDh48qIyMDO3bt09t27bV448/rmnTpvkcc8OGDZoxY4ZOnDihG264QfPmzdOdd975vdfi9XoVFhamioqK7/xIrb46Tt/s1+Ndi07MTQ10CwCAn6Dv+/5d75BjEkLOD0PIAQAEwvd9/+a3qwAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAI9U75OzevVvDhw9XdHS0HA6H8vLyfOYty9LMmTMVFRWlZs2aKSkpSR9//LFPzRdffKHRo0fL5XKpZcuWSk9P1/nz531qDh48qNtuu00hISGKiYnRvHnzLullw4YNio+PV0hIiBISEvTWW2/VdzkAAMBQ9Q45lZWV6t27t5YuXXrZ+Xnz5umPf/yjVqxYoffee08tWrRQSkqKvv76a7tm9OjROnz4sPLz87Vp0ybt3r1b48ePt+e9Xq+Sk5MVGxuroqIizZ8/X9nZ2Vq5cqVds3fvXo0aNUrp6en64IMPlJaWprS0NJWUlNT/VQAAAMZxWJZlXfWDHQ5t3LhRaWlp0v+exYmOjtaTTz6pp556SpJUUVGhiIgI5ebm6r777tNf//pXde/eXfv27VP//v0lSVu2bNGdd96pzz77TNHR0Vq+fLmeffZZeTweOZ1OSdL06dOVl5enI0eOSJJGjhypyspKbdq0ye5n4MCB6tOnj1asWPG9+vd6vQoLC1NFRYVcLtfVvgyX1XH6Zr8e71p0Ym5qoFsAAPwEfd/3b79ek3P8+HF5PB4lJSXZY2FhYUpMTFRhYaEkqbCwUC1btrQDjiQlJSWpSZMmeu+99+yaQYMG2QFHklJSUlRaWqozZ87YNd98nrqauucBAAA/bU39eTCPxyNJioiI8BmPiIiw5zwej8LDw32baNpUrVu39qmJi4u75Bh1c61atZLH4/nW57mcqqoqVVVV2fter/cqVwoAAK51P6lvV+Xk5CgsLMzeYmJiAt0SAABoIH4NOZGRkZKk8vJyn/Hy8nJ7LjIyUqdPn/aZr66u1hdffOFTc7ljfPM5rlRTN385WVlZqqiosLdPP/30B6wWAABcy/wacuLi4hQZGamCggJ7zOv16r333pPb7ZYkud1unT17VkVFRXbNjh07VFtbq8TERLtm9+7dunjxol2Tn5+vrl27qlWrVnbNN5+nrqbueS4nODhYLpfLZwMAAGaqd8g5f/68iouLVVxcLP3vxcbFxcUqKyuTw+HQ5MmT9e///u964403dOjQIY0ZM0bR0dH2N7C6deumoUOH6tFHH9X777+vv/zlL5o4caLuu+8+RUdHS5Luv/9+OZ1Opaen6/Dhw1q/fr0WL16szMxMu49JkyZpy5YtWrBggY4cOaLs7Gzt379fEydO9N+rAwAAGq16X3i8f/9+DR482N6vCx5jx45Vbm6unn76aVVWVmr8+PE6e/asbr31Vm3ZskUhISH2Y9asWaOJEydqyJAhatKkiUaMGKE//vGP9nxYWJi2bdumjIwM9evXT23bttXMmTN97qVz8803a+3atZoxY4aeeeYZ3XDDDcrLy1PPnj1/yOsBAAAM8YPuk9PYcZ+cH4b75AAAAiEg98kBAAC4VhByAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACM1DTQDaDx6jh9c6BbaFAn5qYGugUAwA/AmRwAAGAkv4ecjh07yuFwXLJlZGRIku64445L5iZMmOBzjLKyMqWmpqp58+YKDw/X1KlTVV1d7VOzc+dO9e3bV8HBwerSpYtyc3P9vRQAANCI+f3jqn379qmmpsbeLykp0S9+8Qvdc8899tijjz6q2bNn2/vNmze3/66pqVFqaqoiIyO1d+9enTp1SmPGjNF1112n3//+95Kk48ePKzU1VRMmTNCaNWtUUFCgRx55RFFRUUpJSfH3kgAAQCPk95DTrl07n/25c+eqc+fOuv322+2x5s2bKzIy8rKP37Ztmz766CNt375dERER6tOnj+bMmaNp06YpOztbTqdTK1asUFxcnBYsWCBJ6tatm/bs2aNFixYRcgAAgNTQ1+RcuHBBL7/8ssaNGyeHw2GPr1mzRm3btlXPnj2VlZWlL7/80p4rLCxUQkKCIiIi7LGUlBR5vV4dPnzYrklKSvJ5rpSUFBUWFn5rP1VVVfJ6vT4bAAAwU4N+uyovL09nz57VQw89ZI/df//9io2NVXR0tA4ePKhp06aptLRUr732miTJ4/H4BBxJ9r7H4/nWGq/Xq6+++krNmjW7bD85OTmaNWuW39cJAACuPQ0acl588UUNGzZM0dHR9tj48ePtvxMSEhQVFaUhQ4bo2LFj6ty5c0O2o6ysLGVmZtr7Xq9XMTExDfqcAAAgMBos5Pz973/X9u3b7TM0V5KYmChJOnr0qDp37qzIyEi9//77PjXl5eWSZF/HExkZaY99s8blcl3xLI4kBQcHKzg4+KrXBAAAGo8GuyZn1apVCg8PV2rqt99Qrbi4WJIUFRUlSXK73Tp06JBOnz5t1+Tn58vlcql79+52TUFBgc9x8vPz5Xa7G2AlAACgMWqQkFNbW6tVq1Zp7Nixatr0/58sOnbsmObMmaOioiKdOHFCb7zxhsaMGaNBgwapV69ekqTk5GR1795dDz74oD788ENt3bpVM2bMUEZGhn0WZsKECfrkk0/09NNP68iRI1q2bJleffVVTZkypSGWAwAAGqEGCTnbt29XWVmZxo0b5zPudDq1fft2JScnKz4+Xk8++aRGjBihN998064JCgrSpk2bFBQUJLfbrQceeEBjxozxua9OXFycNm/erPz8fPXu3VsLFizQCy+8wNfHAQCAzWFZlhXoJgLF6/UqLCxMFRUVcrlcfj226b/r9FPAb1cBwLXp+75/89tVAADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICR/B5ysrOz5XA4fLb4+Hh7/uuvv1ZGRobatGmj66+/XiNGjFB5ebnPMcrKypSamqrmzZsrPDxcU6dOVXV1tU/Nzp071bdvXwUHB6tLly7Kzc3191IAAEAj1iBncnr06KFTp07Z2549e+y5KVOm6M0339SGDRu0a9cunTx5Unfffbc9X1NTo9TUVF24cEF79+7V6tWrlZubq5kzZ9o1x48fV2pqqgYPHqzi4mJNnjxZjzzyiLZu3doQywEAAI1Q0wY5aNOmioyMvGS8oqJCL774otauXat/+Zd/kSStWrVK3bp107vvvquBAwdq27Zt+uijj7R9+3ZFRESoT58+mjNnjqZNm6bs7Gw5nU6tWLFCcXFxWrBggSSpW7du2rNnjxYtWqSUlJSGWBIAAGhkGuRMzscff6zo6Gh16tRJo0ePVllZmSSpqKhIFy9eVFJSkl0bHx+vDh06qLCwUJJUWFiohIQERURE2DUpKSnyer06fPiwXfPNY9TV1B3jSqqqquT1en02AABgJr+HnMTEROXm5mrLli1avny5jh8/rttuu03nzp2Tx+OR0+lUy5YtfR4TEREhj8cjSfJ4PD4Bp26+bu7barxer7766qsr9paTk6OwsDB7i4mJ8du6AQDAtcXvH1cNGzbM/rtXr15KTExUbGysXn31VTVr1szfT1cvWVlZyszMtPe9Xi9BBwAAQzX4V8hbtmypn//85zp69KgiIyN14cIFnT171qemvLzcvoYnMjLykm9b1e1/V43L5frWIBUcHCyXy+WzAQAAMzV4yDl//ryOHTumqKgo9evXT9ddd50KCgrs+dLSUpWVlcntdkuS3G63Dh06pNOnT9s1+fn5crlc6t69u13zzWPU1dQdAwAAwO8h56mnntKuXbt04sQJ7d27V//6r/+qoKAgjRo1SmFhYUpPT1dmZqbeeecdFRUV6eGHH5bb7dbAgQMlScnJyerevbsefPBBffjhh9q6datmzJihjIwMBQcHS5ImTJigTz75RE8//bSOHDmiZcuW6dVXX9WUKVP8vRwAANBI+f2anM8++0yjRo3SP//5T7Vr10633nqr3n33XbVr106StGjRIjVp0kQjRoxQVVWVUlJStGzZMvvxQUFB2rRpkx577DG53W61aNFCY8eO1ezZs+2auLg4bd68WVOmTNHixYvVvn17vfDCC3x9HAAA2ByWZVmBbiJQvF6vwsLCVFFR4ffrczpO3+zX4+HHd2JuaqBbAABcxvd9/+a3qwAAgJEIOQAAwEiEHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIxEyAEAAEYi5AAAACMRcgAAgJEIOQAAwEiEHAAAYKSmgW4AuFZ1nL450C00qBNzUwPdAgA0KM7kAAAAIxFyAACAkQg5AADASIQcAABgJL+HnJycHA0YMEChoaEKDw9XWlqaSktLfWruuOMOORwOn23ChAk+NWVlZUpNTVXz5s0VHh6uqVOnqrq62qdm586d6tu3r4KDg9WlSxfl5ub6ezkAAKCR8nvI2bVrlzIyMvTuu+8qPz9fFy9eVHJysiorK33qHn30UZ06dcre5s2bZ8/V1NQoNTVVFy5c0N69e7V69Wrl5uZq5syZds3x48eVmpqqwYMHq7i4WJMnT9YjjzyirVu3+ntJAACgEfL7V8i3bNnis5+bm6vw8HAVFRVp0KBB9njz5s0VGRl52WNs27ZNH330kbZv366IiAj16dNHc+bM0bRp05SdnS2n06kVK1YoLi5OCxYskCR169ZNe/bs0aJFi5SSkuLvZQEAgEamwa/JqaiokCS1bt3aZ3zNmjVq27atevbsqaysLH355Zf2XGFhoRISEhQREWGPpaSkyOv16vDhw3ZNUlKSzzFTUlJUWFh4xV6qqqrk9Xp9NgAAYKYGvRlgbW2tJk+erFtuuUU9e/a0x++//37FxsYqOjpaBw8e1LRp01RaWqrXXntNkuTxeHwCjiR73+PxfGuN1+vVV199pWbNml3ST05OjmbNmtUgawUAANeWBg05GRkZKikp0Z49e3zGx48fb/+dkJCgqKgoDRkyRMeOHVPnzp0brJ+srCxlZmba+16vVzExMQ32fAAAIHAa7OOqiRMnatOmTXrnnXfUvn37b61NTEyUJB09elSSFBkZqfLycp+auv2663iuVONyuS57FkeSgoOD5XK5fDYAAGAmv4ccy7I0ceJEbdy4UTt27FBcXNx3Pqa4uFiSFBUVJUlyu906dOiQTp8+bdfk5+fL5XKpe/fudk1BQYHPcfLz8+V2u/28IgAA0Bj5PeRkZGTo5Zdf1tq1axUaGiqPxyOPx6OvvvpKknTs2DHNmTNHRUVFOnHihN544w2NGTNGgwYNUq9evSRJycnJ6t69ux588EF9+OGH2rp1q2bMmKGMjAwFBwdLkiZMmKBPPvlETz/9tI4cOaJly5bp1Vdf1ZQpU/y9JAAA0Aj5PeQsX75cFRUVuuOOOxQVFWVv69evlyQ5nU5t375dycnJio+P15NPPqkRI0bozTfftI8RFBSkTZs2KSgoSG63Ww888IDGjBmj2bNn2zVxcXHavHmz8vPz1bt3by1YsEAvvPACXx8HAACSJIdlWVagmwgUr9ersLAwVVRU+P36nI7TN/v1eIC/nZibGugWAOCqfN/3b367CgAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICRCDkAAMBIhBwAAGAkQg4AADASIQcAABiJkAMAAIzUNNANAAiMjtM3B7qFBndibmqgWwAQQJzJAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACMRMgBAABGIuQAAAAjEXIAAICR+O0qAMYy/fe5+G0u4NtxJgcAABip0YecpUuXqmPHjgoJCVFiYqLef//9QLcEAACuAY065Kxfv16ZmZl67rnndODAAfXu3VspKSk6ffp0oFsDAAAB1qhDzsKFC/Xoo4/q4YcfVvfu3bVixQo1b95cL730UqBbAwAAAdZoLzy+cOGCioqKlJWVZY81adJESUlJKiwsvOxjqqqqVFVVZe9XVFRIkrxer9/7q6360u/HBIBv6jBlQ6BbaFAls1IC3QKuUXXv25ZlfWtdow05//jHP1RTU6OIiAif8YiICB05cuSyj8nJydGsWbMuGY+JiWmwPgEAVyfs+UB3gGvduXPnFBYWdsX5RhtyrkZWVpYyMzPt/draWn3xxRdq06aNHA6H357H6/UqJiZGn376qVwul9+Oey0xfY2sr/EzfY2sr/EzfY0NuT7LsnTu3DlFR0d/a12jDTlt27ZVUFCQysvLfcbLy8sVGRl52ccEBwcrODjYZ6xly5YN1qPL5TLyf9xvMn2NrK/xM32NrK/xM32NDbW+bzuDU6fRXnjsdDrVr18/FRQU2GO1tbUqKCiQ2+0OaG8AACDwGu2ZHEnKzMzU2LFj1b9/f9100016/vnnVVlZqYcffjjQrQEAgABr1CFn5MiR+vzzzzVz5kx5PB716dNHW7ZsueRi5B9bcHCwnnvuuUs+GjOJ6WtkfY2f6WtkfY2f6Wu8FtbnsL7r+1cAAACNUKO9JgcAAODbEHIAAICRCDkAAMBIhBwAAGAkQo4f7d69W8OHD1d0dLQcDofy8vIC3ZJf5eTkaMCAAQoNDVV4eLjS0tJUWloa6Lb8avny5erVq5d98yq3262333470G01mLlz58rhcGjy5MmBbsUvsrOz5XA4fLb4+PhAt+V3//M//6MHHnhAbdq0UbNmzZSQkKD9+/cHui2/6Nix4yX/DR0OhzIyMgLdml/U1NTot7/9reLi4tSsWTN17txZc+bM+c7fYGpMzp07p8mTJys2NlbNmjXTzTffrH379gWkl0b9FfJrTWVlpXr37q1x48bp7rvvDnQ7frdr1y5lZGRowIABqq6u1jPPPKPk5GR99NFHatGiRaDb84v27dtr7ty5uuGGG2RZllavXq277rpLH3zwgXr06BHo9vxq3759+o//+A/16tUr0K34VY8ePbR9+3Z7v2lTs/6ZO3PmjG655RYNHjxYb7/9ttq1a6ePP/5YrVq1CnRrfrFv3z7V1NTY+yUlJfrFL36he+65J6B9+csf/vAHLV++XKtXr1aPHj20f/9+PfzwwwoLC9MTTzwR6Pb84pFHHlFJSYn+67/+S9HR0Xr55ZeVlJSkjz76SD/72c9+3GYsNAhJ1saNGwPdRoM6ffq0JcnatWtXoFtpUK1atbJeeOGFQLfhV+fOnbNuuOEGKz8/37r99tutSZMmBbolv3juuees3r17B7qNBjVt2jTr1ltvDXQbP5pJkyZZnTt3tmprawPdil+kpqZa48aN8xm7++67rdGjRwesJ3/68ssvraCgIGvTpk0+43379rWeffbZH70fPq7CVauoqJAktW7dOtCtNIiamhqtW7dOlZWVxv1USEZGhlJTU5WUlBToVvzu448/VnR0tDp16qTRo0errKws0C351RtvvKH+/fvrnnvuUXh4uG688Ub953/+Z6DbahAXLlzQyy+/rHHjxvn1R5QD6eabb1ZBQYH+9re/SZI+/PBD7dmzR8OGDQt0a35RXV2tmpoahYSE+Iw3a9ZMe/bs+dH7Mes8Ln40tbW1mjx5sm655Rb17Nkz0O341aFDh+R2u/X111/r+uuv18aNG9W9e/dAt+U369at04EDBwL2GXlDSkxMVG5urrp27apTp05p1qxZuu2221RSUqLQ0NBAt+cXn3zyiZYvX67MzEw988wz2rdvn5544gk5nU6NHTs20O35VV5ens6ePauHHnoo0K34zfTp0+X1ehUfH6+goCDV1NTod7/7nUaPHh3o1vwiNDRUbrdbc+bMUbdu3RQREaFXXnlFhYWF6tKly4/f0I9+7ugnwvSPqyZMmGDFxsZan376aaBb8buqqirr448/tvbv329Nnz7datu2rXX48OFAt+UXZWVlVnh4uPXhhx/aYyZ9XPV/nTlzxnK5XEZ93HjddddZbrfbZ+zxxx+3Bg4cGLCeGkpycrL1y1/+MtBt+NUrr7xitW/f3nrllVesgwcPWn/605+s1q1bW7m5uYFuzW+OHj1qDRo0yJJkBQUFWQMGDLBGjx5txcfH/+i9EHIaiMkhJyMjw2rfvr31ySefBLqVH8WQIUOs8ePHB7oNv9i4caP9D0/dJslyOBxWUFCQVV1dHegW/a5///7W9OnTA92G33To0MFKT0/3GVu2bJkVHR0dsJ4awokTJ6wmTZpYeXl5gW7Fr9q3b28tWbLEZ2zOnDlW165dA9ZTQzl//rx18uRJy7Is695777XuvPPOH70HrsnB92ZZliZOnKiNGzdqx44diouLC3RLP4ra2lpVVVUFug2/GDJkiA4dOqTi4mJ769+/v0aPHq3i4mIFBQUFukW/On/+vI4dO6aoqKhAt+I3t9xyyyW3bvjb3/6m2NjYgPXUEFatWqXw8HClpqYGuhW/+vLLL9Wkie9bb1BQkGprawPWU0Np0aKFoqKidObMGW3dulV33XXXj94D1+T40fnz53X06FF7//jx4youLlbr1q3VoUOHgPbmDxkZGVq7dq1ef/11hYaGyuPxSJLCwsLUrFmzQLfnF1lZWRo2bJg6dOigc+fOae3atdq5c6e2bt0a6Nb8IjQ09JJrqFq0aKE2bdoYcW3VU089peHDhys2NlYnT57Uc889p6CgII0aNSrQrfnNlClTdPPNN+v3v/+97r33Xr3//vtauXKlVq5cGejW/Ka2tlarVq3S2LFjjbsFwPDhw/W73/1OHTp0UI8ePfTBBx9o4cKFGjduXKBb85utW7fKsix17dpVR48e1dSpUxUfH6+HH374x2/mRz93ZLB33nnHknTJNnbs2EC35heXW5ska9WqVYFuzW/GjRtnxcbGWk6n02rXrp01ZMgQa9u2bYFuq0GZdE3OyJEjraioKMvpdFo/+9nPrJEjR1pHjx4NdFt+9+abb1o9e/a0goODrfj4eGvlypWBbsmvtm7dakmySktLA92K33m9XmvSpElWhw4drJCQEKtTp07Ws88+a1VVVQW6Nb9Zv3691alTJ8vpdFqRkZFWRkaGdfbs2YD04rBMus0iAADA/+KaHAAAYCRCDgAAMBIhBwAAGImQAwAAjETIAQAARiLkAAAAIxFyAACAkQg5AADASIQcAABgJEIOAAAwEiEHAAAYiZADAACM9P8AMCmCWHBic9EAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.hist(mkbatch(2**15, large=True, target=\"onpath\", largetarget=False)[1].cpu(), bins=8)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q3Cg_8UQep8g"
},
"source": [
"# Step 2: Define Transformer Model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_state": "idle",
"metadata": {
"id": "tLOWhg_CeWzH"
},
"outputs": [],
"source": [
"class TransformerModel(nn.Module):\n",
" def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, dropout):\n",
" super().__init__()\n",
" self.model_dim = model_dim\n",
" self.embedding = nn.Embedding(input_dim, model_dim // 2)\n",
" encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads,\n",
" dim_feedforward=model_dim*4,\n",
" dropout=dropout, batch_first=True)\n",
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n",
" self.fc_out = nn.Linear(model_dim, output_dim)\n",
"\n",
" def forward(self, src, key_padding_mask):\n",
" batch_sz = src.size(0)\n",
" src = torch.cat((src, torch.full((batch_sz, 1), MAX_VTXS + 1, device=device)), dim=1)\n",
" embed = torch.cat((self.embedding(src[:,::2]), self.embedding(src[:,1::2])), dim=2)\n",
" output = self.transformer_encoder(embed, src_key_padding_mask=key_padding_mask[:, ::2])\n",
" return self.fc_out(output[:, -1, :])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bpIeg86S-hBb"
},
"source": [
"# Step 3: Make Model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kWXvJRDYgFVP",
"outputId": "c13adb9d-6565-43b5-8437-20cef3dc0d16"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training data: 104857K\n",
"Trainable parameters in the model: 550433\n"
]
}
],
"source": [
"# PARAMS\n",
"VOCAB_SIZE = 1 + MAX_VTXS + 1 # pad plus max number of vertices plus target token\n",
"MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
"NEPOCHS = 100\n",
"BSZ = 2**15 # Batch size\n",
"NHEADS = 2\n",
"NLAYERS = 11\n",
"DROPOUT = 0 # 0.2\n",
"model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
" output_dim=1, num_heads=NHEADS,\n",
" num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
" dropout=DROPOUT).to(device)\n",
"model = torch.compile(model)\n",
"\n",
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Training data: {NEPOCHS*EPOCH_SIZE//10**3}K\")\n",
"print(f\"Trainable parameters in the model: {trainable_params}\")\n",
"\n",
"train_err = []\n",
"len1 = []\n",
"len2 = []\n",
"len3 = []\n",
"leninf = []\n",
"epoch = 0\n",
"\n",
"# clear loss file\n",
"open('loss', 'w').close()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"# model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
"# output_dim=1, num_heads=NHEADS,\n",
"# num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
"# dropout=DROPOUT).to(device)\n",
"# model = torch.compile(model)\n",
"# model.load_state_dict(torch.load('model.pth', weights_only=True))\n",
"\n",
"LR = 5e-4\n",
"\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LR)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f8Zn33m7CxL5"
},
"source": [
"# Step 4: Train the Model for the first task"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_state": "running",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 486
},
"id": "pvTfzGmCeXU4",
"outputId": "0d3a20f3-23be-4c19-9eb6-46bfe11a48b1"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/torchinductor_sipb/lc/clcqc3ufbzrethiy77dmsu54kurxdmh4eji2f3msm347rhmfpf4j.py:1078: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)\n",
" extern_kernels.mm(reinterpret_tensor(buf1, (524288, 64), (64, 1), 0), reinterpret_tensor(primals_5, (64, 192), (1, 64), 0), out=buf2)\n",
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:42<00:00, 1.33s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0/100 \t Train Err: 48.97900426387787 12.91722442075843 7.231296321027912 3.5385852727340534 115.3402452468872\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100 \t Train Err: 41.59842586517334 23.066684544086456 14.473573058843613 7.877466633915901 84.72042870521545\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/100 \t Train Err: 40.415191769599915 29.482473254203796 19.633903205394745 11.77738669514656 73.48623991012573\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/100 \t Train Err: 40.01692616939545 32.29490512609482 22.276952624320984 13.907412678003311 68.378258228302\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/100 \t Train Err: 31.416786193847656 2.3271487059355422 26.761382937431335 18.088589638471603 60.98844397068024\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/100 \t Train Err: 25.921700596809387 0.22728093068872113 19.57061032950878 15.32799781858921 52.15890157222748\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/100 \t Train Err: 17.580547362565994 0.021982330930768512 4.470714939758182 19.619352281093597 36.34695905447006\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7/100 \t Train Err: 12.400098770856857 0.013556713653088082 0.5100052966736257 23.03424423933029 23.36328774690628\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8/100 \t Train Err: 9.38240310549736 0.007538945435953792 0.22025173716247082 19.04077085852623 16.662322163581848\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9/100 \t Train Err: 7.412262797355652 0.008171883615432307 0.1555994711816311 14.766773402690887 12.345462799072266\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/100 \t Train Err: 6.162406742572784 0.017439005838241428 0.14521901519037783 10.868462145328522 10.12831449508667\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 11/100 \t Train Err: 5.102278828620911 0.006056587655621115 0.07073448912706226 7.790802486240864 8.321609899401665\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 12/100 \t Train Err: 4.366904504597187 0.005133508995641023 0.059961416525766253 5.993938364088535 6.998718574643135\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 13/100 \t Train Err: 3.831405833363533 0.003056721754546743 0.04472046362934634 4.653905652463436 6.165115922689438\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 14/100 \t Train Err: 3.4279200956225395 0.004735801303468179 0.03825691540259868 3.906827114522457 5.425094351172447\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 15/100 \t Train Err: 3.138390600681305 0.005375595836085267 0.03724290645914152 3.326357141137123 4.948893174529076\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 16/100 \t Train Err: 2.899445064365864 0.006989890585828107 0.041147086303681135 2.920688170939684 4.558141328394413\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 17/100 \t Train Err: 2.6482545658946037 0.006691730450256728 0.03093722724588588 2.4344960935413837 4.143990509212017\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 18/100 \t Train Err: 2.4724042862653732 0.013588363406597637 0.03676938998978585 2.2025532834231853 3.8459226489067078\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 19/100 \t Train Err: 2.2938634902238846 0.007989803227246739 0.031059849599841982 1.9167942628264427 3.582200199365616\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 20/100 \t Train Err: 2.0970346182584763 0.004351407576905331 0.02517502213595435 1.652792677283287 3.244159609079361\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 21/100 \t Train Err: 1.974123526364565 0.003171493917761836 0.02390849226503633 1.4561462700366974 3.0673259645700455\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 22/100 \t Train Err: 1.8447572737932205 0.0030846048357489053 0.021993933914927766 1.3033475428819656 2.8325533121824265\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 23/100 \t Train Err: 1.6970642134547234 0.002785886717902031 0.020618034002836794 1.1004324741661549 2.604609090834856\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 24/100 \t Train Err: 1.5955667905509472 0.003940593182051089 0.02032866739318706 0.976159205660224 2.466688357293606\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 25/100 \t Train Err: 1.4617242440581322 0.002876215428841533 0.018645043048309162 0.8556559775024652 2.229958161711693\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 26/100 \t Train Err: 1.3831324987113476 0.004849277267567231 0.02050616717315279 0.7726096417754889 2.099547818303108\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 27/100 \t Train Err: 1.311477318406105 0.0024031923712755088 0.01927993548451923 0.7051676390692592 1.976257335394621\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 28/100 \t Train Err: 1.207214828580618 0.0035346322583791334 0.018441367952618748 0.6115700239315629 1.8264076933264732\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 29/100 \t Train Err: 1.1089427508413792 0.00329247322952142 0.0165127256186679 0.5199616495519876 1.6691240929067135\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 30/100 \t Train Err: 1.029175629839301 0.002865927770471899 0.01539153911289759 0.4984441949054599 1.5174495466053486\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 31/100 \t Train Err: 0.9538682177662849 0.0023230986480484717 0.01537516585085541 0.4138944335281849 1.406083919107914\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 32/100 \t Train Err: 0.8980342578142881 0.0015845673569856444 0.014839567360468209 0.40191424917429686 1.3131387010216713\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 33/100 \t Train Err: 0.8195682223886251 0.0034502551516197855 0.014685693866340443 0.3501896969974041 1.191128795966506\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 34/100 \t Train Err: 0.7459018900990486 0.0021442237539304188 0.014179074845742434 0.31923429761081934 1.0604591444134712\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 35/100 \t Train Err: 0.7101908139884472 0.0032578711288806517 0.014129422343103215 0.2923557236790657 1.0151417199522257\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 36/100 \t Train Err: 0.6644597053527832 0.0026260755184921436 0.012881889037089422 0.257190125528723 0.9443178754299879\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 37/100 \t Train Err: 0.6029678452759981 0.0022337357859214535 0.013130559556884691 0.232837010640651 0.837149228900671\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 38/100 \t Train Err: 0.5509444028139114 0.0019162936450811685 0.012294809013837948 0.2053068270906806 0.7741311714053154\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 39/100 \t Train Err: 0.5144137293100357 0.0014641922125520068 0.009304212289862335 0.18962949723936617 0.6982439709827304\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.04it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 40/100 \t Train Err: 0.4911631550639868 0.0034183577554358635 0.011847828980535269 0.17764808260835707 0.6830073017627001\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:31<00:00, 1.02it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 41/100 \t Train Err: 0.44534157309681177 0.002719711333156738 0.011579871497815475 0.17093974631279707 0.5986328851431608\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 42/100 \t Train Err: 0.4234167579561472 0.0020350335244074813 0.010087796414154582 0.1561200835276395 0.5679095359519124\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 43/100 \t Train Err: 0.38217254262417555 0.001672563766987878 0.008801718227914535 0.1451707179658115 0.5035076225176454\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"00%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:30<00:00, 1.03it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 44/100 \t Train Err: 0.36583498027175665 0.0006311295665000216 0.0070743790856795385 0.13195386109873652 0.4824865907430649\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 69%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 22/32 [00:21<00:09, 1.06it/s]"
]
}
],
"source": [
"while epoch < NEPOCHS:\n",
" model.train()\n",
" with open(f\"data/{epoch}.pickle\", \"rb\") as f:\n",
" pickled_stuff = pickle.load(f)\n",
" data = pickled_stuff[\"data\"].to(device)\n",
" label = pickled_stuff[\"labels\"].to(device).to(torch.float32)\n",
" padding = pickled_stuff[\"padding\"].to(device)\n",
" dataset = TensorDataset(data, label, padding)\n",
" loader = DataLoader(dataset, batch_size=BSZ)\n",
" for batch_src, batch_labels, batch_padding_mask in tqdm(loader):\n",
" optimizer.zero_grad()\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" train_err.append(loss.item())\n",
" len1.append(criterion(output[batch_labels == 1].squeeze(1), batch_labels[batch_labels == 1]).item())\n",
" len2.append(criterion(output[batch_labels == 2].squeeze(1), batch_labels[batch_labels == 2]).item())\n",
" len3.append(criterion(output[batch_labels == 3].squeeze(1), batch_labels[batch_labels == 3]).item())\n",
" leninf.append(criterion(output[batch_labels == MAX_VTXS].squeeze(1), batch_labels[batch_labels == MAX_VTXS]).item())\n",
" \n",
" with open('loss', 'a') as f:\n",
" f.write(f\"{train_err[-1]} {len1[-1]} {len2[-1]} {len3[-1]} {leninf[-1]}\\n\")\n",
" epoch_err = lambda l: sum(l[-EPOCH_SIZE // BSZ:]) / EPOCH_SIZE * BSZ\n",
" print(f\"Epoch {epoch}/{NEPOCHS} \\t Train Err: {epoch_err(train_err)} {epoch_err(len1)} {epoch_err(len2)} {epoch_err(len3)} {epoch_err(leninf)}\")\n",
"\n",
" epoch += 1\n",
" if epoch % 10 == 0:\n",
" torch.save(model.state_dict(), f\"checkpoints/model_weights_{epoch}.pth\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# \"\"\"\n",
"# Now let's figure out what it's doing. \n",
"\n",
"# step 1: figure out what people are attending to \n",
"# \"\"\"\n",
"\n",
"# example_graph, answer, padding = mkbatch(1)\n",
"# sentance_embeddings = model.full_embedding(example_graph)[0,:,:][example_graph.flatten() != 0]\n",
"# WQ,WK,WV = torch.split(model.transformer_encoder.layers[0].self_attn.in_proj_weight, (MODEL_DIM, MODEL_DIM, MODEL_DIM))\n",
"\n",
"# Q = sentance_embeddings@WQ\n",
"# K = sentance_embeddings@WK\n",
"\n",
"# raw_scores = Q @ K.T / sqrt(MODEL_DIM)\n",
"# soft = torch.softmax(raw_scores, dim=-1).detach().cpu().to(float).numpy()\n",
"# plt.imshow(soft)\n",
"# plt.show()\n",
"\n",
"# print(example_graph)\n",
"\n",
"# print(Q)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('training-loss') as f:\n",
" train_err = list(map(float, f.read().split()))\n",
" plt.suptitle('Log MSE vs Epochs')\n",
" plt.plot(torch.log(torch.tensor(train_err)[:500]), label='Train', color='blue')\n",
" plt.xlabel('Epochs')\n",
" plt.ylabel('Log MSE')\n",
" plt.show()\n",
"\n",
"plt.suptitle('Log MSE vs Epochs')\n",
"plt.plot(torch.log(torch.tensor(train_err)), label='Train', color='blue')\n",
"plt.plot(torch.log(torch.tensor(len1)).to(torch.float16), label='Len 1', color='red')\n",
"plt.plot(torch.log(torch.tensor(len2)).to(torch.float16), label='Len 2', color='green')\n",
"plt.plot(torch.log(torch.tensor(len3)).to(torch.float16), label='Len 3', color='yellow')\n",
"plt.plot(torch.log(torch.tensor(len15)).to(torch.float16), label='Len 15', color='magenta')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Log MSE')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "LoGEmM5lH7_A"
},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'mkbatch' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m batch_src, batch_labels, batch_padding_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28;01mlambda\u001b[39;00m x: x\u001b[38;5;241m.\u001b[39mto(device), \u001b[43mmkbatch\u001b[49m(BSZ))\n\u001b[1;32m 2\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n",
"\u001b[0;31mNameError\u001b[0m: name 'mkbatch' is not defined"
]
}
],
"source": [
"batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mkbatch(BSZ))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" output = model(batch_src, batch_padding_mask)\n",
"x = batch_labels.detach().to(torch.uint8)\n",
"y = output.detach()\n",
"cnts = torch.bincount(x)\n",
"weights = [1/cnts[i.item()].item() for i in x] # normalize by label count\n",
"fig, ax = plt.subplots()\n",
"h = ax.hist2d(x.cpu().numpy().flatten(), y.to(torch.float16).cpu().numpy().flatten(), weights=weights, bins=[15,50], norm=mpl.colors.LogNorm())\n",
"fig.colorbar(h[3], ax=ax)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sipb/.venv/lib64/python3.12/site-packages/torch/nn/modules/transformer.py:871: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)\n",
" return torch._transformer_encoder_layer_fwd(\n",
"/home/sipb/.venv/lib64/python3.12/site-packages/torch/nn/modules/transformer.py:871: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)\n",
" return torch._transformer_encoder_layer_fwd(\n",
"/home/sipb/.venv/lib64/python3.12/site-packages/torch/_inductor/compile_fx.py:167: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"0.0005554668023250997"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def evaluate():\n",
" model.eval()\n",
" test_loss = 0\n",
" with torch.no_grad():\n",
" batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mkbatch(BSZ))\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" return loss.item()\n",
"\n",
"evaluate()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LC6Xv3YfC0Rm"
},
"source": [
"# Step 5: Fine Tune"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"N_TUNE_EPOCHS = 100\n",
"TUNE_LR = 1e-5\n",
"TUNE_WD = 0 # 1e-5\n",
"\n",
"tune_criterion = nn.MSELoss()\n",
"tune_optimizer = torch.optim.Adam(model.parameters(), lr=TUNE_LR, weight_decay=TUNE_WD)\n",
"\n",
"tune_train_err = []\n",
"\n",
"# clear loss file\n",
"open('tune_loss', 'w').close()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0/100 \t Train Err: 0.4553\n",
"Epoch 1/100 \t Train Err: 0.2545\n",
"Epoch 2/100 \t Train Err: 0.2368\n",
"Epoch 3/100 \t Train Err: 0.2138\n",
"Epoch 4/100 \t Train Err: 0.1898\n",
"Epoch 5/100 \t Train Err: 0.1931\n",
"Epoch 6/100 \t Train Err: 0.1948\n",
"Epoch 7/100 \t Train Err: 0.1847\n",
"Epoch 8/100 \t Train Err: 0.1638\n",
"Epoch 9/100 \t Train Err: 0.1490\n",
"Epoch 10/100 \t Train Err: 0.1457\n",
"Epoch 11/100 \t Train Err: 0.1509\n",
"Epoch 12/100 \t Train Err: 0.1537\n",
"Epoch 13/100 \t Train Err: 0.1472\n",
"Epoch 14/100 \t Train Err: 0.1369\n",
"Epoch 15/100 \t Train Err: 0.1304\n",
"Epoch 16/100 \t Train Err: 0.1257\n",
"Epoch 17/100 \t Train Err: 0.1201\n",
"Epoch 18/100 \t Train Err: 0.1160\n",
"Epoch 19/100 \t Train Err: 0.1064\n",
"Epoch 20/100 \t Train Err: 0.1026\n",
"Epoch 21/100 \t Train Err: 0.0964\n",
"Epoch 22/100 \t Train Err: 0.1165\n",
"Epoch 23/100 \t Train Err: 0.0990\n",
"Epoch 24/100 \t Train Err: 0.0691\n",
"Epoch 25/100 \t Train Err: 0.0633\n",
"Epoch 26/100 \t Train Err: 0.0552\n",
"Epoch 27/100 \t Train Err: 0.0505\n",
"Epoch 28/100 \t Train Err: 0.0429\n",
"Epoch 29/100 \t Train Err: 0.0376\n",
"Epoch 30/100 \t Train Err: 0.0309\n",
"Epoch 31/100 \t Train Err: 0.0252\n",
"Epoch 32/100 \t Train Err: 0.0233\n",
"Epoch 33/100 \t Train Err: 0.0219\n",
"Epoch 34/100 \t Train Err: 0.0196\n",
"Epoch 35/100 \t Train Err: 0.0186\n",
"Epoch 36/100 \t Train Err: 0.0181\n",
"Epoch 37/100 \t Train Err: 0.0189\n",
"Epoch 38/100 \t Train Err: 0.0174\n",
"Epoch 39/100 \t Train Err: 0.0141\n",
"Epoch 40/100 \t Train Err: 0.0138\n",
"Epoch 41/100 \t Train Err: 0.0136\n",
"Epoch 42/100 \t Train Err: 0.0124\n",
"Epoch 43/100 \t Train Err: 0.0121\n",
"Epoch 44/100 \t Train Err: 0.0109\n",
"Epoch 45/100 \t Train Err: 0.0107\n",
"Epoch 46/100 \t Train Err: 0.0105\n",
"Epoch 47/100 \t Train Err: 0.0103\n",
"Epoch 48/100 \t Train Err: 0.0088\n",
"Epoch 49/100 \t Train Err: 0.0089\n",
"Epoch 50/100 \t Train Err: 0.0092\n",
"Epoch 51/100 \t Train Err: 0.0078\n",
"Epoch 52/100 \t Train Err: 0.0075\n",
"Epoch 53/100 \t Train Err: 0.0068\n",
"Epoch 54/100 \t Train Err: 0.0066\n",
"Epoch 55/100 \t Train Err: 0.0065\n",
"Epoch 56/100 \t Train Err: 0.0055\n",
"Epoch 57/100 \t Train Err: 0.0056\n",
"Epoch 58/100 \t Train Err: 0.0052\n",
"Epoch 59/100 \t Train Err: 0.0055\n",
"Epoch 60/100 \t Train Err: 0.0048\n",
"Epoch 61/100 \t Train Err: 0.0050\n",
"Epoch 62/100 \t Train Err: 0.0049\n",
"Epoch 63/100 \t Train Err: 0.0051\n",
"Epoch 64/100 \t Train Err: 0.0047\n",
"Epoch 65/100 \t Train Err: 0.0045\n",
"Epoch 66/100 \t Train Err: 0.0041\n",
"Epoch 67/100 \t Train Err: 0.0038\n",
"Epoch 68/100 \t Train Err: 0.0036\n",
"Epoch 69/100 \t Train Err: 0.0035\n",
"Epoch 70/100 \t Train Err: 0.0035\n",
"Epoch 71/100 \t Train Err: 0.0032\n",
"Epoch 72/100 \t Train Err: 0.0032\n",
"Epoch 73/100 \t Train Err: 0.0030\n",
"Epoch 74/100 \t Train Err: 0.0031\n",
"Epoch 75/100 \t Train Err: 0.0026\n",
"Epoch 76/100 \t Train Err: 0.0030\n",
"Epoch 77/100 \t Train Err: 0.0024\n",
"Epoch 78/100 \t Train Err: 0.0026\n",
"Epoch 79/100 \t Train Err: 0.0028\n",
"Epoch 80/100 \t Train Err: 0.0026\n",
"Epoch 81/100 \t Train Err: 0.0023\n",
"Epoch 82/100 \t Train Err: 0.0022\n",
"Epoch 83/100 \t Train Err: 0.0022\n",
"Epoch 84/100 \t Train Err: 0.0021\n",
"Epoch 85/100 \t Train Err: 0.0023\n",
"Epoch 86/100 \t Train Err: 0.0023\n",
"Epoch 87/100 \t Train Err: 0.0020\n",
"Epoch 88/100 \t Train Err: 0.0020\n",
"Epoch 89/100 \t Train Err: 0.0021\n",
"Epoch 90/100 \t Train Err: 0.0020\n",
"Epoch 91/100 \t Train Err: 0.0019\n",
"Epoch 92/100 \t Train Err: 0.0017\n",
"Epoch 93/100 \t Train Err: 0.0019\n",
"Epoch 94/100 \t Train Err: 0.0019\n",
"Epoch 95/100 \t Train Err: 0.0015\n",
"Epoch 96/100 \t Train Err: 0.0016\n",
"Epoch 97/100 \t Train Err: 0.0017\n",
"Epoch 98/100 \t Train Err: 0.0015\n",
"Epoch 99/100 \t Train Err: 0.0015\n"
]
}
],
"source": [
"for epoch in range(N_TUNE_EPOCHS):\n",
" model.train()\n",
" train_loss = 0\n",
" batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ))\n",
" optimizer.zero_grad()\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" train_loss = loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" tune_train_err.append(train_loss)\n",
" with open('tune_loss', 'a') as f:\n",
" f.write(f\"{train_loss}\\n\")\n",
" print(f\"Epoch {epoch}/{N_TUNE_EPOCHS} \\t Train Err: {train_loss:.4f}\")\n",
"\n",
" if epoch % 10 == 9:\n",
" torch.save(model.state_dict(), f\"checkpoints/tune_model_weights_{epoch + 1}.pth\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHgCAYAAABZ+0ykAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABCC0lEQVR4nO3deXRU9f3/8dckYZJAFpaQQCCsoih7g0BABTTIJlaLgpRKxBVQxPLTKlpB7ZeCuBQFBaXiziJWqVLBYsQFyhoMoCKuLIoJIJCELYHM5/fHZSaMAUxgZm5y83ycc8/M3Ll35j23al79bNdljDECAABwiDC7CwAAAAgkwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0ABJHL5dIdd9xhdxlAlUK4ASqxl156SS6XSy6XS8uXLy/1vjFGKSkpcrlcuuKKK/zeO3DggCZMmKDWrVurRo0aqlOnjtq3b68xY8Zo586dvuMeeugh33ecbMvJyQnJbz2V09U2YsQIW2sDYI8IuwsAcPaioqI0Z84cXXTRRX77P/74Y/3444+KjIz023/06FFdcskl+uqrr5SRkaHRo0frwIED+uKLLzRnzhxdffXVSk5O9jtnxowZiomJKfXdNWvWDNKvKrtevXpp2LBhpfafe+65ttQDwF6EG8AB+vXrpwULFujpp59WRETJv9Zz5sxRamqq9uzZ43f8woUL9dlnn+n111/XH//4R7/3jhw5oqKiolLfcc011yghISGIv+LMnXvuufrTn/5kdxkAKgi6pQAHGDJkiH755RctXbrUt6+oqEhvvvlmqfAiSd99950kqVu3bqXei4qKUlxcXEDqat26tXr27Flqv8fjUYMGDXTNNdf49s2bN0+pqamKjY1VXFyc2rRpo6eeeiogdUhSjx491Lp1a2VlZalr166Kjo5W06ZNNXPmzFLH7tq1SzfddJOSkpIUFRWldu3a6eWXXz7p73jqqafUpk0bRUVFqW7duurTp4/WrVtX6tiFCxeqdevWioyMVKtWrbRkyRK/9wsKCnTXXXepSZMmioyMVGJionr16qX169cH7BoAVQXhBnCAJk2aKC0tTXPnzvXtW7x4sfLy8nTdddeVOr5x48aSpFdeeUXGmDJ9x969e7Vnzx6/bf/+/ac9Z/Dgwfrkk09KjctZvny5du7c6att6dKlGjJkiGrVqqVHH31UkydPVo8ePbRixYoy1XbkyJFSte3Zs6dUC9S+ffvUr18/paamasqUKWrYsKFGjhyp2bNn+445fPiwevTooVdffVVDhw7VY489pvj4eN1www2lwtZNN92ku+66SykpKXr00Ud13333KSoqSqtWrSr1e0eNGqXrrrtOU6ZM0ZEjRzRw4ED98ssvvmNGjBihGTNmaODAgXr22Wd19913Kzo6Wps3by7TNQBwAgOg0nrxxReNJLN27Vozffp0Exsbaw4dOmSMMebaa681PXv2NMYY07hxY9O/f3/feYcOHTLnnXeekWQaN25sbrjhBvPCCy+Y3NzcUt8xYcIEI+mk23nnnXfa+rZs2WIkmWnTpvntHzVqlImJifHVOmbMGBMXF2eOHTtW7mtwqtokmblz5/qO6969u5FknnjiCd++wsJC0759e5OYmGiKioqMMcZMnTrVSDKvvfaa77iioiKTlpZmYmJiTH5+vjHGmA8//NBIMnfeeWepmjwej199brfbfPvtt759GzZsKHVd4uPjze23317u3w+gNFpuAIcYNGiQDh8+rEWLFqmgoECLFi06aZeUJEVHR2v16tW65557pOOzrm666SbVr19fo0ePVmFhYalz/vWvf2np0qV+24svvnjams4991y1b99e8+fP9+0rLi7Wm2++qQEDBig6Olo6Pij54MGDft1q5fH73/++VG1Lly4t1SUWERGh2267zffa7Xbrtttu065du5SVlSVJeu+991SvXj0NGTLEd1y1atV055136sCBA/r4449918PlcmnChAml6nG5XH6v09PT1bx5c9/rtm3bKi4uTt9//71vX82aNbV69Wq/mWoAzgwDigGHqFu3rtLT0zVnzhwdOnRIxcXFfmNafi0+Pl5TpkzRlClTtG3bNmVmZurxxx/X9OnTFR8fr//7v//zO/6SSy45owHFgwcP1v3336+ffvpJDRo00EcffaRdu3Zp8ODBvmNGjRqlN954Q3379lWDBg10+eWXa9CgQerTp0+ZvqNhw4ZKT0//zeOSk5NVo0YNv33eGVVbt25Vly5dtG3bNrVo0UJhYf7/3+/888+XJG3btk06Pm4pOTlZtWvX/s3vbdSoUal9tWrV0r59+3yvp0yZooyMDKWkpCg1NVX9+vXTsGHD1KxZs9/8fAD+aLkBHOSPf/yjFi9erJkzZ6pv375lnqbduHFj3XjjjVqxYoVq1qyp119/PWA1DR48WMYYLViwQJL0xhtvKD4+3i+4JCYmKjs7W++8846uvPJKLVu2TH379lVGRkbA6rBTeHj4SfefON5p0KBB+v777zVt2jQlJyfrscceU6tWrbR48eIQVgo4A+EGcJCrr75aYWFhWrVq1Sm7pE6nVq1aat68uX7++eeA1dS0aVN16tRJ8+fP17Fjx/TWW2/pqquuKrX2jtvt1oABA/Tss8/qu+++02233aZXXnlF3377bcBq2blzpw4ePOi37+uvv5aOD8rW8aD3zTffyOPx+B331Vdf+d6XpObNm2vnzp3au3dvwOqrX7++Ro0apYULF+qHH35QnTp1NHHixIB9PlBVEG4AB4mJidGMGTP00EMPacCAAac8bsOGDaXWvtHxLpcvv/xS5513XkDrGjx4sFatWqXZs2drz549fl1SkvxmDUlSWFiY2rZtK0knHf9zpo4dO6bnnnvO97qoqEjPPfec6tatq9TUVOn4mkE5OTl+44SOHTumadOmKSYmRt27d5ckDRw4UMYYPfzww6W+p6wz0LyKi4uVl5fnty8xMVHJyckB/f1AVcGYG8BhytKVs3TpUk2YMEFXXnmlunTpopiYGH3//feaPXu2CgsL9dBDD5U658033zzpCsW9evVSUlLSab9v0KBBuvvuu3X33Xerdu3apcbH3Hzzzdq7d68uvfRSNWzYUNu2bdO0adPUvn1731iX0/n666/12muvldqflJSkXr16+V4nJyfr0Ucf1datW3Xuuedq/vz5ys7O1vPPP69q1apJkm699VY999xzuuGGG5SVlaUmTZrozTff1IoVKzR16lTFxsZKknr27Knrr79eTz/9tL755hv16dNHHo9Hn376qXr27Fmu+0kVFBSoYcOGuuaaa9SuXTvFxMTogw8+0Nq1a/XEE0+U+XMAHGf3dC0AZ+7EqeCn8+up4N9//70ZP3686dKli0lMTDQRERGmbt26pn///ubDDz/0O/d0U8ElmWXLlpWp1m7duhlJ5uabby713ptvvmkuv/xyk5iYaNxut2nUqJG57bbbzM8///ybn3u62rp37+47rnv37qZVq1Zm3bp1Ji0tzURFRZnGjRub6dOnl/rM3NxcM3z4cJOQkGDcbrdp06aNefHFF0sdd+zYMfPYY4+Zli1bGrfbberWrWv69u1rsrKy/Oo72RTvxo0bm4yMDGOOT0m/5557TLt27UxsbKypUaOGadeunXn22Wd/8/cDKM1lytt+CgCVUI8ePbRnzx59/vnndpcCIMgYcwMAAByFcAMAAByFcAMAAByFMTcAAMBRaLkBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOEmF3AaHm8Xi0c+dOxcbGyuVy2V0OAAAoA2OMCgoKlJycrLCw07fNVLlws3PnTqWkpNhdBgAAOAM7duxQw4YNT3tMlQs3sbGx0vGLExcXZ3c5AACgDPLz85WSkuL7O346VS7ceLui4uLiCDcAAFQyZRlSwoBiAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKFXuxpnBUlQk5eZKHo/UuLHd1QAAUHXRchMga9ZIjRpJvXrZXQkAAFUb4SZAoqKsxyNH7K4EAICqjXATIN5wU1hodyUAAFRthJsAoeUGAICKgXATIJGR1iPhBgAAexFuAsTbclNUZM2YAgAA9iDcBIg33IhxNwAA2IpwEyAnhhu6pgAAsA/hJkAiIqSw41eTlhsAAOxDuAkQl4sZUwAAVASEmwBixhQAAPYj3AQQLTcAANiPcBNAhBsAAOxHuAkgwg0AAPYj3AQQ95cCAMB+hJsAouUGAAD7EW4CiNlSAADYj3ATQLTcAABgP8JNABFuAACwH+EmgAg3AADYj3ATQIQbAADsR7gJIKaCAwBgP8JNANFyAwCA/Qg3AcRUcAAA7Ee4CSBabgAAsB/hJoAINwAA2I9wE0CEGwAA7Ee4CSBmSwEAYD/CTQDRcgMAgP0INwHEbCkAAOxHuAkgWm4AALAf4SaACDcAANiPcBNAhBsAAOxHuAkgZksBAGA/wk0A0XIDAID9CDcBxGwpAADsR7gJIFpuAACwH+EmgAg3AADYr0KEm2eeeUZNmjRRVFSUOnfurDVr1pTpvHnz5snlcumqq64Keo1l4Q03xcXSsWN2VwMAQNVke7iZP3++xo4dqwkTJmj9+vVq166devfurV27dp32vK1bt+ruu+/WxRdfHLJaf4s33IjWGwAAbGN7uHnyySd1yy23aPjw4brgggs0c+ZMVa9eXbNnzz7lOcXFxRo6dKgefvhhNWvWLKT1no53QLGYDg4AgG1sDTdFRUXKyspSenp6SUFhYUpPT9fKlStPed4jjzyixMRE3XTTTb/5HYWFhcrPz/fbgiU8XIqIsJ7TcgMAgD1sDTd79uxRcXGxkpKS/PYnJSUpJyfnpOcsX75cL7zwgmbNmlWm75g0aZLi4+N9W0pKSkBqPxUGFQMAYC/bu6XKo6CgQNdff71mzZqlhISEMp0zbtw45eXl+bYdO3YEtUbCDQAA9oqw88sTEhIUHh6u3Nxcv/25ubmqV69eqeO/++47bd26VQMGDPDt83g8kqSIiAht2bJFzZs39zsnMjJSkScOhgkywg0AAPayteXG7XYrNTVVmZmZvn0ej0eZmZlKS0srdXzLli21adMmZWdn+7Yrr7xSPXv2VHZ2dtC7nMqCcAMAgL1sbbmRpLFjxyojI0MdO3ZUp06dNHXqVB08eFDDhw+XJA0bNkwNGjTQpEmTFBUVpdatW/udX7NmTUkqtd8u3DwTAAB72R5uBg8erN27d2v8+PHKyclR+/bttWTJEt8g4+3btyssrPIMDaLlBgAAe7mMMcbuIkIpPz9f8fHxysvLU1xcXMA//6KLpBUrpH/9S/rDHwL+8QAAVEnl+ftdeZpEKglabgAAsBfhJsAINwAA2ItwE2CEGwAA7EW4CTBmSwEAYC/CTYDRcgMAgL0INwHmXQyZcAMAgD0INwFGyw0AAPYi3AQY4QYAAHsRbgKMcAMAgL0INwFGuAEAwF6EmwBjKjgAAPYi3AQYs6UAALAX4SbA6JYCAMBehJsAI9wAAGAvwk2AEW4AALAX4SbACDcAANiLcBNgzJYCAMBehJsAY7YUAAD2ItwEGN1SAADYi3ATYIQbAADsRbgJMMINAAD2ItwE2Inhxhi7qwEAoOoh3ASYN9xI0tGjdlYCAEDVRLgJsBPDDV1TAACEHuEmwNzukueEGwAAQo9wE2AuF2vdAABgJ8JNEDBjCgAA+xBugoBwAwCAfQg3QUC4AQDAPoSbIODmmQAA2IdwEwQMKAYAwD6EmyCgWwoAAPsQboKAcAMAgH0IN0FAuAEAwD6EmyAg3AAAYB/CTRAwWwoAAPsQboKA2VIAANiHcBMEdEsBAGAfwk0QEG4AALAP4SYICDcAANiHcBMEhBsAAOxDuAkCZksBAGAfwk0Q0HIDAIB9CDdBwFRwAADsQ7gJAlpuAACwD+EmCAg3AADYh3ATBIQbAADsQ7gJAsINAAD2IdwEAVPBAQCwD+EmCJgtBQCAfQg3QUC3FAAA9iHcBAHhBgAA+xBugoBwAwCAfQg3QUC4AQDAPoSbIPCGm6NHJY/H7moAAKhaCDdB4J0tJaaDAwAQcoSbIPC23IiuKQAAQo5wEwQREVLY8StLuAEAILQIN0HgcjGoGAAAuxBugoRwAwCAPQg3QUK4AQDAHoSbIPHOmGK2FAAAoUW4CRJabgAAsEeFCDfPPPOMmjRpoqioKHXu3Flr1qw55bFvvfWWOnbsqJo1a6pGjRpq3769Xn311ZDWWxaEGwAA7GF7uJk/f77Gjh2rCRMmaP369WrXrp169+6tXbt2nfT42rVr64EHHtDKlSu1ceNGDR8+XMOHD9f7778f8tpPh3ADAIA9bA83Tz75pG655RYNHz5cF1xwgWbOnKnq1atr9uzZJz2+R48euvrqq3X++eerefPmGjNmjNq2bavly5eHvPbTIdwAAGAPW8NNUVGRsrKylJ6eXlJQWJjS09O1cuXK3zzfGKPMzExt2bJFl1xyyUmPKSwsVH5+vt8WCoQbAADsYWu42bNnj4qLi5WUlOS3PykpSTk5Oac8Ly8vTzExMXK73erfv7+mTZumXr16nfTYSZMmKT4+3relpKQE/HecjDfcMFsKAIDQsr1b6kzExsYqOztba9eu1cSJEzV27Fh99NFHJz123LhxysvL8207duwISY3eqeC03AAAEFoRdn55QkKCwsPDlZub67c/NzdX9erVO+V5YWFhOueccyRJ7du31+bNmzVp0iT16NGj1LGRkZGKPPE23SFCtxQAAPawteXG7XYrNTVVmZmZvn0ej0eZmZlKS0sr8+d4PB4VVrD+H8INAAD2sLXlRpLGjh2rjIwMdezYUZ06ddLUqVN18OBBDR8+XJI0bNgwNWjQQJMmTZKOj6Hp2LGjmjdvrsLCQr333nt69dVXNWPGDJt/iT/CDQAA9rA93AwePFi7d+/W+PHjlZOTo/bt22vJkiW+Qcbbt29XWFhJA9PBgwc1atQo/fjjj4qOjlbLli312muvafDgwTb+itIINwAA2MNljDF2FxFK+fn5io+PV15enuLi4oL2PQ8/LD30kDRypPTss0H7GgAAqoTy/P2ulLOlKgNmSwEAYA/CTZDQLQUAgD0IN0FCuAEAwB6EmyAh3AAAYA/CTZAQbgAAsAfhJkgINwAA2INwEyTe2VIVbOFkAAAcj3ATJLTcAABgD8JNkBBuAACwB+EmSAg3AADYg3ATJIQbAADsQbgJEsINAAD2INwECbOlAACwB+EmSLwtN8XF0rFjdlcDAEDVQbgJEm+4EV1TAACEFOEmSLzdUiLcAAAQUoSbIAkPl6pVs54TbgAACB3CTRAxYwoAgNAj3ASRN9wwYwoAgNAh3ASRd9wNLTcAAIQO4SaI6JYCACD0CDdBRLgBACD0CDdBRLgBACD0CDdBRLgBACD0CDdBRLgBACD0CDdBxM0zAQAIPcJNENWoYT3u2mV3JQAAVB2EmyC65BLr8a237K4EAICqg3ATRNdea91jau1a6dtv7a4GAICqgXATRImJ0mWXWc/nzrW7GgAAqgbCTZD98Y/W49y5kjF2VwMAgPMRboLs6qutWVObN0sbN9pdDQAAzke4CbK4OKl/f+v5nDl2VwMAgPMRbkJgyBDrcd48yeOxuxoAAJyNcBMC/ftLsbHS9u3SypV2VwMAgLMRbkIgOtoaeyO6pgAACLpyhZspU6bo8OHDvtcrVqxQ4Qn3FigoKNCoUaMCW6FDeLumFiyQjh2zuxoAAJzLZUzZJyiHh4fr559/VmJioiQpLi5O2dnZatasmSQpNzdXycnJKi4uDl7FZyk/P1/x8fHKy8tTXFxcyL736FGpQQNp925pyRKpd++QfTUAAJVeef5+l6vl5tc5qBy5qMqrVs1asVgs6AcAQFAx5iaEvF1Tb79tteQAAIDAI9yEUNeuUp06Un6+tGaN3dUAAOBMEeU94Z///KdiYmIkSceOHdNLL72khIQE6fiAYpxaWJh1r6k33pCWLpW6dbO7IgAAnKdcA4qbNGkil8v1m8f98MMPZ1tX0Ng1oNjrn/+UbrnFCjbLl4f86wEAqJTK8/e7XC03W7duPdvaqrz0dOtx1Sqre8qGfAUAgKMx5ibEmjSRmjeXiouljz+2uxoAAJynXOFm5cqVWrRokd++V155RU2bNlViYqJuvfVWv0X9cHK9elmPH3xgdyUAADhPucLNI488oi+++ML3etOmTbrpppuUnp6u++67T++++64mTZoUjDodxds1tXSp3ZUAAOA85Qo32dnZuuyyy3yv582bp86dO2vWrFkaO3asnn76ab3xxhvBqNNRLr1UcrmkzZuln36yuxoAAJylXOFm3759SkpK8r3++OOP1bdvX9/rCy+8UDt27AhshQ5Uq5bUsaP1nK4pAAACq1zhJikpyTfNu6ioSOvXr1eXLl187xcUFKhatWqBr9KBGHcDAEBwlCvc9OvXT/fdd58+/fRTjRs3TtWrV9fFF1/se3/jxo1q3rx5MOp0HO+4mw8+kLhFFwAAgVOucPO3v/1NERER6t69u2bNmqXnn39ebrfb9/7s2bN1+eWXB6NOx+naVYqOlnJypBPGaAMAgLNUrkX8EhIS9MknnygvL08xMTEKDw/3e3/BggWKjY0NdI2OFBkpXXKJ9P77VutN69Z2VwQAgDOUK9zceOONZTpu9uzZZ1pPlZKeboWbpUulu+6yuxoAAJyhXOHmpZdeUuPGjdWhQweV45ZUOAXvoOKPP5aKiqQTevgAAMAZKle4GTlypObOnasffvhBw4cP15/+9CfVrl07eNU5XJs2Ut260u7d0urV0gljswEAwBkq14DiZ555Rj///LP+8pe/6N1331VKSooGDRqk999/n5acMxAWJnnXRGS1YgAAAqPcN86MjIzUkCFDtHTpUn355Zdq1aqVRo0apSZNmujAgQPBqdLBvOFm2TK7KwEAwBnO6q7gYWFhcrlcMsaouLg4cFVVIZdeaj2uWiUdPGh3NQAAVH7lDjeFhYWaO3euevXqpXPPPVebNm3S9OnTtX37dsXExASnSgdr2lRq3Fg6dkxavtzuagAAqPzKFW5GjRql+vXra/Lkybriiiu0Y8cOLViwQP369VNY2Fk1AlVZLpfUs6f1nK4pAADOnsuUYyRwWFiYGjVqpA4dOsjlcp3yuLfeeitQ9QVcfn6+4uPjlZeXp7i4OLvLkSS9+qo0bJh04YXSmjV2VwMAQMVTnr/f5ZoKPmzYsNOGGpwZb8tNVpaUlyfFx9tdEQAAlVe5F/ELhmeeeUaPPfaYcnJy1K5dO02bNk2dOnU66bGzZs3SK6+8os8//1ySlJqaqr///e+nPL4yaNhQatFC+uYb6ZNPpAED7K4IAIDKy/aBMvPnz9fYsWM1YcIErV+/Xu3atVPv3r21a9eukx7/0UcfaciQIVq2bJlWrlyplJQUXX755frpp59CXnsgeWdNffih3ZUAAFC5lWvMTTB07txZF154oaZPny5J8ng8SklJ0ejRo3Xffff95vnFxcWqVauWpk+frmHDhv3m8RVxzI0kzZ8vXXed1K6dlJ1tdzUAAFQs5fn7bWvLTVFRkbKyspSenl5SUFiY0tPTtXLlyjJ9xqFDh3T06NFT3gaisLBQ+fn5fltF1KOH9bhhg7Rnj93VAABQedkabvbs2aPi4mIlJSX57U9KSlJOTk6ZPuPee+9VcnKyX0A60aRJkxQfH+/bUlJSAlJ7oCUlSa1bW88/+sjuagAAqLxsH3NzNiZPnqx58+bp7bffVlRU1EmPGTdunPLy8nzbjh07Ql5nWbHeDQAAZ8/WcJOQkKDw8HDl5ub67c/NzVW9evVOe+7jjz+uyZMn67///a/atm17yuMiIyMVFxfnt1VU5RlU/NNP0rZtQS8JAIBKx9Zw43a7lZqaqszMTN8+j8ejzMxMpaWlnfK8KVOm6G9/+5uWLFmijh07hqja4Ove3Vqx+KuvpJ07/d8rLJQyM6V77pHatLGmj597rrRunV3VAgBQMdneLTV27FjNmjVLL7/8sjZv3qyRI0fq4MGDGj58uHR84cBx48b5jn/00Uf14IMPavbs2WrSpIlycnKUk5PjiDuS16oldehgPfeOu9m3T7r/fikhQUpPlx5/XDq+xI+KiqSRIyXuWQoAQIlyLeIXDIMHD9bu3bs1fvx45eTkqH379lqyZIlvkPH27dv97ls1Y8YMFRUV6ZprrvH7nAkTJuihhx4Kef2Bduml0vr10qJF0tat0pQp1qrFklSvntSnj7W1aSOlpVktN7NmSSNG2F05AAAVg+3r3IRaRV3nxmvxYqlfP/99rVtLEydaKxefePeLadOkO++UataUtmyREhNDXi4AACFRada5QWkXXSRFR1vPmzWTXnvNWtTvyiv9g41kdUl16CDt3y/de68t5QIAUOEQbiqY2Fjp/felefOkzZuloUOl8PCTHxsRIc2YYYWel16Sli8PdbUAAFQ8hJsK6OKLpcGDJbf7t4/t3Fm65Rbr+ciR0tGjQS8PAIAKjXDjAH//u1SnjjWL6vgtugAAqLIINw5Qp44VcCSrmwoAgKqMcOMQQ4ZI1apJ33wjff213dUAAGAfwo1DxMaW3Fn8P/+xuxoAAOxDuHGQ/v2tx0WL7K4EAAD7EG4c5IorrMdPPpHy8+2uBgAAexBuHKR5c+m886Rjx6T//tfuagAAsAfhxmG8rTeMuwEAVFWEG4fxjrt57z3J47G7GgAAQo9w4zAXXSTFxUm7dll3DAcAoKoh3DhMtWpS797Wc2ZNAQCqIsKNAzHuBgBQlRFuHKhvX+tO4evXSzt32l0NAAChRbhxoLp1rbuF6/jAYgAAqhLCjUOxWjEAoKoi3DiUd9zNBx9IR47YXQ0AAKFDuHGodu2kBg2kgwdZrRgAULUQbhzK5ZKuu856/txzdlcDAEDoEG4cbMQI63HxYun77+2uBgCA0CDcONg551gL+hkjzZxpdzUAAIQG4cbhbr/denzhBenwYburAQAg+Ag3Dtevn9SokbR3rzR/vt3VAAAQfIQbhwsPLxl78+yzdlcDAEDwEW6qgJtuktxuae1aayuP/fulKVOkyy+XPv00WBUCABA4hJsqIDFRuvZa63lZW2+2bZP+/GcpJUW6915p6VLrnlWrVgW1VAAAzhrhporwDiyeN0/65ZdTH3fggDR8uNS8uTR1qvW6VSspLc1aELBvXyk7O2RlAwBQboSbKqJLF6l9e+tWDC+8cPJjtm+XunWTXnpJKi6WLrvMWiNn0yar5aZbN6ubqlcvafPmUP8CAADKhnBTRbhcJa0348ZJt94q7dxZ8v7q1VKnTtLGjVJSkvTJJ9Z9qfr0sc6tUUP6z3+k1FRpzx4r+Hz3nW0/BwCAU3IZY4zdRYRSfn6+4uPjlZeXp7i4OLvLCanCQun666UFC6zX0dHWuJpzzpFGjrTeb9tWevdda/r4yfzyi9Sjh/T551LTptKGDVJsbEh/BgCgCirP329abqqQyEjpjTesWU9du1qL+v3979KNN1rBZsAAafnyUwcbSapTx+qiatxY+uEHaeLEUP4CAAB+G+GmCrroIivELFwotWxp7ft//096++2ytcLUqydNn249f/JJacuW4NYLAEB50C1VxRUXS7t3W4GlvK64whqH07u3NfDY5QpGhQAA0C2FcggPP7NgI1lTxd1u6f33pXfeCXRlAACcGcINztg550h33209v+subswJAKgYCDc4K/ffLzVsKG3dat2mAQAAuxFucFZq1JCeeMJ6PnmyFXIAALAT4QZn7dprpZ49rdWPBw2S8vPtrggAUJURbnDWXC5p5kxrDZy1a6X+/a37UAEAYAfCDQLi3HOl//5Xio+31tC56iqrJQcAgFAj3CBgfvc7a72bGjWs+1Jde61UVGR3VQCAqoZwg4BKS5MWLZKioqzHoUNpwQEAhBbhBgHXo4d1K4dq1aQ337Ru9/DDD3ZXBQCoKgg3CIo+faxbM9SpI2VlWV1W775rd1UAgKqAcIOg6dVL+uwzqUsXaf9+6corpfvuk44ds7syAICTEW4QVCkp0scfS2PGWK8ffdSaKn7ggN2VAQCcinCDoHO7rZtsvvGGVL26NWX80kulPXvsrgwA4ESEG4TMtddKH35YstjfRRdJ27bZXRUAwGkINwipzp2tRf5SUqQtW6Ru3aQvvrC7KgCAkxBuEHItW0r/+590wQXSTz9JF18sLVhgd1UAAKcg3MAWDRtKn34qde0q7dtn3XDz2mulXbvsrgwAUNkRbmCb2rWlZcuk8eOliAhrwb9WrQLXirNrl5STE5jPAgBUHoQb2Mrtlh5+WFqzRmrTxppBNWiQNHr02X3ukSPWwoFt2lhr7AAAqg7CDSqEDh2kdeukBx+UwsKk6dOtm3CeqQ8/tMbz7NkjvfVWICsFAFR0hBtUGG639Mgj0l13Wa9vu00qKDizz3rnnZLnc+YEpj4AQOVAuEGF88gjUpMm0o4d0l//Wv7zPR7/cLNsmfTzzwEtEQBQgRFuUOHUqCE995z1fNo0adWq8p2flWWFmZgYKTXVCjtvvBGUUgEAFRDhBhXS5ZdLGRmSMdLNN0tFRWU/99//th779JFuuMF6TtcUAFQdhBtUWE88IdWta61gPHly2c/zdkn9/vfW2jnh4dZsrG+/DVqpAIAKhHCDCqtOHenpp63n//d//uNoTuWHH6RNm6xA06+flJQkpadb782dG9x6AQAVA+EGFdrgwdKVV0pHj1otMQMHSj/+eOrjvQHo4outRQIl6Y9/tB7nzLG6uQAAzmZ7uHnmmWfUpEkTRUVFqXPnzlqzZs0pj/3iiy80cOBANWnSRC6XS1OnTg1prQg9l8tqcbn3Xqs15q23pPPPt1p0iotLH+8db/P735fsu+oqKSpK+uoracOG0NUOALCHreFm/vz5Gjt2rCZMmKD169erXbt26t27t3ad4gZDhw4dUrNmzTR58mTVq1cv5PXCHtWrW2Nu1q+XunSRDhyQxoyRevSQ8vNLjtu3T/rkE+v5lVeW7I+Lk664wnrOwGIAcD5bw82TTz6pW265RcOHD9cFF1ygmTNnqnr16po9e/ZJj7/wwgv12GOP6brrrlNkZGTI64W92raVVqyQZs60Asvy5daMKG/Aee89qzWndWupWTP/c71dU3PnWlPDAQDOZVu4KSoqUlZWltK9oz0lhYWFKT09XStXrgzY9xQWFio/P99vQ+UVFmatXLxsmVSrlrRypdS3r7WSsXe8zYmtNl59+0rx8dZ4neXLQ142ACCEbAs3e/bsUXFxsZKSkvz2JyUlKSeAt3KeNGmS4uPjfVtKSkrAPhv2+d3vpA8+kGrWlP73P6sFx3svqhPH23hFRVmDkSUW9AMAp7N9QHGwjRs3Tnl5eb5tx44ddpeEAPl1wCkokOrXlzp2PPnxffpYj+Vd8RgAULnYFm4SEhIUHh6u3Nxcv/25ubkBHSwcGRmpuLg4vw3OkZoqLV1qBRxJGjDA6ro6mQsvtB43bpQKC0NXIwAgtGwLN263W6mpqcrMzPTt83g8yszMVFpaml1loRLq2FH66CNp5EjpgQdOfVzjxtbCgEePMiUcAJzM1m6psWPHatasWXr55Ze1efNmjRw5UgcPHtTw4cMlScOGDdO4ceN8xxcVFSk7O1vZ2dkqKirSTz/9pOzsbH3LuvpVXrt20rPPSo0anfoYl6uk9WbdupCVBgAIsQg7v3zw4MHavXu3xo8fr5ycHLVv315LlizxDTLevn27wk7oY9i5c6c6dOjge/3444/r8ccfV/fu3fXRRx/Z8htQuXTsKC1ZIq1da3clAIBgcRlTtRakz8/PV3x8vPLy8hh/UwW98441m6pVK+nzz+2uBgBQVuX5++342VLAibwzqTZvtlY6BgA4D+EGVUpysrV5PNJnn9ldDQAgGAg3qHK8g4oZdwMAzkS4QZXDjCkAcDbCDaoc77gbWm4AwJkIN6hyvOHm22+lffvsrgYAEGiEG1Q5depIzZpZz7Oy7K4GABBohBtUSXRNAYBzEW5QJTFjCgCci3CDKokZUwDgXIQbVEm/+511I80dO6TcXLurAQAEEuEGVVJsrNSypfWcrikAcBbCDaosuqYAwJkIN6iymDEFAM5EuEGVdeKMKWPsrgYAECiEG1RZ7dpJERHS7t3Se+/ZXQ0AIFAIN6iyoqOlIUOs51dfLf3rX3ZXBAAIBMINqrR//lMaNEg6etR6fPlluysCAJwtwg2qNLdbmjNHuvFGyeORbrhBmj7d7qoAAGeDcIMqLzxcmjVLGjPGej16tDR1qt1VAQDOFOEGkBQWJv3jH9KDD1qv/9//k5Yts7sqAMCZINwAx7lc0sMPSxkZVhfVdddJO3faXRUAoLwIN8AJXC7p2Weltm2lXbusgHP0qN1VAQDKg3AD/Er16tKbb1r3n/r0U+n+++2uCABQHoQb4CRatJBeesl6/vjj0ttv210RAKCsCDfAKfzhD9LYsdbzG26wWnEAABUf4QY4jcmTpYsvlvLzpZ49pSef5D5UAFDREW6A06hWTVq82LpNQ3GxNUV88GCpoMDuygAAp0K4AX5DjRrS669L06ZZN9pcsEDq1En66iu7KwMAnAzhBigDl0u64w7pk0+k5GQr2PTsKf34o92VAQB+jXADlENamvTZZ1Lr1lJOjnU38cOH7a4KAHAiwg1QTomJ0jvvSHXqSOvWSTffzCBjAKhICDfAGWja1FroLyLCuqv4lCl2VwQA8CLcAGeoRw/pqaes5+PGSf/5j90VAQBEuAHOzsiR0m23Wd1SQ4ZY3VQAAHsRboCz4HJJTz8tXXKJtfZN9+7Sv/9td1UAULURboCz5HZL774r9e4tHTpkzaDydlcBAEKPcAMEQFycFXBuvdXqorrrLunOO61VjQEAoUW4AQKkWjVp5sySmVPTpknXXCMdPWp3ZQBQtRBugAByuaR77rFu0RAZKS1cKN10k+Tx2F0ZAFQdhBsgCK65RvrXv6TwcOnVV6W772ahPwAIFcINECT9+0svvmg9/8c/pMmT7a4IAKoGwg0QRNdfLz35pPX8/vulWbPsrggAnI9wAwTZn/9srWAsSSNGSK+/bndFAOBshBsgBCZOtKaJezxWa84zz9hdEQA4F+EGCAGXS5oxQxo92hpYfMcd0t/+xiBjAAgGwg0QImFh1srFEyZYr8ePt7qsmCYOAIFFuAFCyOWSHnqo5PYMTz1l3a5hyRKpqMju6gDAGQg3gA3uvFN65RVrHZx33pH69pXq1pWGDrXWx9m3z+4KAaDychlTtXr98/PzFR8fr7y8PMXFxdldDqq4tWuttXDeflvKyfF/74ILpIsukrp1k3r0kBo1sqtKALBfef5+E26ACsDjkVavlt56S1q0SPrqq9LHdO8u3XCDtfpxTIwdVQKAfQg3p0G4QWWwe7f0v/9Jy5db2+rVJTOrqle3Ak7PnlLLltZWs6bdFQNAcBFuToNwg8poxw7rHlUvvyx9/XXp9+vVk845R0pIkGrXlurUsR737ZO2bi3ZCgqkTp2sbq4ePaQuXaSoKDt+EQCUD+HmNAg3qMyMkVatkt54Q9q0Sdq8Wdq588w/LzLSugfWI49IrVoFslIACCzCzWkQbuA0+fnWGJ2tW6VffpH27i15jI+XmjQp2dxuacUK6aOPpGXLSgYxu1zWyskPP2wdBwAVDeHmNAg3gMUYaeNGa6Xkf/3L2letmnTzzVKbNlKNGiffYmKk2FiJf30AhBLh5jQIN0Bpa9ZYN/f88MOyn5OUJLVrV7Kdc44VeGJjrQAUEyNFRASzagBVCeHmNAg3wKl98IE0Z46UlycdPGhtBw6UPPe+Li4u2+edc47UsWPJ1qqVFXoiI62uMAAoK8LNaRBugLNjjBVyvvhC2rChZNu+3dpfUCAdPXr6z3C5rCnt0dElW1RUyeOvt6ZNpYsvtmZ3Va8eql8KoCIpz99vGo0BlIvLZbW+dO5sbSdTVGRNQ9+4UVq3rmTbvt163xuQDh4s33dHREipqdb3ut1WiPJuhw9bg6sLCqzHI0esVZ67drW29u2tcwA4Hy03AELGG0IOHSrZjhyx9nkfDx+WCgut10eOWMds3Ch9+qn0449n/t1RUSULHsbFWTPJataUkpOtW1t4t/r1rYHVACoWWm4AVEjVqlnbmfz/CmOkbdukTz6xwo7LVfJ51apZ4SUurmRQc3i49Nln1krPK1da0+Ozs8v2XWFh1rgg71atmtVqFB5uPVavLp17rtUydP751mPt2tZYJO8mlQSo8PDS3+ENerGxjD8CAo2WGwCOZ4z0zTfSd99ZXVZ5eda2f7/VGrR9u7Xt2PHb44XORFycFX7Cwkq6zgoLrffcbqu1qEEDa4uPt97zbkVF1nknBjnJv/Xr0CErcJ24OnV8fMmYJW9Ii4uz3vdutWvTVYfKo9K13DzzzDN67LHHlJOTo3bt2mnatGnq1KnTKY9fsGCBHnzwQW3dulUtWrTQo48+qn79+oW0ZgCVh8tltbSce+7pjysutkLPieGisFA6dsx/y8uzFk788ktrlegvv7TGD4WHW1tYmBWoDh+2Pjc/39pOpqjIapHati3wv7ss3O6S9Ytq1LBCkjcQeR+9rVZhYSW/8devo6Otc73biTPivI+Rkf6DyN1u63p6x00dO2bVkJAg1a1rPUZGWtcoL88ax7V/vxXw6tWzjjndcgPGWJ9bVGRtbrf1vSdrSYOz2B5u5s+fr7Fjx2rmzJnq3Lmzpk6dqt69e2vLli1KTEwsdfz//vc/DRkyRJMmTdIVV1yhOXPm6KqrrtL69evVunVrW34DAGcID7daM8piwIDfPqaoyPpjvG+ftWK0MVY3lHcRxKgo6yapP/1Ush044N8l5nZbd433hoBjx6zPOTFIREVZQWrv3pIVqvfv9w9oR45YAeGXX6xt3z7rc7x/+PftO+vLFxSRkSWtXL8WFmYFnLp1reviHbPlHbdVVHTy87zdmN5Qd+Lmdpd0P54Y4lwuawsLK725XP7X2RuIvf0i3sfoaP8FMd3u0l2Sxlj/e3s3HR9I793Cw0u6SSMiSlrzTvxnJjLSOvfXLYBSye9wufy7cuPirN9vjBXyPR7/blbva4/H//vcbuv1r39HZKQVQO1ie7dU586ddeGFF2r69OmSJI/Ho5SUFI0ePVr33XdfqeMHDx6sgwcPatGiRb59Xbp0Ufv27TVz5szf/D66pQCgpJXqxHWMDhwoGdx94qDuX/+BO9nmHfzt3byB5MQ/8IWFJeHj0CHrD673D7X38cABac8ea/v1ekregeBHj0q7dpX88UfFk5ZmjXcLpErTLVVUVKSsrCyNGzfOty8sLEzp6elauXLlSc9ZuXKlxo4d67evd+/eWrhw4UmPLywsVOEJsT//VG3DAFCFeFupytpSFWoej9X6lJ9f0rJwYhdUcbEVgHJyrNavatX810v6detXtWpWmDpxVt6Joa6gwHr0to6dGNy8rSm/fvSGPWP8u/G8g9C9XC7r2CNHrO87dMh6/HXLkjGlW4W8LSneLtGjR63X3jqPHbM+59ctR78eFH/i2CpjrO3IkZIu0/x86/e7XCfvdvS+DguzvvvEFqGTjVOLjAzGPxVlZ2u42bNnj4qLi5WUlOS3PykpSV999dVJz8nJyTnp8TneOwD+yqRJk/Twww8HsGoAQLCFhZ0+fIWHW7cA+dWfg9OqVs3qDoLzhdldQLCNGzdOeXl5vm3Hjh12lwQAAILI1pabhIQEhYeHKzc3129/bm6u6p1iJFK9evXKdXxkZKQi7W4fAwAAIWNry43b7VZqaqoyMzN9+zwejzIzM5WWlnbSc9LS0vyOl6SlS5ee8ngAAFC12D4VfOzYscrIyFDHjh3VqVMnTZ06VQcPHtTw4cMlScOGDVODBg00adIkSdKYMWPUvXt3PfHEE+rfv7/mzZundevW6fnnn7f5lwAAgIrA9nAzePBg7d69W+PHj1dOTo7at2+vJUuW+AYNb9++XWFhJQ1MXbt21Zw5c/TXv/5V999/v1q0aKGFCxeyxg0AAJAqwjo3ocY6NwAAVD7l+fvt+NlSAACgaiHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAAR7F9heJQ865ZmJ+fb3cpAACgjLx/t8uy9nCVCzcFBQWSpJSUFLtLAQAA5VRQUKD4+PjTHlPlbr/g8Xi0c+dOxcbGyuVyBfSz8/PzlZKSoh07dnBrhyDjWocO1zp0uNahw7UOnUBda2OMCgoKlJyc7HfPyZOpci03YWFhatiwYVC/Iy4ujn9ZQoRrHTpc69DhWocO1zp0AnGtf6vFxosBxQAAwFEINwAAwFEINwEUGRmpCRMmKDIy0u5SHI9rHTpc69DhWocO1zp07LjWVW5AMQAAcDZabgAAgKMQbgAAgKMQbgAAgKMQbgAAgKMQbgLkmWeeUZMmTRQVFaXOnTtrzZo1dpdU6U2aNEkXXnihYmNjlZiYqKuuukpbtmzxO+bIkSO6/fbbVadOHcXExGjgwIHKzc21rWanmDx5slwul+666y7fPq514Pz000/605/+pDp16ig6Olpt2rTRunXrfO8bYzR+/HjVr19f0dHRSk9P1zfffGNrzZVRcXGxHnzwQTVt2lTR0dFq3ry5/va3v/ndm4hrfeY++eQTDRgwQMnJyXK5XFq4cKHf+2W5tnv37tXQoUMVFxenmjVr6qabbtKBAwfOvjiDszZv3jzjdrvN7NmzzRdffGFuueUWU7NmTZObm2t3aZVa7969zYsvvmg+//xzk52dbfr162caNWpkDhw44DtmxIgRJiUlxWRmZpp169aZLl26mK5du9pad2W3Zs0a06RJE9O2bVszZswY336udWDs3bvXNG7c2Nxwww1m9erV5vvvvzfvv/+++fbbb33HTJ482cTHx5uFCxeaDRs2mCuvvNI0bdrUHD582NbaK5uJEyeaOnXqmEWLFpkffvjBLFiwwMTExJinnnrKdwzX+sy999575oEHHjBvvfWWkWTefvttv/fLcm379Olj2rVrZ1atWmU+/fRTc84555ghQ4acdW2EmwDo1KmTuf32232vi4uLTXJyspk0aZKtdTnNrl27jCTz8ccfG2OM2b9/v6lWrZpZsGCB75jNmzcbSWblypU2Vlp5FRQUmBYtWpilS5ea7t27+8IN1zpw7r33XnPRRRed8n2Px2Pq1atnHnvsMd++/fv3m8jISDN37twQVekM/fv3NzfeeKPfvj/84Q9m6NChxnCtA+rX4aYs1/bLL780kszatWt9xyxevNi4XC7z008/nVU9dEudpaKiImVlZSk9Pd23LywsTOnp6Vq5cqWttTlNXl6eJKl27dqSpKysLB09etTv2rds2VKNGjXi2p+h22+/Xf379/e7puJaB9Q777yjjh076tprr1ViYqI6dOigWbNm+d7/4YcflJOT43et4+Pj1blzZ651OXXt2lWZmZn6+uuvJUkbNmzQ8uXL1bdvX4lrHVRlubYrV65UzZo11bFjR98x6enpCgsL0+rVq8/q+6vcjTMDbc+ePSouLlZSUpLf/qSkJH311Ve21eU0Ho9Hd911l7p166bWrVtLknJycuR2u1WzZk2/Y5OSkpSTk2NTpZXXvHnztH79eq1du7bUe1zrwPn+++81Y8YMjR07Vvfff7/Wrl2rO++8U263WxkZGb7rebL/pnCty+e+++5Tfn6+WrZsqfDwcBUXF2vixIkaOnSodPyfa3Gtg6Is1zYnJ0eJiYl+70dERKh27dpnff0JN6gUbr/9dn3++edavny53aU40o4dOzRmzBgtXbpUUVFRdpfjaB6PRx07dtTf//53SVKHDh30+eefa+bMmcrIyLC7PEd544039Prrr2vOnDlq1aqVsrOzdddddyk5OZlr7XB0S52lhIQEhYeHl5o1kpubq3r16tlWl5PccccdWrRokZYtW6aGDRv69terV09FRUXav3+/3/Fc+/LLysrSrl279Lvf/U4RERGKiIjQxx9/rKeffloRERFKSkriWgdI/fr1dcEFF/jtO//887V9+3bp+D/XOn5tT8S1Lr977rlH9913n6677jq1adNG119/vf785z9r0qRJEtc6qMpybevVq6ddu3b5vX/s2DHt3bv3rK8/4eYsud1upaamKjMz07fP4/EoMzNTaWlpttZW2RljdMcdd+jtt9/Whx9+qKZNm/q9n5qaqmrVqvld+y1btmj79u1c+3K67LLLtGnTJmVnZ/u2jh07aujQob7nXOvA6NatW6klDb7++ms1btxYktS0aVPVq1fP71rn5+dr9erVXOtyOnTokMLC/P/MhYeHy+PxSFzroCrLtU1LS9P+/fuVlZXlO+bDDz+Ux+NR586dz66AsxqODGOOTwWPjIw0L730kvnyyy/NrbfeamrWrGlycnLsLq1SGzlypImPjzcfffSR+fnnn33boUOHfMeMGDHCNGrUyHz44Ydm3bp1Ji0tzaSlpdlat1OcOFvKcK0DZs2aNSYiIsJMnDjRfPPNN+b111831atXN6+99prvmMmTJ5uaNWuaf//732bjxo3m97//PdOTz0BGRoZp0KCBbyr4W2+9ZRISEsxf/vIX3zFc6zNXUFBgPvvsM/PZZ58ZSebJJ580n332mdm2bZsxZby2ffr0MR06dDCrV682y5cvNy1atGAqeEUybdo006hRI+N2u02nTp3MqlWr7C6p0pN00u3FF1/0HXP48GEzatQoU6tWLVO9enVz9dVXm59//tnWup3i1+GGax047777rmndurWJjIw0LVu2NM8//7zf+x6Pxzz44IMmKSnJREZGmssuu8xs2bLFtnorq/z8fDNmzBjTqFEjExUVZZo1a2YeeOABU1hY6DuGa33mli1bdtL/RmdkZBhTxmv7yy+/mCFDhpiYmBgTFxdnhg8fbgoKCs66Npc5calGAACASo4xNwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwCqJJfLpYULF9pdBoAgINwACLkbbrhBLper1NanTx+7SwPgABF2FwCgaurTp49efPFFv32RkZG21QPAOWi5AWCLyMhI1atXz2+rVauWdLzLaMaMGerbt6+io6PVrFkzvfnmm37nb9q0SZdeeqmio6NVp04d3XrrrTpw4IDfMbNnz1arVq0UGRmp+vXr64477vB7f8+ePbr66qtVvXp1tWjRQu+8847vvX379mno0KGqW7euoqOj1aJFi1JhDEDFRLgBUCE9+OCDGjhwoDZs2KChQ4fquuuu0+bNmyVJBw8eVO/evVWrVi2tXbtWCxYs0AcffOAXXmbMmKHbb79dt956qzZt2qR33nlH55xzjt93PPzwwxo0aJA2btyofv36aejQodq7d6/v+7/88kstXrxYmzdv1owZM5SQkBDiqwDgjJz1fcUBoJwyMjJMeHi4qVGjht82ceJEY4wxksyIESP8zuncubMZOXKkMcaY559/3tSqVcscOHDA9/5//vMfExYWZnJycowxxiQnJ5sHHnjglDVIMn/96199rw8cOGAkmcWLFxtjjBkwYIAZPnx4gH85gFBgzA0AW/Ts2VMzZszw21e7dm3f87S0NL/30tLSlJ2dLUnavHmz2rVrpxo1avje79atmzwej7Zs2SKXy6WdO3fqsssuO20Nbdu29T2vUaOG4uLitGvXLknSyJEjNXDgQK1fv16XX365rrrqKnXt2vUsfzWAUCDcALBFjRo1SnUTBUp0dHSZjqtWrZrfa5fLJY/HI0nq27evtm3bpvfee09Lly7VZZddpttvv12PP/54UGoGEDiMuQFQIa1atarU6/PPP1+SdP7552vDhg06ePCg7/0VK1YoLCxM5513nmJjY9WkSRNlZmaeVQ1169ZVRkaGXnvtNU2dOlXPP//8WX0egNCg5QaALQoLC5WTk+O3LyIiwjdod8GCBerYsaMuuugivf7661qzZo1eeOEFSdLQoUM1YcIEZWRk6KGHHtLu3bs1evRoXX/99UpKSpIkPfTQQxoxYoQSExPVt29fFRQUaMWKFRo9enSZ6hs/frxSU1PVqlUrFRYWatGiRb5wBaBiI9wAsMWSJUtUv359v33nnXeevvrqK+n4TKZ58+Zp1KhRql+/vubOnasLLrhAklS9enW9//77GjNmjC688EJVr15dAwcO1JNPPun7rIyMDB05ckT/+Mc/dPfddyshIUHXXHNNmetzu90aN26ctm7dqujoaF188cWaN29ewH4/gOBxGWvWAABUGC6XS2+//bauuuoqu0sBUAkx5gYAADgK4QYAADgKY24AVDj0lgM4G7TcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAAR/n/66M0wFecaLoAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.suptitle('MSE vs Epochs')\n",
"plt.plot(tune_train_err, label='Train', color='blue')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('MSE')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.001733972690999508"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def tune_evaluate():\n",
" model.eval()\n",
" test_loss = 0\n",
" with torch.no_grad():\n",
" batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ))\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" return loss.item()\n",
"\n",
"tune_evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[6.0290e+03, 1.8597e+04, 8.4000e+01, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 1.0000e+00]]),\n",
" array([1. , 1.1 , 1.2 , 1.3 , 1.4 , 1.5 , 1.6 , 1.699, 1.8 ,\n",
" 1.9 , 2. , 2.1 , 2.2 , 2.3 , 2.398, 2.5 , 2.6 , 2.7 ,\n",
" 2.8 , 2.898, 3. , 3.1 , 3.2 , 3.299, 3.398, 3.5 , 3.6 ,\n",
" 3.7 , 3.799, 3.898, 4. , 4.1 , 4.2 , 4.297, 4.4 , 4.5 ,\n",
" 4.6 , 4.7 , 4.797, 4.9 , 5. , 5.098, 5.2 , 5.3 , 5.4 ,\n",
" 5.5 , 5.598, 5.7 , 5.797, 5.9 , 6. ], dtype=float16),\n",
" array([0.9014, 0.998 , 1.094 , 1.19 , 1.287 , 1.384 , 1.48 , 1.576 ,\n",
" 1.673 , 1.77 , 1.865 , 1.962 , 2.059 , 2.156 , 2.25 , 2.348 ,\n",
" 2.445 , 2.541 , 2.637 , 2.734 , 2.83 , 2.926 , 3.023 , 3.121 ,\n",
" 3.215 , 3.312 , 3.41 , 3.504 , 3.602 , 3.7 , 3.793 , 3.89 ,\n",
" 3.988 , 4.082 , 4.18 , 4.277 , 4.375 , 4.47 , 4.566 , 4.664 ,\n",
" 4.758 , 4.855 , 4.953 , 5.05 , 5.145 , 5.242 , 5.34 , 5.434 ,\n",
" 5.53 , 5.63 , 5.723 ], dtype=float16),\n",
" <matplotlib.collections.QuadMesh at 0x7f201dae35c0>)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhsAAAGdCAYAAAC7JrHlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVOklEQVR4nO3df4yV9Z3o8c8wI0cWZ0ZhYYUyTDSuUKEQf/QSpFrrr2RiJvYf6xpaiTXZtBkVatyYaXKj5FYH/2gTTc1UvY3c3EhoaYq23ouUJgLpZqkwXm5A11asiaNg3CV1hqHpWWXO/WNxbtGic4b5zDNzzuuVnD/O+J3zfJJjnLffc57naahUKpUAAEgyregBAIDaJjYAgFRiAwBIJTYAgFRiAwBIJTYAgFRiAwBIJTYAgFRNE33A4eHhOHz4cDQ3N0dDQ8NEHx4AGINKpRLHjh2L+fPnx7Rp1e1VTHhsHD58ONra2ib6sADAOOjv748FCxZU9TsTHhvNzc0RJ4dtaWmZ6MMDAB9zc+vtn7nmw/ggfhP/e+TveDUmPDY++uikpaVFbADAJNDUcNZnLzp5J7WxfAXCF0QBgFRiAwBIJTYAgFRiAwBINeFfEAUAJpcdw1s+c83g4GC0traO6fXtbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqZqKHgAAKNYN0275zDUfVj4Y8+vb2QAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCVW8wDMGnduPK/jWrdr/7lv6bPUst2DG/5zDWDg4PR2to6pte3swEApBIbAEAqsQEApBIbAEAqsQEApBIbAECqqk59ffDBB2P9+vWn/GzRokXx2muvjfdcABDvXzyz6BEYB1VfZ2PJkiXx61//+v+/QJNLdQAAp1d1KTQ1NcX555+fMw0AUHOq/s7G66+/HvPnz48LL7wwVq9eHW+99danri+XyzE4OHjKAwCoH1XFxooVK2Ljxo3xwgsvRG9vb7z55ptx1VVXxbFjx077Oz09PdHa2jryaGtrG4+5AYApoqFSqVTG+svvv/9+tLe3xw9+8IO48847/+qacrkc5XJ55Png4GC0tbXFwMBAtLS0jPXQANSB/7LmB6Na99L/uDd9lnr30b1RxvL3+4y+3XnuuefGxRdfHIcOHTrtmlKpFKVS6UwOAwBMYWd0nY2hoaF44403Yt68eeM3EQBQU6ra2bjvvvuis7Mz2tvb4/Dhw/HAAw9EY2Nj3HbbbXkTAlC3Zr/0b0WPwDioKjbefvvtuO222+Lo0aMxZ86c+NKXvhR79uyJOXPm5E0IAExpVcXG5s2b8yYBAGqSe6MAAKnEBgCQSmwAAKnEBgCQyi1bAZi0tv1rT9EjMA7sbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqZqKHgBgsrlh2i2jWrdjeEv6LFAL7GwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnc9RXgY9zNFcaXnQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXrbAB8TMeCe0a1btvbj6XPArXAzgYAkEpsAACpxAYAkEpsAACpxAYAkEpsAACpnPoK8DF/+McLih4BaoqdDQAgldgAAFKdUWxs2LAhGhoaYt26deM3EQBQU8YcG3v37o0nnngili1bNr4TAQA1ZUyxMTQ0FKtXr46nnnoqzjvvvPGfCgCoGWOKja6urrjpppvi+uuv/8y15XI5BgcHT3kAAPWj6lNfN2/eHC+//HLs3bt3VOt7enpi/fr1Y5kNAKgBVe1s9Pf3x9q1a+OZZ56Js88+e1S/093dHQMDAyOP/v7+sc4KMCE+/JvKqB7A6FS1s9HX1xfvvfdeXHbZZSM/O3HiROzevTt++MMfRrlcjsbGxlN+p1QqRalUGr+JAYApparYuO666+LAgQOn/OyOO+6IxYsXx/333/+J0AAAqCo2mpubY+nSpaf8bObMmTF79uxP/BwAIFxBFADIdsY3Ytu5c+f4TAIA1CQ7GwBAKreYB/iYv/2/RU8AtcXOBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQynU2AD6m5Zl/Gd3C/5k9CdQGOxsAQCqxAQCkEhsAQCqxAQCkEhsAQCqxAQCkcuorwMfsGN5S9AhQU+xsAACpxAYAkEpsAACpxAYAkEpsAACpxAYAkMqprwAfc8O0W0a1zimyMDp2NgCAVGIDAEglNgCAVGIDAEglNgCAVGIDAEglNgCAVK6zAZNEx4J7RrVu29uPpc8CMJ7sbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJDKqa8wSQz/3XlFj8BJTfPnFT0C1BQ7GwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKRy6itMEscuail6BIAUdjYAgFRiAwBIJTYAgFRiAwBIJTYAgFRiAwBIJTYAgFSuswGTxNDntD9Qm/zXDQBIJTYAgFRiAwBIJTYAgFRiAwBIJTYAgFROfYVJYnD5fxQ9Aidte/uxokeAmmJnAwBIJTYAgFRiAwBIJTYAgFRVxUZvb28sW7YsWlpaoqWlJVauXBnbtm3Lmw4AmPKqio0FCxbEhg0boq+vL/bt2xfXXntt3HzzzfHKK6/kTQgATGlVnfra2dl5yvOHHnooent7Y8+ePbFkyZLxng0AqAFjvs7GiRMnYsuWLXH8+PFYuXLladeVy+Uol8sjzwcHB8d6SKhpf9/+btEjAKSo+guiBw4ciHPOOSdKpVJ861vfiq1bt8Yll1xy2vU9PT3R2to68mhrazvTmQGAKaTq2Fi0aFHs378/fvvb38a3v/3tWLNmTbz66qunXd/d3R0DAwMjj/7+/jOdGQCYQqr+GGX69Olx0UUXRUTE5ZdfHnv37o1HH300nnjiib+6vlQqRalUOvNJAYAp6YyvszE8PHzKdzIAAP5SVTsb3d3d0dHREQsXLoxjx47Fpk2bYufOnbF9+/a8CQGAKa2q2Hjvvffi9ttvjyNHjkRra2ssW7Ystm/fHjfccEPehADAlFZVbPz4xz/Om4RC3DDtllGt2zG8JX2Wenfkfy0c3cJrsicBGF/ujQIApBIbAEAqsQEApBIbAEAqsQEApBIbAECqMd/1FRhfs373YdEjAKSwswEApBIbAEAqsQEApBIbAEAqsQEApBIbAEAqsQEApHKdjTrn1vGTx4x3hooeASCFnQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSOfUVJontfeuLHgEghZ0NACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV62zUuY4F94xq3ba3H0ufpd55L4BaZWcDAEglNgCAVGIDAEglNgCAVGIDAEglNgCAVE59rXMn2uYUPQInOaUVqFV2NgCAVGIDAEglNgCAVGIDAEglNgCAVGIDAEjl1Nc6Nzy9segROKnj892jWrftX3vSZwEYT3Y2AIBUYgMASCU2AIBUYgMASCU2AIBUYgMASCU2AIBUrrNR5/7t0hlFj8BJrp8B1Co7GwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKRy6mude3/ph0WPAECNs7MBAKQSGwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKSq6jobPT098fOf/zxee+21mDFjRlx55ZXxyCOPxKJFi/ImJNVj1z4zypX/lDwJN0y7ZVTrdgxvSZ8FYDxVtbOxa9eu6Orqij179sSOHTvigw8+iBtvvDGOHz+eNyEAMKVVtbPxwgsvnPJ848aNMXfu3Ojr64urr756vGcDAGrAGV2ufGBgICIiZs2addo15XI5yuXyyPPBwcEzOSQAMMWM+Quiw8PDsW7duli1alUsXbr0tOt6enqitbV15NHW1jbWQwIAU9CYY6OrqysOHjwYmzdv/tR13d3dMTAwMPLo7+8f6yEBgCloTB+j3HXXXfH888/H7t27Y8GCBZ+6tlQqRalUGut8AMAUV1VsVCqVuPvuu2Pr1q2xc+fOuOCCC/ImY0Ks/ed/GNW6my9MH6XuOaUVqFVVxUZXV1ds2rQpnnvuuWhubo533303IiJaW1tjxowZWTMCAFNYVd/Z6O3tjYGBgbjmmmti3rx5I4+f/OQneRMCAFNa1R+jAABUw71RAIBUYgMASCU2AIBUZ3S5cqa+psOugQJALjsbAEAqsQEApBIbAEAqsQEApBIbAEAqsQEApBIbAECqwq6zcXPr7dHUcNZp/7nbbU+Mme8UPQEAtc7OBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKkKO/W16e8vjKZGtzcv2jlHhoseAYAaZ2cDAEglNgCAVGIDAEglNgCAVGIDAEglNgCAVGIDAEhV2HU2yvNa4kTT2UUdnpNmvv2nokcAoMbZ2QAAUokNACCV2AAAUokNACCV2AAAUokNACBVYae+lo4MRlNjuajDc1LDf3xY9AgA1Dg7GwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKQq7NTXf1/xt9E43V1fizb8f14tegQAapydDQAgldgAAFKJDQAgldgAAFKJDQAgldgAAFKJDQAgVWHX2YiGkw8KtWN4S9EjAFDj7GwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQqrBTXwfbG6LxbOe+AkCts7MBAKQSGwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKQq7Dobfbf/92hpbvyUFd+ZwGkAgCx2NgCAVGIDAEhVdWzs3r07Ojs7Y/78+dHQ0BDPPvtszmQAQE2oOjaOHz8ey5cvj8cffzxnIgCgplT9BdGOjo7o6OjImQYAqDnpZ6OUy+Uol8sjzwcHB7MPCQBMIumx0dPTE+vXr//Ez7f/6ez4m8bTn/ramTwXADAx0s9G6e7ujoGBgZFHf39/9iEBgEkkfWejVCpFqVTKPgwAMEm5zgYAkKrqnY2hoaE4dOjQyPM333wz9u/fH7NmzYqFCxeO93wAwBRXdWzs27cvvvKVr4w8v/feeyMiYs2aNbFx48bxnQ4AmPKqjo1rrrkmKpXKmA/40e/+aejEp65ziiwATB4f/V0eSwNM+F1fjx49GhER3/zS7z5jZeuEzAMAjN7Ro0ejtbW6v9ETHhuzZs2KiIi33nqr6mEZX4ODg9HW1hb9/f3R0tJS9Dh1zXsxeXgvJg/vxeQyMDAQCxcuHPk7Xo0Jj41p0/7zBJjW1lb/8kwSLS0t3otJwnsxeXgvJg/vxeTy0d/xqn4nZRIAgJPEBgCQasJjo1QqxQMPPOCqopOA92Ly8F5MHt6LycN7MbmcyfvRUDmT81gBAD6Dj1EAgFRiAwBIJTYAgFRiAwBINWGxsXv37ujs7Iz58+dHQ0NDPPvssxN1aP5CT09PfPGLX4zm5uaYO3dufPWrX43f/e6zLh1Plt7e3li2bNnIRYtWrlwZ27ZtK3osImLDhg3R0NAQ69atK3qUuvPggw9GQ0PDKY/FixcXPVbdeuedd+LrX/96zJ49O2bMmBFf+MIXYt++fVW9xoTFxvHjx2P58uXx+OOPT9Qh+St27doVXV1dsWfPntixY0d88MEHceONN8bx48eLHq0uLViwIDZs2BB9fX2xb9++uPbaa+Pmm2+OV155pejR6trevXvjiSeeiGXLlhU9St1asmRJHDlyZOTxm9/8puiR6tIf//jHWLVqVZx11lmxbdu2ePXVV+P73/9+nHfeeVW9zoRdrryjoyM6Ojom6nCcxgsvvHDK840bN8bcuXOjr68vrr766sLmqlednZ2nPH/ooYeit7c39uzZE0uWLClsrno2NDQUq1evjqeeeiq+973vFT1O3Wpqaorzzz+/6DHq3iOPPBJtbW3x9NNPj/zsggsuqPp1fGejzg0MDET8xQ3yKM6JEydi8+bNcfz48Vi5cmXR49Strq6uuOmmm+L6668vepS69vrrr8f8+fPjwgsvjNWrV8dbb71V9Eh16Re/+EVcccUVccstt8TcuXPj0ksvjaeeeqrq15nwG7ExeQwPD8e6deti1apVsXTp0qLHqVsHDhyIlStXxp///Oc455xzYuvWrXHJJZcUPVZd2rx5c7z88suxd+/eokepaytWrIiNGzfGokWL4siRI7F+/fq46qqr4uDBg9Hc3Fz0eHXlD3/4Q/T29sa9994b3/3ud2Pv3r1xzz33xPTp02PNmjWjfh2xUce6urri4MGDPgst2KJFi2L//v0xMDAQP/vZz2LNmjWxa9cuwTHB+vv7Y+3atbFjx444++yzix6nrv3lR+7Lli2LFStWRHt7e/z0pz+NO++8s9DZ6s3w8HBcccUV8fDDD0dExKWXXhoHDx6MH/3oR1XFho9R6tRdd90Vzz//fLz44ouxYMGCosepa9OnT4+LLrooLr/88ujp6Ynly5fHo48+WvRYdaevry/ee++9uOyyy6KpqSmamppi165d8dhjj0VTU1OcOHGi6BHr1rnnnhsXX3xxHDp0qOhR6s68efM+8T8+n//856v+WMvORp2pVCpx9913x9atW2Pnzp1j+qIPuYaHh6NcLhc9Rt257rrr4sCBA6f87I477ojFixfH/fffH42NjYXNVu+GhobijTfeiG984xtFj1J3Vq1a9YnLI/z+97+P9vb2ql5nwmJjaGjolCp98803Y//+/TFr1qxYuHDhRI1R97q6umLTpk3x3HPPRXNzc7z77rsREdHa2hozZswoery6093dHR0dHbFw4cI4duxYbNq0KXbu3Bnbt28verS609zc/InvLs2cOTNmz57tO00T7L777ovOzs5ob2+Pw4cPxwMPPBCNjY1x2223FT1a3fnOd74TV155ZTz88MPxta99LV566aV48skn48knn6zuhSoT5MUXX6xExCcea9asmagR+M87/P7Vx9NPP130aHXpm9/8ZqW9vb0yffr0ypw5cyrXXXdd5Ve/+lXRY3HSl7/85cratWuLHqPu3HrrrZV58+ZVpk+fXvnc5z5XufXWWyuHDh0qeqy69ctf/rKydOnSSqlUqixevLjy5JNPVv0abjEPAKTyBVEAIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABS/T/VxjQRzeYx5wAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" output = model(batch_src, batch_padding_mask)\n",
"x = batch_labels.detach().to(torch.float16).cpu().numpy().flatten()\n",
"y = output.detach().to(torch.float16).cpu().numpy().flatten()\n",
"plt.hist2d(x, y, bins=50, norm=mpl.colors.LogNorm())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JtTLXn4zC1z_"
},
"source": [
"# Step 6: Test generalization"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.20130713284015656\n"
]
},
{
"data": {
"text/plain": [
"(array([[1.7789e+04, 6.5000e+02, 4.8000e+02, ..., 0.0000e+00, 1.0000e+00,\n",
" 1.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00]]),\n",
" array([ 1. , 1.18 , 1.36 , 1.54 , 1.721, 1.9 , 2.08 , 2.262,\n",
" 2.441, 2.621, 2.8 , 2.98 , 3.16 , 3.34 , 3.521, 3.701,\n",
" 3.88 , 4.062, 4.242, 4.42 , 4.6 , 4.78 , 4.96 , 5.14 ,\n",
" 5.32 , 5.5 , 5.68 , 5.863, 6.043, 6.223, 6.402, 6.582,\n",
" 6.76 , 6.94 , 7.12 , 7.3 , 7.48 , 7.66 , 7.844, 8.02 ,\n",
" 8.2 , 8.38 , 8.56 , 8.74 , 8.92 , 9.1 , 9.28 , 9.46 ,\n",
" 9.64 , 9.82 , 10. ], dtype=float16),\n",
" array([ 0.879, 1.156, 1.434, 1.71 , 1.987, 2.266, 2.541, 2.818,\n",
" 3.096, 3.373, 3.65 , 3.928, 4.203, 4.48 , 4.758, 5.035,\n",
" 5.312, 5.59 , 5.867, 6.145, 6.42 , 6.7 , 6.977, 7.254,\n",
" 7.527, 7.805, 8.08 , 8.36 , 8.64 , 8.92 , 9.19 , 9.47 ,\n",
" 9.75 , 10.016, 10.3 , 10.58 , 10.86 , 11.125, 11.41 , 11.69 ,\n",
" 11.97 , 12.234, 12.516, 12.8 , 13.08 , 13.34 , 13.625, 13.91 ,\n",
" 14.17 , 14.45 , 14.734], dtype=float16),\n",
" <matplotlib.collections.QuadMesh at 0x7f201daaff50>)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAfp0lEQVR4nO3de3BU9f3/8dcmIUuMyUr4Si6SYEAKym2gWH6AbaFE/eaHFKZTVCZiCnU67UQh0lpINSqDENHWKpUJYjuAF/DSacDyG6Qp5VJ+VQjBONALF6WYcp35DeyS8GUN2fP7o2OmEYTEPZvz3s3zMXP+2LOf7Hkdg+yLz57PWZ/jOI4AAAAMSfI6AAAAwOdRUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYk+J1gM+LRCI6fvy4MjIy5PP5vI4DAAA6wHEcnTt3Tnl5eUpKin7+w1xBOX78uPLz872OAQAAvoTGxkb17ds36tcxV1AyMjIkSbfpfytFPa44dkPwlS5KBQAAriQUCik/P7/tfTxa5grKZx/rpKiHUnxXLiiZmZldlAoAAHSEW5dncJEsAAAwh4ICAADMoaAAAABzKCgAAMAccxfJAug6tydN79C42sjbMc8CAP+JGRQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5vgcx3E68wM7duzQs88+q/r6ep04cUI1NTWaNm3aZcf+8Ic/1EsvvaRf/vKXKi8v79Drh0IhBQIBBYNBvgwQAIA44fb7d6dnUJqbmzVixAgtX778iuNqamr0/vvvKy8vL5p8AACgG+r0nWSLi4tVXFx8xTHHjh3TQw89pM2bN2vy5MnR5AMAAN2Q67e6j0Qimjlzph555BENGTLkquPD4bDC4XDb41Ao5HYkAAAQZ1y/SHbp0qVKSUnRnDlzOjS+qqpKgUCgbcvPz3c7EgAAiDOuFpT6+nq98MILWr16tXw+X4d+pqKiQsFgsG1rbGx0MxIAAIhDrhaUP//5zzp9+rQKCgqUkpKilJQUHT16VD/+8Y914403XvZn/H6/MjMz220AAKB7c/UalJkzZ6qoqKjdvjvvvFMzZ87UrFmz3DwUAABIYJ0uKE1NTTp8+HDb4yNHjqihoUFZWVkqKChQ7969243v0aOHcnJyNGjQIHcSAwCAhNfpgrJnzx5NnDix7fG8efMkSaWlpVq9erW76QAAQLfU6YIyYcIEdebms//85z87ewgAANDN8V08AADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAnBSvA3yRqYH7leLrccUxtZG3uywPAADoOsygAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADDH7H1QNgRfUWZmptcxAACAB5hBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJjT6YKyY8cOTZkyRXl5efL5fFq/fn3bcy0tLZo/f76GDRum9PR05eXl6f7779fx48fdzg0AABJYpwtKc3OzRowYoeXLl1/y3Pnz57V3715VVlZq7969+t3vfqcDBw7o29/+tlt5AQBAN+BzHMf50j/s86mmpkbTpk37wjF1dXX62te+pqNHj6qgoOCqrxkKhRQIBBQMBrnVPQAAccLt9++YX4MSDAbl8/l03XXXxfpQAAAgQcT0ywIvXLig+fPna8aMGV/YpsLhsMLhcNvjUCgUy0gAACAOxGwGpaWlRXfffbccx1F1dfUXjquqqlIgEGjb8vPzYxUJAADEiZgUlM/KydGjR1VbW3vFz6IqKioUDAbbtsbGxlhEAgAAccT1j3g+KyeHDh3S1q1b1bt37yuO9/v98vv9bscAAABxrNMFpampSYcPH257fOTIETU0NCgrK0u5ubn67ne/q71792rjxo1qbW3VyZMnJUlZWVlKTU11Nz0AAEhInV5mvG3bNk2cOPGS/aWlpXryySdVWFh42Z/bunWrJkyYcNXXZ5kxAADxx+33707PoEyYMEFX6jRR3FYFAABA4rt4AACARRQUAABgTkxv1BaNqYH7leLrccUxtZG3uywPAADoOsygAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcs8uMNwRf4Vb3AAB0U8ygAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHM6XVB27NihKVOmKC8vTz6fT+vXr2/3vOM4evzxx5Wbm6u0tDQVFRXp0KFDbmYGAAAJrtMFpbm5WSNGjNDy5csv+/wzzzyjZcuWacWKFdq1a5fS09N155136sKFC27kBQAA3UBKZ3+guLhYxcXFl33OcRw9//zzeuyxxzR16lRJ0iuvvKLs7GytX79e9957b/SJAQBAwnP1GpQjR47o5MmTKioqatsXCAQ0ZswYvffee5f9mXA4rFAo1G4DAADdm6sF5eTJk5Kk7Ozsdvuzs7Pbnvu8qqoqBQKBti0/P9/NSAAAIA55voqnoqJCwWCwbWtsbPQ6EgAA8JirBSUnJ0eSdOrUqXb7T5061fbc5/n9fmVmZrbbAABA9+ZqQSksLFROTo62bNnSti8UCmnXrl0aO3asm4cCAAAJrNOreJqamnT48OG2x0eOHFFDQ4OysrJUUFCg8vJyPfXUUxo4cKAKCwtVWVmpvLw8TZs2ze3sAAAgQXW6oOzZs0cTJ05sezxv3jxJUmlpqVavXq2f/vSnam5u1g9+8AOdPXtWt912m95991317NnT3eQAACBh+RzHcbwO8Z9CoZACgYCCwSDXowAAECfcfv/u9AxKV5kauF8pvh5XHFMbebvL8qB7uD1peofG8WcPAGLL82XGAAAAn0dBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmcB8UAAAQNbffv5lBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDmuF5TW1lZVVlaqsLBQaWlpGjBggBYtWiTHcdw+FAAASFApbr/g0qVLVV1drTVr1mjIkCHas2ePZs2apUAgoDlz5rh9OAAAkIBcLyh/+ctfNHXqVE2ePFmSdOONN2rdunXavXu324cCAAAJyvWPeMaNG6ctW7bo4MGDkqQPP/xQO3fuVHFx8WXHh8NhhUKhdhsAAOjeXJ9BWbBggUKhkAYPHqzk5GS1trZq8eLFKikpuez4qqoqLVy40O0YAAAgjrk+g/LWW2/p9ddf19q1a7V3716tWbNGP//5z7VmzZrLjq+oqFAwGGzbGhsb3Y4EAADijM9xeXlNfn6+FixYoLKysrZ9Tz31lF577TX94x//uOrPh0IhBQIBBYNBZWZmuhkNAADEiNvv367PoJw/f15JSe1fNjk5WZFIxO1DAQCABOX6NShTpkzR4sWLVVBQoCFDhuiDDz7Qc889p9mzZ7t9KAAAkKBc/4jn3LlzqqysVE1NjU6fPq28vDzNmDFDjz/+uFJTU6/683zEAwBA/HH7/dv1ghKtz05wgqYqxdfjimNrI293WS4AAPDFzF+DAgAAEC0KCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx/UbtbllQ/AV7oMCAEA3xQwKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMSfE6QCK7PWl6h8bVRt6OeRYAAOIJMygAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIdlxjHE8mEAQCLoyG0zLjotrh6TGRQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmMMyYwAAcEUduW1GKBRSIBBw7ZjMoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAnJgUlGPHjum+++5T7969lZaWpmHDhmnPnj2xOBQAAEhArt8H5cyZMxo/frwmTpyoTZs26frrr9ehQ4fUq1cvtw8FAB36Gnh18D4O8aC7nS+6L9cLytKlS5Wfn69Vq1a17SssLHT7MAAAIIG5/hHPO++8o9GjR2v69Onq06ePRo4cqZdfftntwwAAgATmekH5+OOPVV1drYEDB2rz5s360Y9+pDlz5mjNmjWXHR8OhxUKhdptAACge3P9I55IJKLRo0dryZIlkqSRI0dq//79WrFihUpLSy8ZX1VVpYULF7odAwAAxDHXZ1Byc3N1yy23tNt3880365NPPrns+IqKCgWDwbatsbHR7UgAACDOuD6DMn78eB04cKDdvoMHD6pfv36XHe/3++X3+92OAQAA4pjrBeXhhx/WuHHjtGTJEt19993avXu3Vq5cqZUrV7p9KADodstpU/JyvY6Abqgjy9svOi2uHtP1j3huvfVW1dTUaN26dRo6dKgWLVqk559/XiUlJW4fCgAAJCjXZ1Ak6a677tJdd90Vi5cGAADdAN/FAwAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMickqHjdMDdyvFF+PK47pbvc/QOx1t6+yL+47p0PjNv1rWcyzfFl3fvWJDo3bXJ8gX6mRke51Ald05P+1RPn/LBF05HcRCoUUCARcOyYzKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx+wy40SQCEs4OypRludaz+e2f91b6HWEqDUWX+d1hC4VGtLb6wiu6G7/r6HzmEEBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDlmlxl/evsoRXr09DpGVC4W9PE6QpdpuWO01xFcMeG/l3Zo3LZ358c8S1dozne8jhC1Xv9o9TpCl/q/v/2J1xGALsEMCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABz7N4HJTNZkR7JXseIyqe9Ur2O0GUu/FcPryO4Ijig+/zOJKnnKZ/XEaLmP3PR6whd6vak6R0aVxt5O+ZZgFhiBgUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5phdZpzU4ihJ8f1V8OFeZv/zus7XGt+/q8+0XON1gq6VcsHrBNFL3lrvdYQuxfJhdBfMoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx+yNOq45cUEpZtN1jNON6l+kh8/rCK44P+p/vI7QpZI+9TpB9FIG3eR1hC51e9L0Do3jfimIdzF/C3366afl8/lUXl4e60MBAIAEEdOCUldXp5deeknDhw+P5WEAAECCiVlBaWpqUklJiV5++WX16tUrVocBAAAJKGYFpaysTJMnT1ZRUdEVx4XDYYVCoXYbAADo3mJyGeobb7yhvXv3qq6u7qpjq6qqtHDhwljEAAAAccr1GZTGxkbNnTtXr7/+unr27HnV8RUVFQoGg21bY2Oj25EAAECccX0Gpb6+XqdPn9aoUaPa9rW2tmrHjh168cUXFQ6HlZyc3Pac3++X3++/5HXODr5GyalXLziWnb8+MZbedkSoX2Kcq9OaGOfRUWn/L+J1BAC4LNcLyqRJk7Rv3752+2bNmqXBgwdr/vz57coJAADA5bheUDIyMjR06NB2+9LT09W7d+9L9gMAAFxON7rXKQAAiBddcjP5bdu2dcVhAABAgmAGBQAAmENBAQAA5pj9vuD/6e1Tsj++l3w63WjBUkum43UEV6Q0xvfS9s5KDbV6HQGdxLcUo7tgBgUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAOWbvg+Kk/HuLZ5FUrxN0nYvXJcb9NHwRrxN0LV9r/N+/5uKBw15H6FK3J03v0Djul4J4xwwKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMwxu5C3NVWS3+sU0bnY0+sEXWfq6L1eR3CF/4zXCbpWUkv8LzPW/xrhdYIuxfJhdBfMoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx+x9UC5mOIr0jO97NHz6X61eR+gy6+tHdWjcCyNjHiUq1x6PeB2hS/mPB72OELXavzzmdQR8CbcnTb/qGO750r0xgwIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAc8wuM470apHSkr2OEZ30i14n6DL+k2b/KHXKxZ4+ryMA3QJLiHE1zKAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwByza0P96WElXxPfSz4fG/N/OjhyQYyTxJ4vvr94us01p7rP0nBJunjgsNcRAOCymEEBAADmUFAAAIA5rheUqqoq3XrrrcrIyFCfPn00bdo0HThwwO3DAACABOZ6Qdm+fbvKysr0/vvvq7a2Vi0tLbrjjjvU3Nzs9qEAAECCcv0i2Xfffbfd49WrV6tPnz6qr6/XN77xDbcPBwAAElDMV/EEg0FJUlZW1mWfD4fDCofDbY9DoVCsIwEAAONiepFsJBJReXm5xo8fr6FDh152TFVVlQKBQNuWn58fy0gAACAOxHQGpaysTPv379fOnTu/cExFRYXmzZvX9jgUCik/P1/+1ItKTm2JZbyYO/5pL68jdJlrTnqdwB2pm+q8jtCl+Mp7AFbFrKA8+OCD2rhxo3bs2KG+fft+4Ti/3y+/3x+rGAAAIA65XlAcx9FDDz2kmpoabdu2TYWFhW4fAgAAJDjXC0pZWZnWrl2rDRs2KCMjQydP/nvuPxAIKC0tze3DAQCABOT6RbLV1dUKBoOaMGGCcnNz27Y333zT7UMBAIAEFZOPeAAAAKLBd/EAAABzYn6jti+r5WKyIheTvY4Rld+8//UOjXv88reIiSvXHm/1OoIrkkbe4nWELlXcd06Hxm3617KYZwGA/8QMCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMMbvMON3/qZL9Pq9jRKXnCbP/eV3Xc8MuryO4YnP9Qq8jdCmWDwOwihkUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOaYvVHHn4avV2ZG8lVGLemiNF/ODdvCHRtYGeskAADEF2ZQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA55lbxOI4jSQo1Ra46NumaUBck+vIuXrzQoXGhkO3z6IiLTkuHxiXCuQIALvXZ3++fvY9Hy+e49Uou+fjjjzVgwACvYwAAgC/ho48+Uv/+/aN+HXMzKFlZWZKkTz75RIFAwOs4rgqFQsrPz1djY6MyMzO9juO6RD4/zi0+cW7xiXOLT8FgUAUFBW3v49EyV1CSkv59WUwgEEi4X95nMjMzE/bclODnx7nFJ84tPnFu8emz9/GoX8eVVwEAAHARBQUAAJhjrqD4/X498cQT8vv9XkdxXSKfmxL8/Di3+MS5xSfOLT65fW7mVvEAAACYm0EBAACgoAAAAHMoKAAAwBwKCgAAMMdMQdmxY4emTJmivLw8+Xw+rV+/3utIrqmqqtKtt96qjIwM9enTR9OmTdOBAwe8juWK6upqDR8+vO2mQ2PHjtWmTZu8jhUTTz/9tHw+n8rLy72OErUnn3xSPp+v3TZ48GCvY7nm2LFjuu+++9S7d2+lpaVp2LBh2rNnj9exXHHjjTde8rvz+XwqKyvzOlrUWltbVVlZqcLCQqWlpWnAgAFatGiRa9/t4rVz586pvLxc/fr1U1pamsaNG6e6ujqvY3Xa1d6vHcfR448/rtzcXKWlpamoqEiHDh3q9HHMFJTm5maNGDFCy5cv9zqK67Zv366ysjK9//77qq2tVUtLi+644w41Nzd7HS1qffv21dNPP636+nrt2bNH3/rWtzR16lT99a9/9Tqaq+rq6vTSSy9p+PDhXkdxzZAhQ3TixIm2befOnV5HcsWZM2c0fvx49ejRQ5s2bdLf/vY3/eIXv1CvXr28juaKurq6dr+32tpaSdL06dO9jha1pUuXqrq6Wi+++KL+/ve/a+nSpXrmmWf0q1/9yutornjggQdUW1urV199Vfv27dMdd9yhoqIiHTt2zOtonXK19+tnnnlGy5Yt04oVK7Rr1y6lp6frzjvv1IULHfsC3TaOQZKcmpoar2PEzOnTpx1Jzvbt272OEhO9evVyfv3rX3sdwzXnzp1zBg4c6NTW1jrf/OY3nblz53odKWpPPPGEM2LECK9jxMT8+fOd2267zesYXWbu3LnOgAEDnEgk4nWUqE2ePNmZPXt2u33f+c53nJKSEs8yueX8+fNOcnKys3Hjxnb7R40a5Tz66KOe5YrW59+vI5GIk5OT4zz77LNt+86ePev4/X5n3bp1nXptMzMo3UkwGJT+44sRE0Vra6veeOMNNTc3a+zYsV7HcU1ZWZkmT56soqIir6O46tChQ8rLy1P//v1VUlKiTz75xOtIrnjnnXc0evRoTZ8+XX369NHIkSP18ssvex0rJj799FO99tprmj17tnw+n9dxojZu3Dht2bJFBw8elCR9+OGH2rlzp4qLi72OFrWLFy+qtbVVPXv2bLc/LS0tYWYvJenIkSM6efJku78vA4GAxowZo/fee69Tr2XuywITXSQSUXl5ucaPH6+hQ4d6HccV+/bt09ixY3XhwgVde+21qqmp0S233OJ1LFe88cYb2rt3b1x+TnwlY8aM0erVqzVo0CCdOHFCCxcu1Ne//nXt379fGRkZXseLyscff6zq6mrNmzdPP/vZz1RXV6c5c+YoNTVVpaWlXsdz1fr163X27Fl973vf8zqKKxYsWKBQKKTBgwcrOTlZra2tWrx4sUpKSryOFrWMjAyNHTtWixYt0s0336zs7GytW7dO7733nm666Sav47nm5MmTkqTs7Ox2+7Ozs9ue6ygKShcrKyvT/v37E6oxDxo0SA0NDQoGg/rtb3+r0tJSbd++Pe5LSmNjo+bOnava2tpL/tUT7/7zX6TDhw/XmDFj1K9fP7311lv6/ve/72m2aEUiEY0ePVpLliyRJI0cOVL79+/XihUrEq6g/OY3v1FxcbHy8vK8juKKt956S6+//rrWrl2rIUOGqKGhQeXl5crLy0uI392rr76q2bNn64YbblBycrJGjRqlGTNmqL6+3utoJvERTxd68MEHtXHjRm3dulV9+/b1Oo5rUlNTddNNN+mrX/2qqqqqNGLECL3wwgtex4pafX29Tp8+rVGjRiklJUUpKSnavn27li1bppSUFLW2tnod0TXXXXedvvKVr+jw4cNeR4labm7uJeX45ptvTpiPsD5z9OhR/fGPf9QDDzzgdRTXPPLII1qwYIHuvfdeDRs2TDNnztTDDz+sqqoqr6O5YsCAAdq+fbuamprU2Nio3bt3q6WlRf379/c6mmtycnIkSadOnWq3/9SpU23PdRQFpQs4jqMHH3xQNTU1+tOf/qTCwkKvI8VUJBJROBz2OkbUJk2apH379qmhoaFtGz16tEpKStTQ0KDk5GSvI7qmqalJH330kXJzc72OErXx48dfsoz/4MGD6tevn2eZYmHVqlXq06ePJk+e7HUU15w/f15JSe3flpKTkxWJRDzLFAvp6enKzc3VmTNntHnzZk2dOtXrSK4pLCxUTk6OtmzZ0rYvFApp165dnb420cxHPE1NTe3+9XbkyBE1NDQoKytLBQUFnmaLVllZmdauXasNGzYoIyOj7XO4QCCgtLQ0r+NFpaKiQsXFxSooKNC5c+e0du1abdu2TZs3b/Y6WtQyMjIuuU4oPT1dvXv3jvvrh37yk59oypQp6tevn44fP64nnnhCycnJmjFjhtfRovbwww9r3LhxWrJkie6++27t3r1bK1eu1MqVK72O5ppIJKJVq1aptLRUKSlm/hqP2pQpU7R48WIVFBRoyJAh+uCDD/Tcc89p9uzZXkdzxebNm+U4jgYNGqTDhw/rkUce0eDBgzVr1iyvo3XK1d6vy8vL9dRTT2ngwIEqLCxUZWWl8vLyNG3atM4dyNX1RlHYunWrI+mSrbS01OtoUbvceUlyVq1a5XW0qM2ePdvp16+fk5qa6lx//fXOpEmTnD/84Q9ex4qZRFlmfM899zi5ublOamqqc8MNNzj33HOPc/jwYa9jueb3v/+9M3ToUMfv9zuDBw92Vq5c6XUkV23evNmR5Bw4cMDrKK4KhULO3LlznYKCAqdnz55O//79nUcffdQJh8NeR3PFm2++6fTv399JTU11cnJynLKyMufs2bNex+q0q71fRyIRp7Ky0snOznb8fr8zadKkL/Vn1eckyi36AABAwuAaFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDn/H/iu+7rgR1tVAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ, test=True))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" output = model(batch_src, batch_padding_mask)\n",
"print(criterion(output.squeeze(1), batch_labels).item())\n",
"x = batch_labels.detach().to(torch.float16).cpu().numpy().flatten()\n",
"y = output.detach().to(torch.float16).cpu().numpy().flatten()\n",
"plt.hist2d(x, y, bins=50, norm=mpl.colors.LogNorm())"
]
}
],
"metadata": {
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|