1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
|
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "LPphBnKR-aWF"
},
"source": [
"# Step 0: Imports"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Question: \n",
"\n",
"- Do the attention heads learn to attend to the same positional encodings\n",
"- do interp -- what is it doing? can we figure out?\n",
"- update: I think we should do interp once it's bigger. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ge5QvElvhCOw",
"outputId": "c7cdaefa-d6dc-44ad-c258-e4fb2aca97a5"
},
"outputs": [],
"source": [
"# using tqdm.auto glitches out collaborative editing\n",
"from tqdm import tqdm\n",
"import torch\n",
"import torch.nn as nn\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"\n",
"from math import sqrt\n",
"from collections import deque\n",
"import os\n",
"import random\n",
"from concurrent.futures import ProcessPoolExecutor\n",
"import pickle\n",
"\n",
"torch.manual_seed(42)\n",
"random.seed(42)\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_state": "idle",
"metadata": {
"id": "lylOX2POPwFL"
},
"outputs": [],
"source": [
"# VTXS numbers here are inclusive\n",
"MIN_VTXS = 3\n",
"MAX_VTXS = 15\n",
"MAX_TUNE_VTXS = 7\n",
"AVG_DEG = 2\n",
"SEQ_LEN = MAX_VTXS + 1 # means 32 edges, final token is the target vertex\n",
"PAD_TOKEN = 0\n",
"EPOCH_SIZE = 2**20\n",
"# vertices are labelled 1,2,...,63\n",
"# we also have a padding token which is 0."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gKt-yIpDebF1"
},
"source": [
"# Step 1: Generate synthetic data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"# original task data\n",
"# the data will be edge lists\n",
"# like this: [1 3 1 5 2 4 0 0 0 0 2]\n",
"# this represents edges (1,3), (1,5) (2,4)\n",
"# (the zeros are just padding tokens)\n",
"# the final 2 means which vertex we're going to \n",
"\n",
"# the label is the shortest distance from vtx 1 to vtx 2\n",
"# or \"number of vertices\" if no path exists\n",
"\n",
"def random_graph(n):\n",
" edge_list = []\n",
" adjacencies = [set() for _ in range(n+1)]\n",
" indices = [random.randint(1, n) for _ in range(AVG_DEG * n)]\n",
" for i in range(0, len(indices), 2):\n",
" u = indices[i]\n",
" v = indices[i + 1]\n",
" if u != v:\n",
" edge_list += [min(u,v),max(u,v)]\n",
" adjacencies[u].add(v)\n",
" adjacencies[v].add(u)\n",
" edge_list += [PAD_TOKEN]*(2*SEQ_LEN-1-len(edge_list))\n",
" return edge_list, adjacencies\n",
"\n",
"# input: G, represented as an adjacency list\n",
"# output: [number of vertices]+[d(1,i) for i in range(n)] if target=None\n",
"# if target is set to some value, then we instead just output that specific distance\n",
"def SSSP(G, target=2):\n",
" dist = [MAX_VTXS for _ in G]\n",
" dist[1] = 0\n",
" frontier = deque()\n",
" frontier.append(1)\n",
" while len(frontier) > 0:\n",
" vtx = frontier.popleft()\n",
" for x in G[vtx]:\n",
" if dist[x] == MAX_VTXS:\n",
" dist[x] = 1 + dist[vtx]\n",
" frontier.append(x)\n",
" if x == target:\n",
" return dist[target]\n",
" if target is not None:\n",
" return dist[target]\n",
" else:\n",
" return dist\n",
"\n",
"def mkbatch(size):\n",
" graphs1 = []\n",
" distance1 = []\n",
" \n",
" for i in range(size):\n",
" n = random.randint(MIN_VTXS, MAX_VTXS)\n",
" edge_list, adj_list = random_graph(n)\n",
" dist = SSSP(adj_list)\n",
" edge_list[-1] = 2 # target token\n",
" graphs1.append(edge_list)\n",
" distance1.append(dist)\n",
" \n",
" data = torch.tensor(graphs1)\n",
" labels = torch.tensor(distance1, dtype=torch.float32)\n",
" padding = data == PAD_TOKEN\n",
" return data, labels, padding\n",
"\n",
"def savebatch(size, idx):\n",
" data, labels, padding = mkbatch(size)\n",
" everything = {\n",
" \"data\": data,\n",
" \"labels\": labels,\n",
" \"padding\": padding,\n",
" }\n",
" \n",
" with open(f'data31/{idx}.pickle', 'wb') as file:\n",
" pickle.dump(everything, file)\n",
"\n",
"def vertices_on_shortest_12_path(G, target=2):\n",
" dist = [MAX_VTXS for _ in G]\n",
" parent = [-1 for _ in G]\n",
" dist[1] = 0\n",
" frontier = deque()\n",
" frontier.append(1)\n",
" while len(frontier) > 0:\n",
" vtx = frontier.popleft()\n",
" for x in G[vtx]:\n",
" if dist[x] == MAX_VTXS:\n",
" parent[x] = vtx\n",
" dist[x] = 1 + dist[vtx]\n",
" frontier.append(x)\n",
" if x == target:\n",
" path = [x]\n",
" while parent[x] != -1:\n",
" x = parent[x]\n",
" path.append(x)\n",
" return list(reversed(path))\n",
" return []\n",
"\n",
"def mktunebatch(size, test=False):\n",
" graphs = []\n",
" distance = []\n",
" \n",
" for i in range(size):\n",
" n = random.randint(MIN_VTXS, MAX_VTXS if test else MAX_TUNE_VTXS)\n",
" while True:\n",
" edge_list, adj_list = random_graph(n)\n",
" path = vertices_on_shortest_12_path(adj_list)\n",
" if len(path) > 1:\n",
" target_vtx_idx = random.randrange(1, len(path))\n",
" target_vtx = path[target_vtx_idx]\n",
" if target_vtx > MAX_TUNE_VTXS:\n",
" continue\n",
" edge_list[-1] = target_vtx\n",
" graphs.append(edge_list)\n",
" distance.append(target_vtx_idx)\n",
" break\n",
" \n",
" data = torch.tensor(graphs)\n",
" labels = torch.tensor(distance, dtype=torch.float32)\n",
" padding = data == PAD_TOKEN\n",
" return data, labels, padding"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"# Only need to run this once to generate training data\n",
"# RESTART THE KERNEL BEFORE RUNNING AND ONLY RUN THE CELLS ABOVE\n",
"# Python is slow and awful\n",
"\n",
"# with ProcessPoolExecutor() as executor:\n",
"# for i in range(1000):\n",
"# executor.submit(savebatch, EPOCH_SIZE, i)\n",
"# executor.shutdown()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 1, 2, 5, 12, 3, 12, 2, 11, 9, 12, 2, 10, 1, 7, 1, 2, 9, 10,\n",
" 1, 9, 4, 12, 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
" [ 9, 12, 4, 7, 8, 10, 5, 13, 1, 13, 3, 13, 7, 12, 5, 6, 3, 4,\n",
" 6, 13, 2, 7, 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
" [ 1, 5, 8, 12, 2, 9, 2, 7, 5, 9, 10, 11, 6, 10, 4, 12, 1, 2,\n",
" 4, 11, 2, 5, 2, 4, 0, 0, 0, 0, 0, 0, 2],\n",
" [ 5, 8, 3, 6, 4, 5, 2, 3, 4, 9, 3, 8, 5, 7, 4, 9, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2],\n",
" [ 6, 13, 1, 13, 1, 4, 6, 13, 5, 7, 2, 4, 10, 12, 4, 6, 8, 11,\n",
" 7, 11, 3, 8, 3, 5, 4, 12, 0, 0, 0, 0, 2]]),\n",
" tensor([ 1., 5., 1., 15., 2.]),\n",
" tensor([[False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, True, True, True, True, True, True, True, True,\n",
" False],\n",
" [False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, True, True, True, True, True, True, True, True,\n",
" False],\n",
" [False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, True, True, True, True, True, True,\n",
" False],\n",
" [False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True,\n",
" False],\n",
" [False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, True, True, True, True,\n",
" False]]))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mkbatch(5)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([523., 0., 0., 0., 390., 0., 0., 0., 0., 238., 0.,\n",
" 0., 0., 92., 0., 0., 0., 0., 40., 0., 0., 0.,\n",
" 15., 0., 0., 0., 0., 5., 0., 0., 0., 0., 2.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 743.]),\n",
" array([ 1. , 1.21875, 1.4375 , 1.65625, 1.875 , 2.09375,\n",
" 2.3125 , 2.53125, 2.75 , 2.96875, 3.1875 , 3.40625,\n",
" 3.625 , 3.84375, 4.0625 , 4.28125, 4.5 , 4.71875,\n",
" 4.9375 , 5.15625, 5.375 , 5.59375, 5.8125 , 6.03125,\n",
" 6.25 , 6.46875, 6.6875 , 6.90625, 7.125 , 7.34375,\n",
" 7.5625 , 7.78125, 8. , 8.21875, 8.4375 , 8.65625,\n",
" 8.875 , 9.09375, 9.3125 , 9.53125, 9.75 , 9.96875,\n",
" 10.1875 , 10.40625, 10.625 , 10.84375, 11.0625 , 11.28125,\n",
" 11.5 , 11.71875, 11.9375 , 12.15625, 12.375 , 12.59375,\n",
" 12.8125 , 13.03125, 13.25 , 13.46875, 13.6875 , 13.90625,\n",
" 14.125 , 14.34375, 14.5625 , 14.78125, 15. ]),\n",
" <BarContainer object of 64 artists>)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAl0UlEQVR4nO3df1RU953/8Rc/B0RnCGyYkRWQ7rpFEoyJpjIxu+0qK7WsW480qTmU0MSNp57RBGmtsqsm1UQMu42pXdSa46p7Gtat56xpxcQEiSFtBSS47jGaJWZrAgmZoWctjJrDgDDfP75l2vFHklFkPsDzcc49J3M/n2Hel2PMM8MME+H3+/0CAAAwSGS4BwAAALgSgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAONHhHuBGDAwMqKOjQxMmTFBERES4xwEAAJ+D3+/XhQsXlJqaqsjIT3+OZEQGSkdHh9LS0sI9BgAAuAHt7e2aNGnSp+4ZkYEyYcIE6fcXaLVawz0OAAD4HLxer9LS0gL/Hf80IzJQBn+sY7VaCRQAAEaYz/PyDF4kCwAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA40SHewAAABAek9ccuu7a+5sLhnWWK/EMCgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA44QUKJMnT1ZERMRVh8vlkiT19PTI5XIpOTlZ48ePV2FhoTweT9DXaGtrU0FBgcaNG6eUlBStWrVKly9fHtqrAgAAI1pIgdLc3KyPP/44cNTW1kqSHnjgAUnSypUrdfDgQe3fv1/19fXq6OjQokWLAvfv7+9XQUGBent7dezYMe3du1d79uzR+vXrh/q6AADACBbh9/v9N3rn0tJS1dTU6OzZs/J6vbr99ttVXV2tb3zjG5Kk//mf/9HUqVPV0NCg3NxcvfLKK/rbv/1bdXR0yG63S5J27Nih1atX67e//a1iY2M/1+N6vV7ZbDZ1d3fLarXe6PgAAIxpk9ccuu7a+5sLhvzxQvnv9w2/BqW3t1c//elP9eijjyoiIkItLS3q6+tTXl5eYE9WVpbS09PV0NAgSWpoaFBOTk4gTiQpPz9fXq9Xp0+fvtFRAADAKBN9o3d86aWX1NXVpW9/+9uSJLfbrdjYWCUmJgbts9vtcrvdgT1/HCeD64Nr1+Pz+eTz+QK3vV7vjY4NAABGgBt+BmXXrl2aP3++UlNTh3aia6ioqJDNZgscaWlpt/wxAQBA+NxQoHzwwQc6cuSI/v7v/z5wzuFwqLe3V11dXUF7PR6PHA5HYM+V7+oZvD2451rKy8vV3d0dONrb229kbAAAMELcUKDs3r1bKSkpKij4wwtoZsyYoZiYGNXV1QXOtba2qq2tTU6nU5LkdDp16tQpdXZ2BvbU1tbKarUqOzv7uo9nsVhktVqDDgAAMHqF/BqUgYEB7d69WyUlJYqO/sPdbTablixZorKyMiUlJclqtWrFihVyOp3Kzc2VJM2bN0/Z2dkqLi5WZWWl3G631q5dK5fLJYvFMrRXBgAARqyQA+XIkSNqa2vTo48+etXali1bFBkZqcLCQvl8PuXn52vbtm2B9aioKNXU1GjZsmVyOp1KSEhQSUmJNmzYcPNXAgAARo2b+j0o4cLvQQEA4OaNyt+DAgAAcKsQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOOEHCgfffSRvvWtbyk5OVnx8fHKycnRW2+9FVj3+/1av369Jk6cqPj4eOXl5ens2bNBX+P8+fMqKiqS1WpVYmKilixZoosXLw7NFQEAgBEvpED53e9+p9mzZysmJkavvPKKzpw5ox/+8Ie67bbbAnsqKyu1detW7dixQ01NTUpISFB+fr56enoCe4qKinT69GnV1taqpqZGb775ppYuXTq0VwYAAEasCL/f7/+8m9esWaNf//rX+uUvf3nNdb/fr9TUVH33u9/V9773PUlSd3e37Ha79uzZo8WLF+udd95Rdna2mpubNXPmTEnS4cOH9bWvfU0ffvihUlNTP3MOr9crm82m7u5uWa3Wz3+1AAAgYPKaQ9dde39zwZA/Xij//Q7pGZRf/OIXmjlzph544AGlpKTo7rvv1gsvvBBYP3funNxut/Ly8gLnbDabZs2apYaGBklSQ0ODEhMTA3EiSXl5eYqMjFRTU9M1H9fn88nr9QYdAABg9AopUH7zm99o+/btmjJlil599VUtW7ZMjz/+uPbu3StJcrvdkiS73R50P7vdHlhzu91KSUkJWo+OjlZSUlJgz5UqKipks9kCR1paWmhXCQAARpSQAmVgYED33HOPNm3apLvvvltLly7VY489ph07dty6CSWVl5eru7s7cLS3t9/SxwMAAOEVUqBMnDhR2dnZQeemTp2qtrY2SZLD4ZAkeTyeoD0ejyew5nA41NnZGbR++fJlnT9/PrDnShaLRVarNegAAACjV0iBMnv2bLW2tgade/fdd5WRkSFJyszMlMPhUF1dXWDd6/WqqalJTqdTkuR0OtXV1aWWlpbAntdff10DAwOaNWvWzV4PAAAYBaJD2bxy5Urdd9992rRpkx588EEdP35cO3fu1M6dOyVJERERKi0t1dNPP60pU6YoMzNT69atU2pqqhYuXCj9/hmXr371q4EfDfX19Wn58uVavHjx53oHDwAAGP1CCpR7771XBw4cUHl5uTZs2KDMzEw9//zzKioqCuz5/ve/r0uXLmnp0qXq6urS/fffr8OHDysuLi6w58UXX9Ty5cs1d+5cRUZGqrCwUFu3bh3aKwMAACNWSL8HxRT8HhQAAG7eqPk9KAAAAMOBQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxokO9wAmmrzm0HXX3t9cMKyzAAAwFvEMCgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwTkiB8tRTTykiIiLoyMrKCqz39PTI5XIpOTlZ48ePV2FhoTweT9DXaGtrU0FBgcaNG6eUlBStWrVKly9fHrorAgAAI17IbzO+4447dOTIkT98geg/fImVK1fq0KFD2r9/v2w2m5YvX65Fixbp17/+tSSpv79fBQUFcjgcOnbsmD7++GM9/PDDiomJ0aZNm4bqmgAAwAgXcqBER0fL4XBcdb67u1u7du1SdXW15syZI0navXu3pk6dqsbGRuXm5uq1117TmTNndOTIEdntdk2fPl0bN27U6tWr9dRTTyk2NnZorgoAAIxoIb8G5ezZs0pNTdUXvvAFFRUVqa2tTZLU0tKivr4+5eXlBfZmZWUpPT1dDQ0NkqSGhgbl5OTIbrcH9uTn58vr9er06dPXfUyfzyev1xt0AACA0SukQJk1a5b27Nmjw4cPa/v27Tp37pz+8i//UhcuXJDb7VZsbKwSExOD7mO32+V2uyVJbrc7KE4G1wfXrqeiokI2my1wpKWlhTI2AAAYYUL6Ec/8+fMD/zxt2jTNmjVLGRkZ+tnPfqb4+PhbMZ8kqby8XGVlZYHbXq+XSAEAYBS7qbcZJyYm6i/+4i/03nvvyeFwqLe3V11dXUF7PB5P4DUrDofjqnf1DN6+1utaBlksFlmt1qADAACMXjcVKBcvXtT//u//auLEiZoxY4ZiYmJUV1cXWG9tbVVbW5ucTqckyel06tSpU+rs7Azsqa2tldVqVXZ29s2MAgAARpGQfsTzve99TwsWLFBGRoY6Ojr05JNPKioqSg899JBsNpuWLFmisrIyJSUlyWq1asWKFXI6ncrNzZUkzZs3T9nZ2SouLlZlZaXcbrfWrl0rl8sli8Vyq64RAACMMCEFyocffqiHHnpI//d//6fbb79d999/vxobG3X77bdLkrZs2aLIyEgVFhbK5/MpPz9f27ZtC9w/KipKNTU1WrZsmZxOpxISElRSUqINGzYM/ZUBAIARK6RA2bdv36eux8XFqaqqSlVVVdfdk5GRoZdffjmUhwUAAGMMn8UDAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMc1OBsnnzZkVERKi0tDRwrqenRy6XS8nJyRo/frwKCwvl8XiC7tfW1qaCggKNGzdOKSkpWrVqlS5fvnwzowAAgFHkhgOlublZP/nJTzRt2rSg8ytXrtTBgwe1f/9+1dfXq6OjQ4sWLQqs9/f3q6CgQL29vTp27Jj27t2rPXv2aP369Td3JQAAYNS4oUC5ePGiioqK9MILL+i2224LnO/u7tauXbv03HPPac6cOZoxY4Z2796tY8eOqbGxUZL02muv6cyZM/rpT3+q6dOna/78+dq4caOqqqrU29s7dFcGAABGrBsKFJfLpYKCAuXl5QWdb2lpUV9fX9D5rKwspaenq6GhQZLU0NCgnJwc2e32wJ78/Hx5vV6dPn36mo/n8/nk9XqDDgAAMHpFh3qHffv26cSJE2pubr5qze12KzY2VomJiUHn7Xa73G53YM8fx8ng+uDatVRUVOgHP/hBqKMCAIARKqRnUNrb2/XEE0/oxRdfVFxc3K2b6grl5eXq7u4OHO3t7cP22AAAYPiFFCgtLS3q7OzUPffco+joaEVHR6u+vl5bt25VdHS07Ha7ent71dXVFXQ/j8cjh8MhSXI4HFe9q2fw9uCeK1ksFlmt1qADAACMXiEFyty5c3Xq1CmdPHkycMycOVNFRUWBf46JiVFdXV3gPq2trWpra5PT6ZQkOZ1OnTp1Sp2dnYE9tbW1slqtys7OHsprAwAAI1RIr0GZMGGC7rzzzqBzCQkJSk5ODpxfsmSJysrKlJSUJKvVqhUrVsjpdCo3N1eSNG/ePGVnZ6u4uFiVlZVyu91au3atXC6XLBbLUF4bAAAYoUJ+kexn2bJliyIjI1VYWCifz6f8/Hxt27YtsB4VFaWamhotW7ZMTqdTCQkJKikp0YYNG4Z6lBFj8ppD1117f3PBsM4CAIAJbjpQ3njjjaDbcXFxqqqqUlVV1XXvk5GRoZdffvlmHxoAAIxSfBYPAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOCEFyvbt2zVt2jRZrVZZrVY5nU698sorgfWenh65XC4lJydr/PjxKiwslMfjCfoabW1tKigo0Lhx45SSkqJVq1bp8uXLQ3dFAABgxAspUCZNmqTNmzerpaVFb731lubMmaOvf/3rOn36tCRp5cqVOnjwoPbv36/6+np1dHRo0aJFgfv39/eroKBAvb29OnbsmPbu3as9e/Zo/fr1Q39lAABgxIoOZfOCBQuCbj/zzDPavn27GhsbNWnSJO3atUvV1dWaM2eOJGn37t2aOnWqGhsblZubq9dee01nzpzRkSNHZLfbNX36dG3cuFGrV6/WU089pdjY2KG9OgAAMCLd8GtQ+vv7tW/fPl26dElOp1MtLS3q6+tTXl5eYE9WVpbS09PV0NAgSWpoaFBOTo7sdntgT35+vrxeb+BZmGvx+Xzyer1BBwAAGL1CDpRTp05p/Pjxslgs+s53vqMDBw4oOztbbrdbsbGxSkxMDNpvt9vldrslSW63OyhOBtcH166noqJCNpstcKSlpYU6NgAAGEFCDpQvfvGLOnnypJqamrRs2TKVlJTozJkzt2a63ysvL1d3d3fgaG9vv6WPBwAAwiuk16BIUmxsrP78z/9ckjRjxgw1NzfrRz/6kb75zW+qt7dXXV1dQc+ieDweORwOSZLD4dDx48eDvt7gu3wG91yLxWKRxWIJdVQAADBC3fTvQRkYGJDP59OMGTMUExOjurq6wFpra6va2trkdDolSU6nU6dOnVJnZ2dgT21traxWq7Kzs292FAAAMEqE9AxKeXm55s+fr/T0dF24cEHV1dV644039Oqrr8pms2nJkiUqKytTUlKSrFarVqxYIafTqdzcXEnSvHnzlJ2dreLiYlVWVsrtdmvt2rVyuVw8QwIAAAJCCpTOzk49/PDD+vjjj2Wz2TRt2jS9+uqr+pu/+RtJ0pYtWxQZGanCwkL5fD7l5+dr27ZtgftHRUWppqZGy5Ytk9PpVEJCgkpKSrRhw4ahvzIAADBihRQou3bt+tT1uLg4VVVVqaqq6rp7MjIy9PLLL4fysAAAYIzhs3gAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYJzrcA8Ask9ccuu7a+5sLhnUWAMDYxTMoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA44QUKBUVFbr33ns1YcIEpaSkaOHChWptbQ3a09PTI5fLpeTkZI0fP16FhYXyeDxBe9ra2lRQUKBx48YpJSVFq1at0uXLl4fmigAAwIgXUqDU19fL5XKpsbFRtbW16uvr07x583Tp0qXAnpUrV+rgwYPav3+/6uvr1dHRoUWLFgXW+/v7VVBQoN7eXh07dkx79+7Vnj17tH79+qG9MgAAMGJFh7L58OHDQbf37NmjlJQUtbS06K/+6q/U3d2tXbt2qbq6WnPmzJEk7d69W1OnTlVjY6Nyc3P12muv6cyZMzpy5IjsdrumT5+ujRs3avXq1XrqqacUGxs7tFcIAABGnJt6DUp3d7ckKSkpSZLU0tKivr4+5eXlBfZkZWUpPT1dDQ0NkqSGhgbl5OTIbrcH9uTn58vr9er06dM3Mw4AABglQnoG5Y8NDAyotLRUs2fP1p133ilJcrvdio2NVWJiYtBeu90ut9sd2PPHcTK4Prh2LT6fTz6fL3Db6/Xe6NgAAGAEuOFnUFwul95++23t27dvaCe6hoqKCtlstsCRlpZ2yx8TAACEzw0FyvLly1VTU6OjR49q0qRJgfMOh0O9vb3q6uoK2u/xeORwOAJ7rnxXz+DtwT1XKi8vV3d3d+Bob2+/kbEBAMAIEVKg+P1+LV++XAcOHNDrr7+uzMzMoPUZM2YoJiZGdXV1gXOtra1qa2uT0+mUJDmdTp06dUqdnZ2BPbW1tbJarcrOzr7m41osFlmt1qADAACMXiG9BsXlcqm6ulo///nPNWHChMBrRmw2m+Lj42Wz2bRkyRKVlZUpKSlJVqtVK1askNPpVG5uriRp3rx5ys7OVnFxsSorK+V2u7V27Vq5XC5ZLJZbc5UAAGBECSlQtm/fLkn6yle+EnR+9+7d+va3vy1J2rJliyIjI1VYWCifz6f8/Hxt27YtsDcqKko1NTVatmyZnE6nEhISVFJSog0bNgzNFQEAgBEvpEDx+/2fuScuLk5VVVWqqqq67p6MjAy9/PLLoTw0AAAYQ/gsHgAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYJyQA+XNN9/UggULlJqaqoiICL300ktB636/X+vXr9fEiRMVHx+vvLw8nT17NmjP+fPnVVRUJKvVqsTERC1ZskQXL168+asBAACjQsiBcunSJd11112qqqq65nplZaW2bt2qHTt2qKmpSQkJCcrPz1dPT09gT1FRkU6fPq3a2lrV1NTozTff1NKlS2/uSgAAwKgRHeod5s+fr/nz519zze/36/nnn9fatWv19a9/XZL0b//2b7Lb7XrppZe0ePFivfPOOzp8+LCam5s1c+ZMSdKPf/xjfe1rX9M///M/KzU19WavCQAAjHBD+hqUc+fOye12Ky8vL3DOZrNp1qxZamhokCQ1NDQoMTExECeSlJeXp8jISDU1NV3z6/p8Pnm93qADAACMXkMaKG63W5Jkt9uDztvt9sCa2+1WSkpK0Hp0dLSSkpICe65UUVEhm80WONLS0oZybAAAYJgR8S6e8vJydXd3B4729vZwjwQAAG6hIQ0Uh8MhSfJ4PEHnPR5PYM3hcKizszNo/fLlyzp//nxgz5UsFousVmvQAQAARq8hDZTMzEw5HA7V1dUFznm9XjU1NcnpdEqSnE6nurq61NLSEtjz+uuva2BgQLNmzRrKcQAAwAgV8rt4Ll68qPfeey9w+9y5czp58qSSkpKUnp6u0tJSPf3005oyZYoyMzO1bt06paamauHChZKkqVOn6qtf/aoee+wx7dixQ319fVq+fLkWL17MO3jGoMlrDl137f3NBcM6CwDAHCEHyltvvaW//uu/DtwuKyuTJJWUlGjPnj36/ve/r0uXLmnp0qXq6urS/fffr8OHDysuLi5wnxdffFHLly/X3LlzFRkZqcLCQm3dunWorgkAAIxwIQfKV77yFfn9/uuuR0REaMOGDdqwYcN19yQlJam6ujrUhwYAAGPEiHgXDwAAGFsIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxosM9AGCKyWsOfer6+5sLhm0WABjreAYFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcfgsHiCMPu3zf/jsHwBjGc+gAAAA4xAoAADAOGENlKqqKk2ePFlxcXGaNWuWjh8/Hs5xAACAIcIWKP/xH/+hsrIyPfnkkzpx4oTuuusu5efnq7OzM1wjAQAAQ4TtRbLPPfecHnvsMT3yyCOSpB07dujQoUP613/9V61ZsyZcYwFjFi/YBWCSsARKb2+vWlpaVF5eHjgXGRmpvLw8NTQ0XLXf5/PJ5/MFbnd3d0uSvF7vLZlvwPfJddduxWMO9+OZNMtIuXaN8esf7lnufPLV6669/YP8YZ0FGM2G+9/7wa/p9/s/e7M/DD766CO/JP+xY8eCzq9atcr/pS996ar9Tz75pF8SBwcHBwcHxyg42tvbP7MVRsTvQSkvL1dZWVng9sDAgM6fP6/k5GRFRESEdbah5vV6lZaWpvb2dlmt1nCPM+zG+vWL78GYv37xPeD6R/H1+/1+XbhwQampqZ+5NyyB8id/8ieKioqSx+MJOu/xeORwOK7ab7FYZLFYgs4lJibe8jnDyWq1jro/mKEY69cvvgdj/vrF94DrH6XXb7PZPte+sLyLJzY2VjNmzFBdXV3g3MDAgOrq6uR0OsMxEgAAMEjYfsRTVlamkpISzZw5U1/60pf0/PPP69KlS4F39QAAgLErbIHyzW9+U7/97W+1fv16ud1uTZ8+XYcPH5bdbg/XSEawWCx68sknr/qR1lgx1q9ffA/G/PWL7wHXP8avf1CE/3O91wcAAGD48Fk8AADAOAQKAAAwDoECAACMQ6AAAADjECgGqKio0L333qsJEyYoJSVFCxcuVGtra7jHCqvNmzcrIiJCpaWl4R5l2Hz00Uf61re+peTkZMXHxysnJ0dvvfVWuMcaNv39/Vq3bp0yMzMVHx+vP/uzP9PGjRs/32d2jEBvvvmmFixYoNTUVEVEROill14KWvf7/Vq/fr0mTpyo+Ph45eXl6ezZs2Gb91b4tO9BX1+fVq9erZycHCUkJCg1NVUPP/ywOjo6wjrzUPqsPwN/7Dvf+Y4iIiL0/PPPD+uM4USgGKC+vl4ul0uNjY2qra1VX1+f5s2bp0uXLoV7tLBobm7WT37yE02bNi3cowyb3/3ud5o9e7ZiYmL0yiuv6MyZM/rhD3+o2267LdyjDZtnn31W27dv17/8y7/onXfe0bPPPqvKykr9+Mc/Dvdot8SlS5d01113qaqq6prrlZWV2rp1q3bs2KGmpiYlJCQoPz9fPT09wz7rrfJp34NPPvlEJ06c0Lp163TixAn953/+p1pbW/V3f/d3YZn1VvisPwODDhw4oMbGxs/16+FHlaH8EEAMjc7OTr8kf319fbhHGXYXLlzwT5kyxV9bW+v/8pe/7H/iiSfCPdKwWL16tf/+++8P9xhhVVBQ4H/00UeDzi1atMhfVFQUtpmGiyT/gQMHArcHBgb8DofD/0//9E+Bc11dXX6LxeL/93//9zBNeWtd+T24luPHj/sl+T/44INhm2u4XO/6P/zwQ/+f/umf+t9++21/RkaGf8uWLWGZLxx4BsVA3d3dkqSkpKRwjzLsXC6XCgoKlJeXF+5RhtUvfvELzZw5Uw888IBSUlJ0991364UXXgj3WMPqvvvuU11dnd59911J0n//93/rV7/6lebPnx/u0YbduXPn5Ha7g/49sNlsmjVrlhoaGsI6Wzh1d3crIiJi1H8W26CBgQEVFxdr1apVuuOOO8I9zrAbEZ9mPJYMDAyotLRUs2fP1p133hnucYbVvn37dOLECTU3N4d7lGH3m9/8Rtu3b1dZWZn+4R/+Qc3NzXr88ccVGxurkpKScI83LNasWSOv16usrCxFRUWpv79fzzzzjIqKisI92rBzu92SdNVv1rbb7YG1saanp0erV6/WQw89NCo/QO9ann32WUVHR+vxxx8P9yhhQaAYxuVy6e2339avfvWrcI8yrNrb2/XEE0+otrZWcXFx4R5n2A0MDGjmzJnatGmTJOnuu+/W22+/rR07doyZQPnZz36mF198UdXV1brjjjt08uRJlZaWKjU1dcx8D3BtfX19evDBB+X3+7V9+/ZwjzMsWlpa9KMf/UgnTpxQREREuMcJC37EY5Dly5erpqZGR48e1aRJk8I9zrBqaWlRZ2en7rnnHkVHRys6Olr19fXaunWroqOj1d/fH+4Rb6mJEycqOzs76NzUqVPV1tYWtpmG26pVq7RmzRotXrxYOTk5Ki4u1sqVK1VRURHu0Yadw+GQJHk8nqDzHo8nsDZWDMbJBx98oNra2jHz7Mkvf/lLdXZ2Kj09PfB34gcffKDvfve7mjx5crjHGxY8g2IAv9+vFStW6MCBA3rjjTeUmZkZ7pGG3dy5c3Xq1Kmgc4888oiysrK0evVqRUVFhW224TB79uyr3lr+7rvvKiMjI2wzDbdPPvlEkZHB/88UFRWlgYGBsM0ULpmZmXI4HKqrq9P06dMlSV6vV01NTVq2bFm4xxs2g3Fy9uxZHT16VMnJyeEeadgUFxdf9Vq8/Px8FRcX65FHHgnbXMOJQDGAy+VSdXW1fv7zn2vChAmBnzHbbDbFx8eHe7xhMWHChKtec5OQkKDk5OQx8VqclStX6r777tOmTZv04IMP6vjx49q5c6d27twZ7tGGzYIFC/TMM88oPT1dd9xxh/7rv/5Lzz33nB599NFwj3ZLXLx4Ue+9917g9rlz53Ty5EklJSUpPT1dpaWlevrppzVlyhRlZmZq3bp1Sk1N1cKFC8M691D6tO/BxIkT9Y1vfEMnTpxQTU2N+vv7A383JiUlKTY2NoyTD43P+jNwZZDFxMTI4XDoi1/8YhimDYNwv40I///tZdc6du/eHe7Rwmosvc3Y7/f7Dx486L/zzjv9FovFn5WV5d+5c2e4RxpWXq/X/8QTT/jT09P9cXFx/i984Qv+f/zHf/T7fL5wj3ZLHD169Jr/3peUlPj9v3+r8bp16/x2u91vsVj8c+fO9be2toZ77CH1ad+Dc+fOXffvxqNHj4Z79CHxWX8GrjTW3mYc4R+tv6YRAACMWLxIFgAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYJz/B0FPZLxO9+txAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.hist(mkbatch(2048)[1].to(torch.float32).cpu(), bins=64)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([1516., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 420., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 94., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 16., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 2.]),\n",
" array([1. , 1.0625, 1.125 , 1.1875, 1.25 , 1.3125, 1.375 , 1.4375,\n",
" 1.5 , 1.5625, 1.625 , 1.6875, 1.75 , 1.8125, 1.875 , 1.9375,\n",
" 2. , 2.0625, 2.125 , 2.1875, 2.25 , 2.3125, 2.375 , 2.4375,\n",
" 2.5 , 2.5625, 2.625 , 2.6875, 2.75 , 2.8125, 2.875 , 2.9375,\n",
" 3. , 3.0625, 3.125 , 3.1875, 3.25 , 3.3125, 3.375 , 3.4375,\n",
" 3.5 , 3.5625, 3.625 , 3.6875, 3.75 , 3.8125, 3.875 , 3.9375,\n",
" 4. , 4.0625, 4.125 , 4.1875, 4.25 , 4.3125, 4.375 , 4.4375,\n",
" 4.5 , 4.5625, 4.625 , 4.6875, 4.75 , 4.8125, 4.875 , 4.9375,\n",
" 5. ]),\n",
" <BarContainer object of 64 artists>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoW0lEQVR4nO3df1RU94H//9cgAsY4g2iZcTaotE1VjJpEEjLmd2VFpe56SjehYa3bstrNQhqjMcJpNZqkxZisRlMqsW2CZ2tOfuxZ3cYkGKpVWkMQUValhprUKKkZyB5kRshHRLnfP771noyCghmEN3k+zrnnZO59z9z3u++e8Mw4jA7LsiwBAAAYJKK3JwAAANBdBAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA40T29gR6Snt7u06cOKEhQ4bI4XD09nQAAEAXWJalU6dOyev1KiKi8/dZ+m3AnDhxQgkJCb09DQAAcAXq6up03XXXdXq93wbMkCFDpL/9D+B0Ont7OgAAoAuCwaASEhLsn+Od6bcBc/6PjZxOJwEDAIBhLvfxDz7ECwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA40T29gRMNDrvzU6vfbQy/arOBQCALyPegQEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxuh0wZWVlmjVrlrxerxwOh7Zs2dLp2H/7t3+Tw+HQc889F3K+sbFRWVlZcjqdio2NVXZ2tpqbm0PGHDhwQHfeeadiYmKUkJCgVatWdXeqAACgn+p2wLS0tGjSpEkqLCy85LjNmzfrvffek9frvehaVlaWampqVFpaqq1bt6qsrEzz58+3rweDQU2bNk2jRo1SVVWVnnnmGS1fvlwbNmzo7nQBAEA/FNndJ8yYMUMzZsy45Ji//vWveuihh7Rt2zalp6eHXDt8+LBKSkpUWVmp5ORkSdLzzz+vmTNn6tlnn5XX69WmTZt05swZvfjii4qKitL48eNVXV2t1atXh4QOAAD4cgr7Z2Da29s1Z84cLV68WOPHj7/oenl5uWJjY+14kaTU1FRFRESooqLCHnPXXXcpKirKHpOWlqba2lqdPHmyw/u2trYqGAyGHAAAoH8Ke8A8/fTTioyM1I9+9KMOr/v9fsXHx4eci4yMVFxcnPx+vz3G7XaHjDn/+PyYCxUUFMjlctlHQkJCmFYEAAD6mrAGTFVVldauXavi4mI5HI5wvvRl5efnKxAI2EddXd1VvT8AALh6whowf/jDH9TQ0KCRI0cqMjJSkZGROnbsmBYtWqTRo0dLkjwejxoaGkKed/bsWTU2Nsrj8dhj6uvrQ8acf3x+zIWio6PldDpDDgAA0D+FNWDmzJmjAwcOqLq62j68Xq8WL16sbdu2SZJ8Pp+amppUVVVlP2/Hjh1qb29XSkqKPaasrExtbW32mNLSUo0ZM0ZDhw4N55QBAICBuv1bSM3Nzfrggw/sx0ePHlV1dbXi4uI0cuRIDRs2LGT8wIED5fF4NGbMGEnSuHHjNH36dM2bN09FRUVqa2tTbm6uMjMz7V+5fuCBB7RixQplZ2dryZIlOnTokNauXas1a9Z88RUDAADjdTtg9u7dq3vvvdd+vHDhQknS3LlzVVxc3KXX2LRpk3JzczV16lRFREQoIyND69ats6+7XC698847ysnJ0eTJkzV8+HAtW7aMX6EGAACSJIdlWVZvT6InBINBuVwuBQKBsH8eZnTem51e+2hleqfXAADApXX15zd/FxIAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjdDtgysrKNGvWLHm9XjkcDm3ZssW+1tbWpiVLlmjChAkaPHiwvF6vvve97+nEiRMhr9HY2KisrCw5nU7FxsYqOztbzc3NIWMOHDigO++8UzExMUpISNCqVau+yDoBAEA/0u2AaWlp0aRJk1RYWHjRtc8++0z79u3T0qVLtW/fPv33f/+3amtr9Q//8A8h47KyslRTU6PS0lJt3bpVZWVlmj9/vn09GAxq2rRpGjVqlKqqqvTMM89o+fLl2rBhw5WuEwAA9CMOy7KsK36yw6HNmzdr9uzZnY6prKzUrbfeqmPHjmnkyJE6fPiwkpKSVFlZqeTkZElSSUmJZs6cqY8//lher1fr16/Xj3/8Y/n9fkVFRUmS8vLytGXLFr3//vtdmlswGJTL5VIgEJDT6bzSJXZodN6bnV77aGV6WO8FAMCXSVd/fvf4Z2ACgYAcDodiY2MlSeXl5YqNjbXjRZJSU1MVERGhiooKe8xdd91lx4skpaWlqba2VidPnuzpKQMAgD4usidf/PTp01qyZIm++93v2hXl9/sVHx8fOonISMXFxcnv99tjEhMTQ8a43W772tChQy+6V2trq1pbW+3HwWCwR9YEAAB6X4+9A9PW1qb77rtPlmVp/fr1PXUbW0FBgVwul30kJCT0+D0BAEDv6JGAOR8vx44dU2lpacifYXk8HjU0NISMP3v2rBobG+XxeOwx9fX1IWPOPz4/5kL5+fkKBAL2UVdX1wMrAwAAfUHYA+Z8vBw5ckS/+93vNGzYsJDrPp9PTU1Nqqqqss/t2LFD7e3tSklJsceUlZWpra3NHlNaWqoxY8Z0+MdHkhQdHS2n0xlyAACA/qnbAdPc3Kzq6mpVV1dLko4eParq6modP35cbW1t+s53vqO9e/dq06ZNOnfunPx+v/x+v86cOSNJGjdunKZPn6558+Zpz5492r17t3Jzc5WZmSmv1ytJeuCBBxQVFaXs7GzV1NTo1Vdf1dq1a7Vw4cJwrx8AABio279GvXPnTt17770XnZ87d66WL19+0Ydvz/v973+ve+65R/rbF9nl5ubqjTfeUEREhDIyMrRu3Tpde+219vgDBw4oJydHlZWVGj58uB566CEtWbKky/Pk16gBADBPV39+f6HvgenLCBgAAMzTZ74HBgAAINwIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMbpdsCUlZVp1qxZ8nq9cjgc2rJlS8h1y7K0bNkyjRgxQoMGDVJqaqqOHDkSMqaxsVFZWVlyOp2KjY1Vdna2mpubQ8YcOHBAd955p2JiYpSQkKBVq1Zd6RoBAEA/0+2AaWlp0aRJk1RYWNjh9VWrVmndunUqKipSRUWFBg8erLS0NJ0+fdoek5WVpZqaGpWWlmrr1q0qKyvT/Pnz7evBYFDTpk3TqFGjVFVVpWeeeUbLly/Xhg0brnSdAACgH3FYlmVd8ZMdDm3evFmzZ8+W/vbui9fr1aJFi/Too49KkgKBgNxut4qLi5WZmanDhw8rKSlJlZWVSk5OliSVlJRo5syZ+vjjj+X1erV+/Xr9+Mc/lt/vV1RUlCQpLy9PW7Zs0fvvv9+luQWDQblcLgUCATmdzitdYodG573Z6bWPVqaH9V4AAHyZdPXnd1g/A3P06FH5/X6lpqba51wul1JSUlReXi5JKi8vV2xsrB0vkpSamqqIiAhVVFTYY+666y47XiQpLS1NtbW1OnnyZIf3bm1tVTAYDDkAAED/FNaA8fv9kiS32x1y3u1229f8fr/i4+NDrkdGRiouLi5kTEev8fl7XKigoEAul8s+EhISwrgyAADQl/Sb30LKz89XIBCwj7q6ut6eEgAA6CFhDRiPxyNJqq+vDzlfX19vX/N4PGpoaAi5fvbsWTU2NoaM6eg1Pn+PC0VHR8vpdIYcAACgfwprwCQmJsrj8Wj79u32uWAwqIqKCvl8PkmSz+dTU1OTqqqq7DE7duxQe3u7UlJS7DFlZWVqa2uzx5SWlmrMmDEaOnRoOKcMAAAM1O2AaW5uVnV1taqrq6W/fXC3urpax48fl8Ph0IIFC/TUU0/pt7/9rQ4ePKjvfe978nq99m8qjRs3TtOnT9e8efO0Z88e7d69W7m5ucrMzJTX65UkPfDAA4qKilJ2drZqamr06quvau3atVq4cGG41w8AAAwU2d0n7N27V/fee6/9+HxUzJ07V8XFxXrsscfU0tKi+fPnq6mpSXfccYdKSkoUExNjP2fTpk3Kzc3V1KlTFRERoYyMDK1bt86+7nK59M477ygnJ0eTJ0/W8OHDtWzZspDvigEAAF9eX+h7YPoyvgcGAADz9Mr3wAAAAFwNBAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACME/aAOXfunJYuXarExEQNGjRIX/va1/Tkk0/Ksix7jGVZWrZsmUaMGKFBgwYpNTVVR44cCXmdxsZGZWVlyel0KjY2VtnZ2Wpubg73dAEAgIHCHjBPP/201q9fr5///Oc6fPiwnn76aa1atUrPP/+8PWbVqlVat26dioqKVFFRocGDBystLU2nT5+2x2RlZammpkalpaXaunWrysrKNH/+/HBPFwAAGMhhff6tkTD41re+JbfbrV//+tf2uYyMDA0aNEi/+c1vZFmWvF6vFi1apEcffVSSFAgE5Ha7VVxcrMzMTB0+fFhJSUmqrKxUcnKyJKmkpEQzZ87Uxx9/LK/Xe9l5BINBuVwuBQIBOZ3OcC5Ro/Pe7PTaRyvTw3ovAAC+TLr68zvs78BMmTJF27dv15///GdJ0v/+7//qj3/8o2bMmCFJOnr0qPx+v1JTU+3nuFwupaSkqLy8XJJUXl6u2NhYO14kKTU1VREREaqoqAj3lAEAgGEiw/2CeXl5CgaDGjt2rAYMGKBz587ppz/9qbKysiRJfr9fkuR2u0Oe53a77Wt+v1/x8fGhE42MVFxcnD3mQq2trWptbbUfB4PBcC8NAAD0EWF/B+a1117Tpk2b9PLLL2vfvn3auHGjnn32WW3cuDHctwpRUFAgl8tlHwkJCT16PwAA0HvCHjCLFy9WXl6eMjMzNWHCBM2ZM0ePPPKICgoKJEkej0eSVF9fH/K8+vp6+5rH41FDQ0PI9bNnz6qxsdEec6H8/HwFAgH7qKurC/fSAABAHxH2gPnss88UERH6sgMGDFB7e7skKTExUR6PR9u3b7evB4NBVVRUyOfzSZJ8Pp+amppUVVVlj9mxY4fa29uVkpLS4X2jo6PldDpDDgAA0D+F/TMws2bN0k9/+lONHDlS48eP1/79+7V69Wr94Ac/kCQ5HA4tWLBATz31lK6//nolJiZq6dKl8nq9mj17tiRp3Lhxmj59uubNm6eioiK1tbUpNzdXmZmZXfoNJAAA0L+FPWCef/55LV26VP/+7/+uhoYGeb1e/fCHP9SyZcvsMY899phaWlo0f/58NTU16Y477lBJSYliYmLsMZs2bVJubq6mTp2qiIgIZWRkaN26deGeLgAAMFDYvwemr+B7YAAAME+vfQ8MAABATyNgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGKdHAuavf/2r/vmf/1nDhg3ToEGDNGHCBO3du9e+blmWli1bphEjRmjQoEFKTU3VkSNHQl6jsbFRWVlZcjqdio2NVXZ2tpqbm3tiugAAwDBhD5iTJ0/q9ttv18CBA/X222/rT3/6k/7jP/5DQ4cOtcesWrVK69atU1FRkSoqKjR48GClpaXp9OnT9pisrCzV1NSotLRUW7duVVlZmebPnx/u6QIAAAM5LMuywvmCeXl52r17t/7whz90eN2yLHm9Xi1atEiPPvqoJCkQCMjtdqu4uFiZmZk6fPiwkpKSVFlZqeTkZElSSUmJZs6cqY8//lher/ey8wgGg3K5XAoEAnI6neFcokbnvdnptY9Wpof1XgAAfJl09ed32N+B+e1vf6vk5GT90z/9k+Lj43XTTTfpl7/8pX396NGj8vv9Sk1Ntc+5XC6lpKSovLxcklReXq7Y2Fg7XiQpNTVVERERqqio6PC+ra2tCgaDIQcAAOifwh4wf/nLX7R+/Xpdf/312rZtmx588EH96Ec/0saNGyVJfr9fkuR2u0Oe53a77Wt+v1/x8fEh1yMjIxUXF2ePuVBBQYFcLpd9JCQkhHtpAACgjwh7wLS3t+vmm2/Wz372M910002aP3++5s2bp6KionDfKkR+fr4CgYB91NXV9ej9AABA7wl7wIwYMUJJSUkh58aNG6fjx49LkjwejySpvr4+ZEx9fb19zePxqKGhIeT62bNn1djYaI+5UHR0tJxOZ8gBAAD6p7AHzO23367a2tqQc3/+8581atQoSVJiYqI8Ho+2b99uXw8Gg6qoqJDP55Mk+Xw+NTU1qaqqyh6zY8cOtbe3KyUlJdxTBgAAhokM9ws+8sgjmjJlin72s5/pvvvu0549e7RhwwZt2LBBkuRwOLRgwQI99dRTuv7665WYmKilS5fK6/Vq9uzZ0t/esZk+fbr9R09tbW3Kzc1VZmZml34DCQAA9G9hD5hbbrlFmzdvVn5+vp544gklJibqueeeU1ZWlj3mscceU0tLi+bPn6+mpibdcccdKikpUUxMjD1m06ZNys3N1dSpUxUREaGMjAytW7cu3NMFAAAGCvv3wPQVfA8MAADm6bXvgQEAAOhpBAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACM0+MBs3LlSjkcDi1YsMA+d/r0aeXk5GjYsGG69tprlZGRofr6+pDnHT9+XOnp6brmmmsUHx+vxYsX6+zZsz09XXxJjM57s9MDAND39WjAVFZW6oUXXtDEiRNDzj/yyCN644039Prrr2vXrl06ceKEvv3tb9vXz507p/T0dJ05c0bvvvuuNm7cqOLiYi1btqwnpwsAAAzRYwHT3NysrKws/fKXv9TQoUPt84FAQL/+9a+1evVqffOb39TkyZP10ksv6d1339V7770nSXrnnXf0pz/9Sb/5zW904403asaMGXryySdVWFioM2fO9NSUAQCAIXosYHJycpSenq7U1NSQ81VVVWpraws5P3bsWI0cOVLl5eWSpPLyck2YMEFut9sek5aWpmAwqJqamg7v19raqmAwGHIAAID+KbInXvSVV17Rvn37VFlZedE1v9+vqKgoxcbGhpx3u93y+/32mM/Hy/nr5691pKCgQCtWrAjjKgAAQF8V9ndg6urq9PDDD2vTpk2KiYkJ98t3Kj8/X4FAwD7q6uqu2r0BAMDVFfaAqaqqUkNDg26++WZFRkYqMjJSu3bt0rp16xQZGSm3260zZ86oqakp5Hn19fXyeDySJI/Hc9FvJZ1/fH7MhaKjo+V0OkMOAADQP4U9YKZOnaqDBw+qurraPpKTk5WVlWX/88CBA7V9+3b7ObW1tTp+/Lh8Pp8kyefz6eDBg2poaLDHlJaWyul0KikpKdxTBgAAhgn7Z2CGDBmiG264IeTc4MGDNWzYMPt8dna2Fi5cqLi4ODmdTj300EPy+Xy67bbbJEnTpk1TUlKS5syZo1WrVsnv9+snP/mJcnJyFB0dHe4pAwAAw/TIh3gvZ82aNYqIiFBGRoZaW1uVlpamX/ziF/b1AQMGaOvWrXrwwQfl8/k0ePBgzZ07V0888URvTBcAAPQxVyVgdu7cGfI4JiZGhYWFKiws7PQ5o0aN0ltvvXUVZgcAAEzD34UEAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACMQ8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwTtgDpqCgQLfccouGDBmi+Ph4zZ49W7W1tSFjTp8+rZycHA0bNkzXXnutMjIyVF9fHzLm+PHjSk9P1zXXXKP4+HgtXrxYZ8+eDfd0AQCAgcIeMLt27VJOTo7ee+89lZaWqq2tTdOmTVNLS4s95pFHHtEbb7yh119/Xbt27dKJEyf07W9/275+7tw5paen68yZM3r33Xe1ceNGFRcXa9myZeGeLgAAMJDDsiyrJ2/w6aefKj4+Xrt27dJdd92lQCCgr3zlK3r55Zf1ne98R5L0/vvva9y4cSovL9dtt92mt99+W9/61rd04sQJud1uSVJRUZGWLFmiTz/9VFFRUZe9bzAYlMvlUiAQkNPpDOuaRue92em1j1amh/Ve6BnsIQD0TV39+d3jn4EJBAKSpLi4OElSVVWV2tralJqaao8ZO3asRo4cqfLycklSeXm5JkyYYMeLJKWlpSkYDKqmpqbD+7S2tioYDIYcAACgf+rRgGlvb9eCBQt0++2364YbbpAk+f1+RUVFKTY2NmSs2+2W3++3x3w+Xs5fP3+tIwUFBXK5XPaRkJDQQ6sCAAC9rUcDJicnR4cOHdIrr7zSk7eRJOXn5ysQCNhHXV1dj98TAAD0jsieeuHc3Fxt3bpVZWVluu666+zzHo9HZ86cUVNTU8i7MPX19fJ4PPaYPXv2hLze+d9SOj/mQtHR0YqOju6h1QAAgL4k7O/AWJal3Nxcbd68WTt27FBiYmLI9cmTJ2vgwIHavn27fa62tlbHjx+Xz+eTJPl8Ph08eFANDQ32mNLSUjmdTiUlJYV7ygAAwDBhfwcmJydHL7/8sv7nf/5HQ4YMsT+z4nK5NGjQILlcLmVnZ2vhwoWKi4uT0+nUQw89JJ/Pp9tuu02SNG3aNCUlJWnOnDlatWqV/H6/fvKTnygnJ4d3WQAAQPgDZv369ZKke+65J+T8Sy+9pH/5l3+RJK1Zs0YRERHKyMhQa2ur0tLS9Itf/MIeO2DAAG3dulUPPvigfD6fBg8erLlz5+qJJ54I93QBAICBwh4wXflamZiYGBUWFqqwsLDTMaNGjdJbb70V5tkBAID+gL8LCQAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYJ7K3JwAAV2J03pudXvtoZfpVnQuAq493YAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcQgYAABgHAIGAAAYh4ABAADGIWAAAIBxCBgAAGAcAgYAABiHgAEAAMYhYAAAgHEIGAAAYBwCBgAAGIeAAQAAxiFgAACAcfp0wBQWFmr06NGKiYlRSkqK9uzZ09tTAgAAfUBkb0+gM6+++qoWLlyooqIipaSk6LnnnlNaWppqa2sVHx/f29MDAHxBo/Pe7PTaRyvTr+pcYJ4++w7M6tWrNW/ePH3/+99XUlKSioqKdM011+jFF1/s7akBAIBe1iffgTlz5oyqqqqUn59vn4uIiFBqaqrKy8s7fE5ra6taW1vtx4FAQJIUDAbDPr/21s86vdYT90P4sYfmYw/Nxx6iI+f33rKsS47rkwHzf//3fzp37pzcbnfIebfbrffff7/D5xQUFGjFihUXnU9ISOixeXbE9dxVvR16AHtoPvbQfOwhTp06JZfL1en1PhkwVyI/P18LFy60H7e3t6uxsVHDhg2Tw+EI232CwaASEhJUV1cnp9MZttftS/r7Glmf+fr7Gvv7+vQlWCPru3KWZenUqVPyer2XHNcnA2b48OEaMGCA6uvrQ87X19fL4/F0+Jzo6GhFR0eHnIuNje2xOTqdzn75f8rP6+9rZH3m6+9r7O/r05dgjazvylzqnZfz+uSHeKOiojR58mRt377dPtfe3q7t27fL5/P16twAAEDv65PvwEjSwoULNXfuXCUnJ+vWW2/Vc889p5aWFn3/+9/v7akBAIBe1mcD5v7779enn36qZcuWye/368Ybb1RJSclFH+y92qKjo/X4449f9MdV/Ul/XyPrM19/X2N/X5++BGtkfT3PYV3u95QAAAD6mD75GRgAAIBLIWAAAIBxCBgAAGAcAgYAABiHgLlAWVmZZs2aJa/XK4fDoS1btlz2OTt37tTNN9+s6Ohoff3rX1dxcfFVmeuV6O76du7cKYfDcdHh9/uv2py7o6CgQLfccouGDBmi+Ph4zZ49W7W1tZd93uuvv66xY8cqJiZGEyZM0FtvvXVV5nslrmSNxcXFF+1hTEzMVZtzd6xfv14TJ060vyDL5/Pp7bffvuRzTNq/7q7PpL3ryMqVK+VwOLRgwYJLjjNpDy/UlTWatI/Lly+/aK5jx4695HN6Y/8ImAu0tLRo0qRJKiws7NL4o0ePKj09Xffee6+qq6u1YMEC/eu//qu2bdvW43O9Et1d33m1tbX65JNP7CM+Pr7H5vhF7Nq1Szk5OXrvvfdUWlqqtrY2TZs2TS0tLZ0+591339V3v/tdZWdna//+/Zo9e7Zmz56tQ4cOXdW5d9WVrFF/+8bMz+/hsWPHrtqcu+O6667TypUrVVVVpb179+qb3/ym/vEf/1E1NTUdjjdt/7q7Phm0dxeqrKzUCy+8oIkTJ15ynGl7+HldXaMM28fx48eHzPWPf/xjp2N7bf8sdEqStXnz5kuOeeyxx6zx48eHnLv//vuttLS0Hp7dF9eV9f3+97+3JFknT568avMKp4aGBkuStWvXrk7H3HfffVZ6enrIuZSUFOuHP/zhVZjhF9eVNb700kuWy+W6qvMKp6FDh1q/+tWvOrxm+v5Zl1mfqXt36tQp6/rrr7dKS0utu+++23r44Yc7HWvqHnZnjSbt4+OPP25NmjSpy+N7a/94B+YLKi8vV2pqasi5tLQ0lZeX99qcesKNN96oESNG6O///u+1e/fu3p5OlwUCAUlSXFxcp2NM38OurFGSmpubNWrUKCUkJFz2v/j7inPnzumVV15RS0tLp3+NiMn715X1ydC9y8nJUXp6+kV70xFT97A7a5Rh+3jkyBF5vV599atfVVZWlo4fP97p2N7avz77Tbym8Pv9F307sNvtVjAY1P/7f/9PgwYN6rW5hcOIESNUVFSk5ORktba26le/+pXuueceVVRU6Oabb+7t6V1Se3u7FixYoNtvv1033HBDp+M628O++jmfz+vqGseMGaMXX3xREydOVCAQ0LPPPqspU6aopqZG11133VWdc1ccPHhQPp9Pp0+f1rXXXqvNmzcrKSmpw7Em7l931mfa3knSK6+8on379qmysrJL403cw+6u0aR9TElJUXFxscaMGaNPPvlEK1as0J133qlDhw5pyJAhF43vrf0jYHBJY8aM0ZgxY+zHU6ZM0Ycffqg1a9boP//zP3t1bpeTk5OjQ4cOXfLPbk3X1TX6fL6Q/8KfMmWKxo0bpxdeeEFPPvnkVZhp94wZM0bV1dUKBAL6r//6L82dO1e7du3q9Ie8abqzPtP2rq6uTg8//LBKS0v77IdUv6grWaNJ+zhjxgz7nydOnKiUlBSNGjVKr732mrKzs3t1bp9HwHxBHo9H9fX1Iefq6+vldDqNf/elM7feemufj4Lc3Fxt3bpVZWVll/2vm8720OPx9PAsv5jurPFCAwcO1E033aQPPvigx+b3RURFRenrX/+6JGny5MmqrKzU2rVr9cILL1w01sT96876LtTX966qqkoNDQ0h79CeO3dOZWVl+vnPf67W1lYNGDAg5Dmm7eGVrPFCfX0fPy82Nlbf+MY3Op1rb+0fn4H5gnw+n7Zv3x5yrrS09JJ/nm266upqjRgxoren0SHLspSbm6vNmzdrx44dSkxMvOxzTNvDK1njhc6dO6eDBw/22X28UHt7u1pbWzu8Ztr+deRS67tQX9+7qVOn6uDBg6qurraP5ORkZWVlqbq6usMf7Kbt4ZWs8UJ9fR8/r7m5WR9++GGnc+21/evRjwgb6NSpU9b+/fut/fv3W5Ks1atXW/v377eOHTtmWZZl5eXlWXPmzLHH/+Uvf7GuueYaa/Hixdbhw4etwsJCa8CAAVZJSUkvrqJz3V3fmjVrrC1btlhHjhyxDh48aD388MNWRESE9bvf/a4XV9G5Bx980HK5XNbOnTutTz75xD4+++wze8ycOXOsvLw8+/Hu3butyMhI69lnn7UOHz5sPf7449bAgQOtgwcP9tIqLu1K1rhixQpr27Zt1ocffmhVVVVZmZmZVkxMjFVTU9NLq+hcXl6etWvXLuvo0aPWgQMHrLy8PMvhcFjvvPOOZfWD/evu+kzau85c+Bs6pu9hRy63RpP2cdGiRdbOnTuto0ePWrt377ZSU1Ot4cOHWw0NDZbVh/aPgLnA+V8bvvCYO3euZVmWNXfuXOvuu+++6Dk33nijFRUVZX31q1+1XnrppV6a/eV1d31PP/209bWvfc2KiYmx4uLirHvuucfasWNHL67g0jpam6SQPbn77rvt9Z732muvWd/4xjesqKgoa/z48dabb77ZC7PvmitZ44IFC6yRI0daUVFRltvttmbOnGnt27evl1ZwaT/4wQ+sUaNGWVFRUdZXvvIVa+rUqfYPd6sf7F9312fS3nXmwh/upu9hRy63RpP28f7777dGjBhhRUVFWX/3d39n3X///dYHH3xgX+8r++ew/v9/IQIAABiDz8AAAADjEDAAAMA4BAwAADAOAQMAAIxDwAAAAOMQMAAAwDgEDAAAMA4BAwAAjEPAAAAA4xAwAADAOAQMAAAwDgEDAACM8/8BbhuvIZklPe0AAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.hist(mktunebatch(2048)[1].to(torch.float32).cpu(), bins=64)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q3Cg_8UQep8g"
},
"source": [
"# Step 2: Define Transformer Model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_state": "idle",
"metadata": {
"id": "tLOWhg_CeWzH"
},
"outputs": [],
"source": [
"class TransformerModel(nn.Module):\n",
" def __init__(self, input_dim, model_dim, output_dim, num_heads, num_layers, seq_len, dropout):\n",
" super().__init__()\n",
" self.model_dim = model_dim\n",
" self.embedding = nn.Embedding(input_dim, model_dim // 2)\n",
" encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads,\n",
" dim_feedforward=model_dim*4,\n",
" dropout=dropout, batch_first=True)\n",
" self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)\n",
" self.fc_out = nn.Linear(model_dim, output_dim)\n",
"\n",
" def forward(self, src, key_padding_mask):\n",
" batch_sz = src.size(0)\n",
" src = torch.cat((src, torch.full((batch_sz, 1), MAX_VTXS + 1, device=device)), dim=1)\n",
" embed = torch.cat((self.embedding(src[:,::2]), self.embedding(src[:,1::2])), dim=2)\n",
" output = self.transformer_encoder(embed, src_key_padding_mask=key_padding_mask[:, ::2])\n",
" return self.fc_out(output[:, -1, :])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bpIeg86S-hBb"
},
"source": [
"# Step 3: Make Model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_state": "idle",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kWXvJRDYgFVP",
"outputId": "c13adb9d-6565-43b5-8437-20cef3dc0d16"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training data: 1048576K\n",
"Trainable parameters in the model: 500449\n"
]
}
],
"source": [
"# PARAMS\n",
"VOCAB_SIZE = 1 + MAX_VTXS + 1 # pad plus max number of vertices plus target token\n",
"MODEL_DIM = 64 # Dimension of model (embedding and transformer)\n",
"NEPOCHS = 1000\n",
"BSZ = 2**15 # Batch size\n",
"NHEADS = 2\n",
"NLAYERS = 10\n",
"DROPOUT = 0 # 0.2\n",
"model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
" output_dim=1, num_heads=NHEADS,\n",
" num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
" dropout=DROPOUT).to(device)\n",
"model = torch.compile(model)\n",
"\n",
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Training data: {NEPOCHS*EPOCH_SIZE//10**3}K\")\n",
"print(f\"Trainable parameters in the model: {trainable_params}\")\n",
"\n",
"train_err = []\n",
"len1 = []\n",
"len2 = []\n",
"len3 = []\n",
"leninf = []\n",
"epoch = 0\n",
"\n",
"# clear loss file\n",
"open('loss', 'w').close()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"# model = TransformerModel(input_dim=VOCAB_SIZE, model_dim=MODEL_DIM,\n",
"# output_dim=1, num_heads=NHEADS,\n",
"# num_layers=NLAYERS, seq_len=SEQ_LEN,\n",
"# dropout=DROPOUT).to(device)\n",
"# model = torch.compile(model)\n",
"# model.load_state_dict(torch.load('model.pth', weights_only=True))\n",
"\n",
"LR = 5e-4\n",
"\n",
"criterion = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LR)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f8Zn33m7CxL5"
},
"source": [
"# Step 4: Train the Model for the first task"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 486
},
"id": "pvTfzGmCeXU4",
"outputId": "0d3a20f3-23be-4c19-9eb6-46bfe11a48b1"
},
"outputs": [],
"source": [
"while epoch < NEPOCHS:\n",
" model.train()\n",
" with open(f\"data/{epoch}.pickle\", \"rb\") as f:\n",
" pickled_stuff = pickle.load(f)\n",
" data = pickled_stuff[\"data\"].to(device)\n",
" label = pickled_stuff[\"labels\"].to(device).to(torch.float32)\n",
" padding = pickled_stuff[\"padding\"].to(device)\n",
" dataset = TensorDataset(data, label, padding)\n",
" loader = DataLoader(dataset, batch_size=BSZ)\n",
" for batch_src, batch_labels, batch_padding_mask in tqdm(loader):\n",
" optimizer.zero_grad()\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" train_loss = loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" train_err.append(train_loss)\n",
" len1.append(criterion(output[batch_labels == 1].squeeze(1), batch_labels[batch_labels == 1]).item())\n",
" len2.append(criterion(output[batch_labels == 2].squeeze(1), batch_labels[batch_labels == 2]).item())\n",
" len3.append(criterion(output[batch_labels == 3].squeeze(1), batch_labels[batch_labels == 3]).item())\n",
" leninf.append(criterion(output[batch_labels == MAX_VTXS].squeeze(1), batch_labels[batch_labels == MAX_VTXS]).item())\n",
" \n",
" with open('loss', 'a') as f:\n",
" f.write(f\"{train_loss} {len1[-1]} {len2[-1]} {len3[-1]} {leninf[-1]}\\n\")\n",
" print(f\"Epoch {epoch}/{NEPOCHS} \\t Train Err: {train_err[-1]} {len1[-1]} {len2[-1]} {len3[-1]} {leninf[-1]}\")\n",
"\n",
" epoch += 1\n",
" if epoch % 10 == 0:\n",
" torch.save(model.state_dict(), f\"checkpoints/model_weights_{epoch}.pth\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"# \"\"\"\n",
"# Now let's figure out what it's doing. \n",
"\n",
"# step 1: figure out what people are attending to \n",
"# \"\"\"\n",
"\n",
"# example_graph, answer, padding = mkbatch(1)\n",
"# sentance_embeddings = model.full_embedding(example_graph)[0,:,:][example_graph.flatten() != 0]\n",
"# WQ,WK,WV = torch.split(model.transformer_encoder.layers[0].self_attn.in_proj_weight, (MODEL_DIM, MODEL_DIM, MODEL_DIM))\n",
"\n",
"# Q = sentance_embeddings@WQ\n",
"# K = sentance_embeddings@WK\n",
"\n",
"# raw_scores = Q @ K.T / sqrt(MODEL_DIM)\n",
"# soft = torch.softmax(raw_scores, dim=-1).detach().cpu().to(float).numpy()\n",
"# plt.imshow(soft)\n",
"# plt.show()\n",
"\n",
"# print(example_graph)\n",
"\n",
"# print(Q)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"with open('training-loss') as f:\n",
" train_err = list(map(float, f.read().split()))\n",
" plt.suptitle('Log MSE vs Epochs')\n",
" plt.plot(torch.log(torch.tensor(train_err)[:500]), label='Train', color='blue')\n",
" plt.xlabel('Epochs')\n",
" plt.ylabel('Log MSE')\n",
" plt.show()\n",
"\n",
"plt.suptitle('Log MSE vs Epochs')\n",
"plt.plot(torch.log(torch.tensor(train_err)), label='Train', color='blue')\n",
"plt.plot(torch.log(torch.tensor(len1)).to(torch.float16), label='Len 1', color='red')\n",
"plt.plot(torch.log(torch.tensor(len2)).to(torch.float16), label='Len 2', color='green')\n",
"plt.plot(torch.log(torch.tensor(len3)).to(torch.float16), label='Len 3', color='yellow')\n",
"plt.plot(torch.log(torch.tensor(len15)).to(torch.float16), label='Len 15', color='magenta')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Log MSE')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_state": "idle",
"metadata": {
"id": "LoGEmM5lH7_A"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhIAAAGlCAYAAACvGh/vAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqQElEQVR4nO3df3RU9Z3/8ddkIgPFJBpYEgYTSVsUA5goBFbwu8Kab/mmCGV7/NUvhRTPwVYHAdNjge0CWpWIntJUmy+IexTbryju2UJdPKWyWSS6VQhJ47fuVn6cRoywSWRbMxCXgHPv9w/JrJEL3Nx7585k8nyc8zk9c3PvfN5XK3nz/vwKmKZpCgAAwIGMZAcAAAD6LxIJAADgGIkEAABwjEQCAAA4RiIBAAAcI5EAAACOkUgAAADHSCQAAIBjJBIAAMAxEgkAAOAYiQQAAHCMRAIAgDS2Y8cOXX311RozZoz+/u//3vPvD3BoFwAA6enTTz9VcXGxdu/erZycHE2cOFG//e1vNWzYMM/6oCIBAECa2rdvn8aNG6dRo0bp0ksvVUVFhV577TVP+yCRAAAgRdXX12v27NkKh8MKBALavn37OffU1tZq9OjRGjx4sKZMmaJ9+/bFf3bs2DGNGjUq/nnUqFE6evSopzGSSAAAkKK6urpUUlKi2tpay59v3bpVVVVVWrNmjZqamlRSUqKZM2eqo6PDtxgzfevJJsMwdOzYMWVlZSkQCCQ7HABACjNNUydOnFA4HFZGRuL+bnzq1CmdPn3a9feYpnnO77ZQKKRQKGR5f0VFhSoqKs77fevXr9eiRYu0cOFCSdLGjRv16quv6tlnn9WKFSsUDod7VSCOHj2qyZMnu36PXswU09raakqi0Wg0Gs12a21tTdjvpf/6r/8y80cEPYnz0ksvPefamjVrbMUhydy2bVv8c3d3txkMBntdM03TXLBggTlnzhzTNE3zzJkz5le/+lXzww8/NE+cOGFeddVV5vHjxz3955NyFYmsrCxJUmtrq7Kzs5MdDgDAIaP9uoT3ET1p6Mrr34//7kiE06dPq60jppbGK5Wd5bzqET1hqGjikXN+v52vGnExx48fVywWU15eXq/reXl5eu+99yRJmZmZ+vGPf6wZM2bIMAz94Ac/8HTFhlJxaKOn5JOdnU0iAQD9mPFJ0Le+/BgKz87KcJVIxL/H599vc+bM0Zw5cxL2/SmXSAAAkIpipqGY6e55Lw0fPlzBYFDt7e29rre3tys/P9/Tvi6EVRsAANhgyHTdvDRo0CBNnDhRdXV1/x2jYaiurk433HCDp31dCBUJAABsMGTITU2h5+mysjIFg0FFIhFFIpELPnPy5EkdPnw4/rmlpUXNzc3Kzc1VYWGhqqqqVFlZqUmTJmny5MmqqalRV1dXfBWHH0gkAADwUUNDg+05Evv379eMGTPin6uqqiRJlZWV2rx5s+644w599NFHWr16tdra2lRaWqqdO3eeMwEzkVLurI1oNKqcnBx1dnYy2RIAcEF+/M7o6aP1vVGuV20UjD2adr/fqEgAAGCD23kOXs+RSBVMtgQAAI5RkQAAwAZDpmJUJM5BRQIAABu8Wv5ZVlam4uLi8x7E1d9QkQAAwEd9WbXRH5BIAABgQ8w0FXOx0NHNs6mMRAIAABuMs83N8+mIORIAAMAxKhIAANgQc7lqw82zqYyKBAAANsRM902s2gAAYGDyao5Euq3aoCIBAAAcoyIBAIANhgKKKeDq+XREIgEAgA2G+Vlz83w66vPQRn19vWbPnq1wOKxAIKDt27ef997vfe97CgQCqqmpcRsnAABIQX1OJLq6ulRSUnLR2abbtm3T22+/rXA47CY+AABSQuzs0Iablo76PLRRUVGhioqKC95z9OhR3XffffrNb36jWbNmuYkPAICU4DYZSNdEwvNVG4ZhaP78+XrggQc0bty4i97f3d2taDTaqwEAkK7YR+Ii1q1bp8zMTC1ZssTW/dXV1XrooYe8DgMAAE8ZZkCG6WLVxtln2UfiAhobG/XTn/5UmzdvViBg7x/2ypUr1dnZGW+tra1ehgQAgCeYI2HN00TijTfeUEdHhwoLC5WZmanMzEwdOXJE3//+9zV69GjLZ0KhkLKzs3s1AADQP3g6tDF//nyVl5f3ujZz5kzNnz9fCxcu9LIrAAB8FVOGYi7+/h3zNJrU0edE4uTJkzp8+HD8c0tLi5qbm5Wbm6vCwkINGzas1/2XXHKJ8vPzdfXVV3sTMQAASWC6nCNhung2lfU5kdi/f79mzJgR/1xVVSVJqqys1ObNm72NDgCAFMHyT2t9TiSmT58u07S/z+f777/f1y4AAEA/wVkbAADYEDMzFDNdzJHgrA0AAAYuQwEZynDRPhvaYEMqAADgWLptSEUiAQCADUy2tEYiAQCADe7nSKTnJAnmSAAAAMeoSAAAYMNnky1dHNrF0AYAAAOX4XKLbEMMbQAAAPRCRQIAABuYbGmNRAIAABt6NpZy/jyJBAAAA1bMDCjm4gRPN8+mMuZIAADgI7bIBgBgAIq5XLUROzu0wRbZAAAMQIaZIcPFZEsjTSdbMrQBAAAcoyIBAIANXg1tpBsSCQAAbDBcrrwwPI0mdTC0AQAAHKMiAQCADe43pErPv7uTSAAAYIP7LbLTM5FIz7cCAAC+oCIBAIANhgIy5GayZXpukU0iAQCADQxtWCORAADABvf7SKRnIpGebwUAAHxBRQIAABsMMyDDzYZUHCMOAMDAZZwd2nDaevaR4BhxAADgGMeIAwAwALk/Rjw9BwFIJAAAsCGmgGIu9oJw82wqS8/0CAAA+IKKBAAANjC0YY1EAgAAG2IuhydinkaTOtIzPQIAAL6gIgEAgA0MbVgjkQAAwAYO7bJGIgEAgA2my2PETZZ/AgAA9EZFAgAAGxjasEYiAQCADZz+aS090yMAAOCLPicS9fX1mj17tsLhsAKBgLZv3x7/2ZkzZ7R8+XJNmDBBQ4cOVTgc1oIFC3Ts2DGv4wYAwFdujhDvaemoz2/V1dWlkpISy3PUP/nkEzU1NWnVqlVqamrSL3/5Sx04cEBz5szxKl4AAJKiZ2jDTUtHfZ4jUVFRoYqKCsuf5eTkaNeuXb2u/exnP9PkyZP1wQcfqLCw0HmkAAAg5SR8smVnZ6cCgYAuu+wyy593d3eru7s7/jkajSY6JAAA+sxQhgwXwxNunk1lCX2rU6dOafny5frWt76l7Oxsy3uqq6uVk5MTbwUFBYkMCQAAR2JmwHVLRwlLJM6cOaPbb79dpmlqw4YN571v5cqV6uzsjLfW1tZEhQQAQNKVlZWpuLjYcq5hf5SQoY2eJOLIkSP6l3/5l/NWIyQpFAopFAolIgwAADzj1T4SDQ0NF/y92N94nkj0JBGHDh3S7t27NWzYMK+7AADAd6bL0z9Ndrb8zMmTJ3X48OH455aWFjU3Nys3N1cjR47UrbfeqqamJu3YsUOxWExtbW2SpNzcXA0aNMjb6AEA8ElMAcVcHLzl5tlU1udEYv/+/ZoxY0b8c1VVlSSpsrJSDz74oF555RVJUmlpaa/ndu/erenTp7uPGAAApIw+JxLTp0+XaZrn/fmFfgYAQH9lmO7OyzDS9Ncjh3YBAGCD4XKOhJtnU1l6vhUAAPAFFQkAAGwwFJDhYsKkm2dTGYkEAAA2uN2dkp0tAQAAvoCKBAAANjDZ0hqJBAAANhhyuUV2ms6RSM/0CAAA+IKKBAAANpguV22YaVqRIJEAAMAGr07/TDckEgAA2MBkS2vp+VYAAMAXVCQAALCBoQ1rJBIAANjAFtnWGNoAAACOUZEAAMAGhjaskUgAAGADiYQ1hjYAAIBjVCQAALCBioQ1EgkAAGwgkbDG0AYAAHCMigQAADaYLveCMD2NJnVQkQAAwIaeoQ03LRX9zd/8jS6//HLdeuutjp4nkQAAwIZ0TSSWLl2qn//8546fJ5EAAGAAmz59urKyshw/TyIBAIANyahI1NfXa/bs2QqHwwoEAtq+ffs599TW1mr06NEaPHiwpkyZon379nn0xvYw2RIAABuSsfyzq6tLJSUluuuuu/TNb37znJ9v3bpVVVVV2rhxo6ZMmaKamhrNnDlTBw4c0IgRIyRJpaWl+vTTT8959rXXXlM4HHb4Nv+NRAIAAB9Fo9Fen0OhkEKhkOW9FRUVqqioOO93rV+/XosWLdLChQslSRs3btSrr76qZ599VitWrJAkNTc3exr/FzG0AQCADaYZcN0kqaCgQDk5OfFWXV3tKJ7Tp0+rsbFR5eXl8WsZGRkqLy/XW2+95dl7XwwVCQAAbDAUcLWPRM+zra2tys7Ojl8/XzXiYo4fP65YLKa8vLxe1/Py8vTee+/Z/p7y8nK988476urq0hVXXKF/+Id/0A033GD7eRIJAAB8lJ2d3SuRSLZ//ud/dvU8iQQAADak2lkbw4cPVzAYVHt7e6/r7e3tys/P97SvC2GOBAAANng1R8IrgwYN0sSJE1VXVxe/ZhiG6urq+jQ04RYVCQAAfFRWVqZgMKhIJKJIJHLBe0+ePKnDhw/HP7e0tKi5uVm5ubkqLCxUVVWVKisrNWnSJE2ePFk1NTXq6uqKr+LwA4kEAAA2eDW00dDQYHuOxP79+zVjxoz456qqKklSZWWlNm/erDvuuEMfffSRVq9erba2NpWWlmrnzp3nTMBMJBIJAABscDs84eTZ6dOnyzQvfG7o4sWLtXjxYsdxuUUiAQCADabLioTXcyRSBZMtAQCAY1QkAACwwZR0kVGGiz6fjqhIAABgQ8/Olm6azq7aKC4uVm1tbbJfyRNUJAAA8FFfVm30B32uSFzsbHTTNLV69WqNHDlSQ4YMUXl5uQ4dOuRlzAAA+C7VNqRKFX1OJHrORj9fSebxxx/Xk08+qY0bN2rv3r0aOnSoZs6cqVOnTnkRLwAASdGzj4Sblo76PLRxobPRTdNUTU2N/u7v/k7f+MY3JEk///nPlZeXp+3bt+vOO+90HzEAAEgZnk62bGlpUVtbW6+z0XNycjRlyhRfz0YHAMBrpum+pSNPJ1u2tbVJZ89C/7y8vLz4z76ou7tb3d3d8c/RaNTLkAAA8EQydrbsD5K+/LO6ulo5OTnxVlBQkOyQAABIGJZ/XkDP+eft7e0aOXJk/Hp7e7tKS0stn1m5cmX8EBKdrUiQTAD9j9F2lW99ZeQf9K0voIdXFYkBv/zzQoqKipSfn9/rbPRoNKq9e/ee92z0UCik7OzsXg0AgFTDqg1rfa5IXOxs9GXLlumRRx7RmDFjVFRUpFWrVikcDmvu3Llexw4AgG/cTphksuVZFzsb/Qc/+IG6urp099136+OPP9aNN96onTt3avDgwd5GDgAAkq7PicTFzkYPBAL60Y9+pB/96EduYwMAIGV8VpFwM0fC03BSBmdtAABgA8s/rSV9+ScAAAMJyz8BwIKfSzL9WmrKMlN8nnm2uXleabj8k0QCAAAbGNqwxtAGAABwjIoEAAB2eDW2kWZIJAAAsMPl0IbSdGiDRAIAABvY2dIacyQAAIBjVCQAALCBVRvWSCQA9Dvs74CkMAPu5jmkaSLB0AYAAD5iZ0sAAAYgryZbsrMlAAADEftIWGJoAwAAOEZFAgAAG1i1YY1EAgAAu9J0eMINEgkgzb3x/ld96ed/jD7sSz8AUguJBAAANjC0YY1EAgAAO1i1YYlEAgAAWwJnm5vn0w/LPwEAgGMkEgAA2GF60NgiGwCAAcqjORJskQ2gX2FZJoBEIpEAAMAOjhG3RCIBAIANXp3+mW6YbAkAAByjIgEAgB1sSGWJRAIAADuYI2GJoQ0AAOAYFQkAAGwImJ81N8+nIxIJIM09+m+3+NLPD8ft8KUfIGmYI2GJRAIAADuYI2GJORIAAMAxKhIAANjB0IYlKhIAANjB6Z+WqEgAAOAjTv8EAGAgYmjDEokE8DljH/yJL/3cdftrvvTzmaCPfQFpjFUblpgjAQAAHKMiAQCADexsac3zikQsFtOqVatUVFSkIUOG6Ctf+Yoefvhhmel6EDsAYGDwaNVGuvG8IrFu3Tpt2LBBzz//vMaNG6f9+/dr4cKFysnJ0ZIlS7zuDgAAJJHnicRvf/tbfeMb39CsWbMkSaNHj9aLL76offv2ed0VAABIMs+HNqZOnaq6ujodPHhQkvTOO+/ozTffVEVFheX93d3dikajvRoAAKkm8Ll5Eo5asl8gQTyvSKxYsULRaFRjx45VMBhULBbTo48+qnnz5lneX11drYceesjrMAAA8BbLPy15nki8/PLLeuGFF7RlyxaNGzdOzc3NWrZsmcLhsCorK8+5f+XKlaqqqop/jkajKigo8DoswJb3Hrzfl36ePrDdl34k6btX7/GtLwADj+eJxAMPPKAVK1bozjvvlCRNmDBBR44cUXV1tWUiEQqFFAqFvA4DAABvsbOlJc8TiU8++UQZGb2nXgSDQRmG4XVXAAD4h0TCkueJxOzZs/Xoo4+qsLBQ48aN0+9+9zutX79ed911l9ddAQCAJPM8kXjqqae0atUq3Xvvvero6FA4HNZ3v/tdrV692uuuAADwDTtbWvM8kcjKylJNTY1qamq8/moAAJKHoQ1LHNoFAAAc49Au4HNG/+IxX/q5d+KXfOkHgIeoSFgikQAAwAbmSFhjaAMAAB+VlZWpuLhYtbW1yQ7FE1QkAACww6MtshsaGpSdne1dXElGIgEAgB3MkbBEIgEAgA3MkbDGHAkAAOAYFQkgCf7tZDjZIQDoK4Y2LJFIAABgh8uhjXRNJBjaAAAAjlGRAADADoY2LJFIAABgB4mEJYY2AACAY1QkAACwgX0krFGRAAAAjlGRAD7n/9z4f33p5+tF7/rSDwAkGokEAAB2MNnSEokEAAA2MEfCGokEAAB2pWky4AaTLQEAgGNUJAAAsIM5EpZIJAAAsIE5EtZIJIDPaTtzWbJDAIB+hUQCAAA7GNqwRCIBAIANDG1YY9UGAABwjIoEAAB2MLRhiUQCAAA7SCQsMbQBAAAcoyIBAIANTLa0RiKBfuGa1T/xpZ+8m/7Sl37uusqXbgB4KQ2HNlpbWzV//nx1dHQoMzNTq1at0m233dan7yCRAADAjjRMJDIzM1VTU6PS0lK1tbVp4sSJ+vrXv66hQ4fa/46ERggAAFLWyJEjNXLkSElSfn6+hg8frj/96U99SiSYbAkAgA09cyTctL6qr6/X7NmzFQ6HFQgEtH379nPuqa2t1ejRozV48GBNmTJF+/btc/R+jY2NisViKigo6NNzJBIAANhhetD6qKurSyUlJaqtrbX8+datW1VVVaU1a9aoqalJJSUlmjlzpjo6OuL3lJaWavz48ee0Y8eOxe/505/+pAULFmjTpk19jpGhDQAAUlRFRYUqKirO+/P169dr0aJFWrhwoSRp48aNevXVV/Xss89qxYoVkqTm5uYL9tHd3a25c+dqxYoVmjp1ap9jpCIBAIANXg1tRKPRXq27u9tRPKdPn1ZjY6PKy8vj1zIyMlReXq633nrL1neYpqnvfOc7+uu//mvNnz/fURxUJNA/XBf1pZv3D+T70o9u9qcbAB7yaNXGF+cgrFmzRg8++GCfv+748eOKxWLKy8vrdT0vL0/vvfeere/413/9V23dulXXXnttfP7FL37xC02YMMF2HCQSAAD4qLW1VdnZ2fHPoVAoabHceOONMgzD1XeQSAAAYIdHFYns7OxeiYRTw4cPVzAYVHt7e6/r7e3tys/3qbrKHAkAAOwJeNC8NGjQIE2cOFF1dXXxa4ZhqK6uTjfccIPHvZ1fQhKJo0eP6tvf/raGDRumIUOGaMKECdq/f38iugIAIG2dPHlSzc3N8ZUXLS0tam5u1gcffCBJqqqq0jPPPKPnn39ef/jDH3TPPfeoq6srvorDD54Pbfz5z3/WtGnTNGPGDP3617/WX/zFX+jQoUO6/PLLve4KAAD/eDS0UVZWpmAwqEgkokgkcsFH9u/frxkzZsQ/V1VVSZIqKyu1efNm3XHHHfroo4+0evVqtbW1qbS0VDt37jxnAmYieZ5IrFu3TgUFBXruuefi14qKirzuBgAAX3l1+mdDQ4PtORLTp0+XaV6408WLF2vx4sXOA3PJ80TilVde0cyZM3Xbbbdpz549GjVqlO69914tWrTI8v7u7u5ea2ijUX+W+aF/2f+Xz/rSz9CRR3zpB0A/lIaHdnnB8zkSf/zjH7VhwwaNGTNGv/nNb3TPPfdoyZIlev755y3vr66uVk5OTrz1dY9vAACQPJ4nEoZh6Prrr9fatWt13XXX6e6779aiRYu0ceNGy/tXrlypzs7OeGttbfU6JAAAvOHjORv9hedDGyNHjlRxcXGva9dcc43+8R//0fL+UCiU1M04AACww6s5EunG84rEtGnTdODAgV7XDh48qCuvvNLrrgAA6HfKyspUXFx83hM9+xvPKxL333+/pk6dqrVr1+r222/Xvn37tGnTJkdHkwIAkDI8mmzZl1Ub/YHnFYmysjJt27ZNL774osaPH6+HH35YNTU1mjdvntddAQDgG69O/0w3CTlr45ZbbtEtt9ySiK8GAAAphEO74Nj0/7XOt77e/9/3+tPPAl+6AdAfsY+EJRIJAABsYNWGNU7/BADAR6zaAABgIGLVhiUSCQAA7GCOhCUSCQAAbGCOhDXmSAAAAMeoSMCxzi8P8q2v9xfc71tfAGCJoQ1LJBIAANgQME0FTOfZgJtnUxlDGwAAwDESCQAA7DA9aOwjAQDAwOTVqo1020eCigQAAHCMigQAAHawasMSiQQca1jt3/jetT6t/vx/P2GZKQBrbEhljaENAADgGBUJAADsYGjDEokEAAA2MLRhjUQCAAA7qEhYYo4EAAA+YkMqAAAGKC+GJ9JtQyoSCQAA7DDNz5qb59MQiUQa+p8Zt/nSzwcPRXzpR5IOsL8DAKQkEgkAAGxg1YY1EgkAAOxg1YYlVm0AAADHqEgAAGBDwPisuXk+HZFIAABgB0MblhjaAAAAjlGRSEMZ14/zpZ+il4/70o8kaZV/XQGAFVZtWKMiAQCAHT0bUrlpbJENAMDA5FVFIt22yKYiAQAAHKMiAQCAHazasEQiAQCADUy2tMbQBgAAcIyKBAAAdnCMuCUSiTT06aWDfOnnku4zvvQDAKmAoQ1rDG0AAADHqEgAAGAHqzYskUgAAGADQxvWEj608dhjjykQCGjZsmWJ7goAAPgsoRWJhoYGPf3007r22msT2Q0AAIlnmJ81N8+noYRVJE6ePKl58+bpmWee0eWXX56obgAA8IfpQUtDCatIRCIRzZo1S+Xl5XrkkUcS1Q0s7Hppsy/9ZOQf9KUfAEgFAZfzHAJn/7esrEzBYFCRSESRSMSr8JImIYnESy+9pKamJjU0NFz03u7ubnV3d8c/R6PRRIQEAEBK4PTPi2htbdXSpUv1wgsvaPDgwRe9v7q6Wjk5OfFWUFDgdUgAALjXs7Olm5aGPE8kGhsb1dHRoeuvv16ZmZnKzMzUnj179OSTTyozM1OxWKzX/StXrlRnZ2e8tba2eh0SAACu9Sz/dNPSkedDGzfffLN+//vf97q2cOFCjR07VsuXL1cwGOz1s1AopFAo5HUYAADAB54nEllZWRo/fnyva0OHDtWwYcPOuQ4AQL/BzpaW2NkSAAAbAqapgIt5Dm6eTWW+JBKvv/66H93grJnhEl/62WX40g0AIIVRkQAAwA7jbHPzfBoikQAAwAaGNqwl/NAuAACQvqhIAABgB6s2LJFIAABgh9vdKdN0aINEAgAAG9zuTpmuO1syRwIAADhGRQIAADsY2rBEIgEAgA0B47Pm5vl0xNAGAABwjIoEAAB2MLRhiUQCAAA72EfCEkMbAAD4qKysTMXFxaqtrU12KJ6gIgEAgA1enbXR0NCg7OxsDyNLLhIJnxhtV/nW1y7joG99AcCAwRwJSwxtAAAAx6hIAABghynJzV4Q6VmQIJEAAMAOr+ZIpBsSCQAA7DBdznNIzzyCORIAAMA5KhIAANjBqg1LJBIAANhhSAq4fD4NkUj4JCOfvR0AAOmHRAIAABtYtWGNRAIAADuYI2GJVRsAAMAxKhIAANhBRcISiQQAAHaQSFhiaAMAADhGRQIAADvYR8ISiQQAADaw/NMaiQQAAHYwR8IScyQAAIBjVCQAALDDMKWAi6qCkZ4VCRIJAADsYGjDEkMbAADAMSoSAADY4rIiofSsSJBIAABgB0MblhjaAAAAjlGRAADADsN0NzzBqg0AAAYw0/isuXk+DTG0AQDAAPXxxx9r0qRJKi0t1fjx4/XMM8/0+Ts8TySqq6tVVlamrKwsjRgxQnPnztWBAwe87gYAAH/1TLZ001JMVlaW6uvr1dzcrL1792rt2rX6z//8zz59h+eJxJ49exSJRPT2229r165dOnPmjL72ta+pq6vL664AAPCPYbpvKSYYDOpLX/qSJKm7u1umacrsY8Lj+RyJnTt39vq8efNmjRgxQo2Njfqrv/orr7tzzWi7ypd+MvIP+tIPACBBkrD8s76+Xk888YQaGxv1H//xH9q2bZvmzp3b657a2lo98cQTamtrU0lJiZ566ilNnjzZdh8ff/yxbrrpJh06dEhPPPGEhg8f3qcYEz5HorOzU5KUm5tr+fPu7m5Fo9FeDQAASF1dXSopKVFtba3lz7du3aqqqiqtWbNGTU1NKikp0cyZM9XR0RG/p2f+wxfbsWPHJEmXXXaZ3nnnHbW0tGjLli1qb2/vU4wJXbVhGIaWLVumadOmafz48Zb3VFdX66GHHkpkGAAAuGe63FTq7KNf/AtzKBRSKBSyfKSiokIVFRXn/cr169dr0aJFWrhwoSRp48aNevXVV/Xss89qxYoVkqTm5mZb4eXl5amkpERvvPGGbr31VrtvldiKRCQS0bvvvquXXnrpvPesXLlSnZ2d8dba2prIkAAAcMajyZYFBQXKycmJt+rqakfhnD59Wo2NjSovL49fy8jIUHl5ud566y1b39He3q4TJ05IZ0cQ6uvrdfXVV/cpjoRVJBYvXqwdO3aovr5eV1xxxXnvu1AmBgBAumltbVV2dnb8s9PfgcePH1csFlNeXl6v63l5eXrvvfdsfceRI0d09913xydZ3nfffZowYUKf4vA8kegJZNu2bXr99ddVVFTkdRcAAPjPMCS52FTK+OzZ7OzsXolEMk2ePNn20Mf5eJ5IRCIRbdmyRb/61a+UlZWltrY2SVJOTo6GDBnidXcAAPgjxQ7tGj58uILB4DmTI9vb25Wfn+9pXxfieSKxYcMGSdL06dN7XX/uuef0ne98x+vuXGNZJgCgPxo0aJAmTpyourq6+JJQwzBUV1enxYsX+xZHQoY2AABIOx5VJMrKyhQMBhWJRBSJRC74yMmTJ3X48OH455aWFjU3Nys3N1eFhYWqqqpSZWWlJk2apMmTJ6umpkZdXV3xVRx+4NAuAADs8Oj0z4aGBttzJPbv368ZM2bEP1dVVUmSKisrtXnzZt1xxx366KOPtHr1arW1tam0tFQ7d+48ZwJmIpFIAACQoqZPn37RSv/ixYt9Hcr4IhIJAABsME1DpoujwN08m8pIJAAAsMN0efBWms4hJJEAAMAO0+UciTRNJBJ+aBcAAPhvZWVlKi4uPu9BXP0NFQkAAOwwDCngYp7D2TkSfVm10R+QSAAAYAdDG5YY2gAAAI5RkQAAwAbTMGS6GNpg+ScAAAMZQxuWGNoAAACOkUgAAGCHYbpvLP/0j9F+nYxPggnvh2PEAQC2mKYkN8s/+35oV39ARQIAADiWshUJAABSiWmYMgPOJ0xe7BTP/opEAgAAO0zD5dAGyz8BABiwqEhYY44EAABwLOUqEj0ZW/SkPyWgjC9FfekHAOC9aPSzP8P9+Nv+p2a3q+GJT3VGOrv8MxgMKhKJKBKJeBhhcgTMFKu1fPjhhyooKEh2GACAfqS1tVVXXHFFQr771KlTKioqUltbm+vvys/PV0tLiwYPHuxJbKkg5RIJwzB07NgxZWVlKRAI+NJnNBpVQUGBWltb02ptb490fz/xjmkh3d9PvGNCmKapEydOKBwOKyMjcaP1p06d0unTp11/z6BBg9IqiVAqDm1kZGQkLKu8mOzs7LT9j1sD4P3EO6aFdH8/8Y6ey8nJSXgfgwcPTrsEwCtMtgQAAI6RSAAAAMdIJCSFQiGtWbNGoVAo2aEkRLq/n3jHtJDu7yfeEWkq5SZbAgCA/oOKBAAAcIxEAgAAOEYiAQAAHCORAAAAjg3YRKK6ulplZWXKysrSiBEjNHfuXB04cCDZYSXUY489pkAgoGXLliU7FE8dPXpU3/72tzVs2DANGTJEEyZM0P79+5MdlidisZhWrVqloqIiDRkyRF/5ylf08MMP9+tTBOvr6zV79myFw2EFAgFt3769189N09Tq1as1cuRIDRkyROXl5Tp06FDS4nXiQu945swZLV++XBMmTNDQoUMVDoe1YMECHTt2LKkx99XF/j1+3ve+9z0FAgHV1NT4GiP8MWATiT179igSiejtt9/Wrl27dObMGX3ta19TV1dXskNLiIaGBj399NO69tprkx2Kp/785z9r2rRpuuSSS/TrX/9a//7v/64f//jHuvzyy5MdmifWrVunDRs26Gc/+5n+8Ic/aN26dXr88cf11FNPJTs0x7q6ulRSUqLa2lrLnz/++ON68skntXHjRu3du1dDhw7VzJkzderUKd9jdepC7/jJJ5+oqalJq1atUlNTk375y1/qwIEDmjNnTlJidepi/x57bNu2TW+//bbC4bBvscFnJkzTNM2Ojg5Tkrlnz55kh+K5EydOmGPGjDF37dpl3nTTTebSpUuTHZJnli9fbt54443JDiNhZs2aZd511129rn3zm980582bl7SYvCTJ3LZtW/yzYRhmfn6++cQTT8Svffzxx2YoFDJffPHFJEXpzhff0cq+fftMSeaRI0d8i8tL53vHDz/80Bw1apT57rvvmldeeaX5k5/8JCnxIbEGbEXiizo7OyVJubm5yQ7Fc5FIRLNmzVJ5eXmyQ/HcK6+8okmTJum2227TiBEjdN111+mZZ55JdliemTp1qurq6nTw4EFJ0jvvvKM333xTFRUVyQ4tIVpaWtTW1tbr/6s5OTmaMmWK3nrrraTGlkidnZ0KBAK67LLLkh2KZwzD0Pz58/XAAw9o3LhxyQ4HCZRyh3Ylg2EYWrZsmaZNm6bx48cnOxxPvfTSS2pqalJDQ0OyQ0mIP/7xj9qwYYOqqqr0t3/7t2poaNCSJUs0aNAgVVZWJjs811asWKFoNKqxY8cqGAwqFovp0Ucf1bx585IdWkL0HNOcl5fX63peXp4nRzinolOnTmn58uX61re+lVYHea1bt06ZmZlasmRJskNBgpFInP0b+7vvvqs333wz2aF4qrW1VUuXLtWuXbvS9tQ6wzA0adIkrV27VpJ03XXX6d1339XGjRvTIpF4+eWX9cILL2jLli0aN26cmpubtWzZMoXD4bR4v4HuzJkzuv3222WapjZs2JDscDzT2Nion/70p2pqalIgEEh2OEiwAT+0sXjxYu3YsUO7d+9O2vHlidLY2KiOjg5df/31yszMVGZmpvbs2aMnn3xSmZmZisViyQ7RtZEjR6q4uLjXtWuuuUYffPBB0mLy0gMPPKAVK1bozjvv1IQJEzR//nzdf//9qq6uTnZoCZGfny9Jam9v73W9vb09/rN00ZNEHDlyRLt27UqrasQbb7yhjo4OFRYWxv/sOXLkiL7//e9r9OjRyQ4PHhuwFQnTNHXfffdp27Ztev3111VUVJTskDx388036/e//32vawsXLtTYsWO1fPlyBYPBpMXmlWnTpp2zbPfgwYO68sorkxaTlz755BNlZPTO94PBoAzDSFpMiVRUVKT8/HzV1dWptLRUkhSNRrV3717dc889yQ7PMz1JxKFDh7R7924NGzYs2SF5av78+efMyZo5c6bmz5+vhQsXJi0uJMaATSQikYi2bNmiX/3qV8rKyoqPv+bk5GjIkCHJDs8TWVlZ58z5GDp0qIYNG5Y2c0Huv/9+TZ06VWvXrtXtt9+uffv2adOmTdq0aVOyQ/PE7Nmz9eijj6qwsFDjxo3T7373O61fv1533XVXskNz7OTJkzp8+HD8c0tLi5qbm5Wbm6vCwkItW7ZMjzzyiMaMGaOioiKtWrVK4XBYc+fOTWrcfXGhdxw5cqRuvfVWNTU1aceOHYrFYvE/f3JzczVo0KAkRm7fxf49fjE5uuSSS5Sfn6+rr746CdEioZK9bCRZJFm25557LtmhJVS6Lf80TdP8p3/6J3P8+PFmKBQyx44da27atCnZIXkmGo2aS5cuNQsLC83BgwebX/7yl80f/vCHZnd3d7JDc2z37t2W/+1VVlaa5tkloKtWrTLz8vLMUChk3nzzzeaBAweSHXafXOgdW1pazvvnz+7du5Mdum0X+/f4RSz/TF8cIw4AABwb8JMtAQCAcyQSAADAMRIJAADgGIkEAABwjEQCAAA4RiIBAAAcI5EAAACOkUgAAADHSCQAAIBjJBIAAMAxEgkAAOAYiQQAAHDs/wM2kAk69+6pegAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mkbatch(BSZ))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" output = model(batch_src, batch_padding_mask)\n",
"x = batch_labels.detach().to(torch.uint8)\n",
"y = output.detach()\n",
"cnts = torch.bincount(x)\n",
"weights = [1/cnts[i.item()].item() for i in x] # normalize by label count\n",
"fig, ax = plt.subplots()\n",
"h = ax.hist2d(x.cpu().numpy().flatten(), y.to(torch.float16).cpu().numpy().flatten(), weights=weights, bins=[15,50], norm=mpl.colors.LogNorm())\n",
"fig.colorbar(h[3], ax=ax)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/sipb/.venv/lib64/python3.12/site-packages/torch/nn/modules/transformer.py:871: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)\n",
" return torch._transformer_encoder_layer_fwd(\n",
"/home/sipb/.venv/lib64/python3.12/site-packages/torch/nn/modules/transformer.py:871: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)\n",
" return torch._transformer_encoder_layer_fwd(\n",
"/home/sipb/.venv/lib64/python3.12/site-packages/torch/_inductor/compile_fx.py:167: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"0.0005554668023250997"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def evaluate():\n",
" model.eval()\n",
" test_loss = 0\n",
" with torch.no_grad():\n",
" batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mkbatch(BSZ))\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" return loss.item()\n",
"\n",
"evaluate()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LC6Xv3YfC0Rm"
},
"source": [
"# Step 5: Fine Tune"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_state": "idle",
"metadata": {},
"outputs": [],
"source": [
"N_TUNE_EPOCHS = 100\n",
"TUNE_LR = 1e-5\n",
"TUNE_WD = 0 # 1e-5\n",
"\n",
"tune_criterion = nn.MSELoss()\n",
"tune_optimizer = torch.optim.Adam(model.parameters(), lr=TUNE_LR, weight_decay=TUNE_WD)\n",
"\n",
"tune_train_err = []\n",
"\n",
"# clear loss file\n",
"open('tune_loss', 'w').close()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0/100 \t Train Err: 0.4553\n",
"Epoch 1/100 \t Train Err: 0.2545\n",
"Epoch 2/100 \t Train Err: 0.2368\n",
"Epoch 3/100 \t Train Err: 0.2138\n",
"Epoch 4/100 \t Train Err: 0.1898\n",
"Epoch 5/100 \t Train Err: 0.1931\n",
"Epoch 6/100 \t Train Err: 0.1948\n",
"Epoch 7/100 \t Train Err: 0.1847\n",
"Epoch 8/100 \t Train Err: 0.1638\n",
"Epoch 9/100 \t Train Err: 0.1490\n",
"Epoch 10/100 \t Train Err: 0.1457\n",
"Epoch 11/100 \t Train Err: 0.1509\n",
"Epoch 12/100 \t Train Err: 0.1537\n",
"Epoch 13/100 \t Train Err: 0.1472\n",
"Epoch 14/100 \t Train Err: 0.1369\n",
"Epoch 15/100 \t Train Err: 0.1304\n",
"Epoch 16/100 \t Train Err: 0.1257\n",
"Epoch 17/100 \t Train Err: 0.1201\n",
"Epoch 18/100 \t Train Err: 0.1160\n",
"Epoch 19/100 \t Train Err: 0.1064\n",
"Epoch 20/100 \t Train Err: 0.1026\n",
"Epoch 21/100 \t Train Err: 0.0964\n",
"Epoch 22/100 \t Train Err: 0.1165\n",
"Epoch 23/100 \t Train Err: 0.0990\n",
"Epoch 24/100 \t Train Err: 0.0691\n",
"Epoch 25/100 \t Train Err: 0.0633\n",
"Epoch 26/100 \t Train Err: 0.0552\n",
"Epoch 27/100 \t Train Err: 0.0505\n",
"Epoch 28/100 \t Train Err: 0.0429\n",
"Epoch 29/100 \t Train Err: 0.0376\n",
"Epoch 30/100 \t Train Err: 0.0309\n",
"Epoch 31/100 \t Train Err: 0.0252\n",
"Epoch 32/100 \t Train Err: 0.0233\n",
"Epoch 33/100 \t Train Err: 0.0219\n",
"Epoch 34/100 \t Train Err: 0.0196\n",
"Epoch 35/100 \t Train Err: 0.0186\n",
"Epoch 36/100 \t Train Err: 0.0181\n",
"Epoch 37/100 \t Train Err: 0.0189\n",
"Epoch 38/100 \t Train Err: 0.0174\n",
"Epoch 39/100 \t Train Err: 0.0141\n",
"Epoch 40/100 \t Train Err: 0.0138\n",
"Epoch 41/100 \t Train Err: 0.0136\n",
"Epoch 42/100 \t Train Err: 0.0124\n",
"Epoch 43/100 \t Train Err: 0.0121\n",
"Epoch 44/100 \t Train Err: 0.0109\n",
"Epoch 45/100 \t Train Err: 0.0107\n",
"Epoch 46/100 \t Train Err: 0.0105\n",
"Epoch 47/100 \t Train Err: 0.0103\n",
"Epoch 48/100 \t Train Err: 0.0088\n",
"Epoch 49/100 \t Train Err: 0.0089\n",
"Epoch 50/100 \t Train Err: 0.0092\n",
"Epoch 51/100 \t Train Err: 0.0078\n",
"Epoch 52/100 \t Train Err: 0.0075\n",
"Epoch 53/100 \t Train Err: 0.0068\n",
"Epoch 54/100 \t Train Err: 0.0066\n",
"Epoch 55/100 \t Train Err: 0.0065\n",
"Epoch 56/100 \t Train Err: 0.0055\n",
"Epoch 57/100 \t Train Err: 0.0056\n",
"Epoch 58/100 \t Train Err: 0.0052\n",
"Epoch 59/100 \t Train Err: 0.0055\n",
"Epoch 60/100 \t Train Err: 0.0048\n",
"Epoch 61/100 \t Train Err: 0.0050\n",
"Epoch 62/100 \t Train Err: 0.0049\n",
"Epoch 63/100 \t Train Err: 0.0051\n",
"Epoch 64/100 \t Train Err: 0.0047\n",
"Epoch 65/100 \t Train Err: 0.0045\n",
"Epoch 66/100 \t Train Err: 0.0041\n",
"Epoch 67/100 \t Train Err: 0.0038\n",
"Epoch 68/100 \t Train Err: 0.0036\n",
"Epoch 69/100 \t Train Err: 0.0035\n",
"Epoch 70/100 \t Train Err: 0.0035\n",
"Epoch 71/100 \t Train Err: 0.0032\n",
"Epoch 72/100 \t Train Err: 0.0032\n",
"Epoch 73/100 \t Train Err: 0.0030\n",
"Epoch 74/100 \t Train Err: 0.0031\n",
"Epoch 75/100 \t Train Err: 0.0026\n",
"Epoch 76/100 \t Train Err: 0.0030\n",
"Epoch 77/100 \t Train Err: 0.0024\n",
"Epoch 78/100 \t Train Err: 0.0026\n",
"Epoch 79/100 \t Train Err: 0.0028\n",
"Epoch 80/100 \t Train Err: 0.0026\n",
"Epoch 81/100 \t Train Err: 0.0023\n",
"Epoch 82/100 \t Train Err: 0.0022\n",
"Epoch 83/100 \t Train Err: 0.0022\n",
"Epoch 84/100 \t Train Err: 0.0021\n",
"Epoch 85/100 \t Train Err: 0.0023\n",
"Epoch 86/100 \t Train Err: 0.0023\n",
"Epoch 87/100 \t Train Err: 0.0020\n",
"Epoch 88/100 \t Train Err: 0.0020\n",
"Epoch 89/100 \t Train Err: 0.0021\n",
"Epoch 90/100 \t Train Err: 0.0020\n",
"Epoch 91/100 \t Train Err: 0.0019\n",
"Epoch 92/100 \t Train Err: 0.0017\n",
"Epoch 93/100 \t Train Err: 0.0019\n",
"Epoch 94/100 \t Train Err: 0.0019\n",
"Epoch 95/100 \t Train Err: 0.0015\n",
"Epoch 96/100 \t Train Err: 0.0016\n",
"Epoch 97/100 \t Train Err: 0.0017\n",
"Epoch 98/100 \t Train Err: 0.0015\n",
"Epoch 99/100 \t Train Err: 0.0015\n"
]
}
],
"source": [
"for epoch in range(N_TUNE_EPOCHS):\n",
" model.train()\n",
" train_loss = 0\n",
" batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ))\n",
" optimizer.zero_grad()\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" train_loss = loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" tune_train_err.append(train_loss)\n",
" with open('tune_loss', 'a') as f:\n",
" f.write(f\"{train_loss}\\n\")\n",
" print(f\"Epoch {epoch}/{N_TUNE_EPOCHS} \\t Train Err: {train_loss:.4f}\")\n",
"\n",
" if epoch % 10 == 9:\n",
" torch.save(model.state_dict(), f\"tune_model_weights_{epoch + 1}.pth\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHgCAYAAABZ+0ykAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABCC0lEQVR4nO3deXRU9f3/8dckYZJAFpaQQCCsoih7g0BABTTIJlaLgpRKxBVQxPLTKlpB7ZeCuBQFBaXiziJWqVLBYsQFyhoMoCKuLIoJIJCELYHM5/fHZSaMAUxgZm5y83ycc8/M3Ll35j23al79bNdljDECAABwiDC7CwAAAAgkwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0AAHAUwg0ABJHL5dIdd9xhdxlAlUK4ASqxl156SS6XSy6XS8uXLy/1vjFGKSkpcrlcuuKKK/zeO3DggCZMmKDWrVurRo0aqlOnjtq3b68xY8Zo586dvuMeeugh33ecbMvJyQnJbz2V09U2YsQIW2sDYI8IuwsAcPaioqI0Z84cXXTRRX77P/74Y/3444+KjIz023/06FFdcskl+uqrr5SRkaHRo0frwIED+uKLLzRnzhxdffXVSk5O9jtnxowZiomJKfXdNWvWDNKvKrtevXpp2LBhpfafe+65ttQDwF6EG8AB+vXrpwULFujpp59WRETJv9Zz5sxRamqq9uzZ43f8woUL9dlnn+n111/XH//4R7/3jhw5oqKiolLfcc011yghISGIv+LMnXvuufrTn/5kdxkAKgi6pQAHGDJkiH755RctXbrUt6+oqEhvvvlmqfAiSd99950kqVu3bqXei4qKUlxcXEDqat26tXr27Flqv8fjUYMGDXTNNdf49s2bN0+pqamKjY1VXFyc2rRpo6eeeiogdUhSjx491Lp1a2VlZalr166Kjo5W06ZNNXPmzFLH7tq1SzfddJOSkpIUFRWldu3a6eWXXz7p73jqqafUpk0bRUVFqW7duurTp4/WrVtX6tiFCxeqdevWioyMVKtWrbRkyRK/9wsKCnTXXXepSZMmioyMVGJionr16qX169cH7BoAVQXhBnCAJk2aKC0tTXPnzvXtW7x4sfLy8nTdddeVOr5x48aSpFdeeUXGmDJ9x969e7Vnzx6/bf/+/ac9Z/Dgwfrkk09KjctZvny5du7c6att6dKlGjJkiGrVqqVHH31UkydPVo8ePbRixYoy1XbkyJFSte3Zs6dUC9S+ffvUr18/paamasqUKWrYsKFGjhyp2bNn+445fPiwevTooVdffVVDhw7VY489pvj4eN1www2lwtZNN92ku+66SykpKXr00Ud13333KSoqSqtWrSr1e0eNGqXrrrtOU6ZM0ZEjRzRw4ED98ssvvmNGjBihGTNmaODAgXr22Wd19913Kzo6Wps3by7TNQBwAgOg0nrxxReNJLN27Vozffp0Exsbaw4dOmSMMebaa681PXv2NMYY07hxY9O/f3/feYcOHTLnnXeekWQaN25sbrjhBvPCCy+Y3NzcUt8xYcIEI+mk23nnnXfa+rZs2WIkmWnTpvntHzVqlImJifHVOmbMGBMXF2eOHTtW7mtwqtokmblz5/qO6969u5FknnjiCd++wsJC0759e5OYmGiKioqMMcZMnTrVSDKvvfaa77iioiKTlpZmYmJiTH5+vjHGmA8//NBIMnfeeWepmjwej199brfbfPvtt759GzZsKHVd4uPjze23317u3w+gNFpuAIcYNGiQDh8+rEWLFqmgoECLFi06aZeUJEVHR2v16tW65557pOOzrm666SbVr19fo0ePVmFhYalz/vWvf2np0qV+24svvnjams4991y1b99e8+fP9+0rLi7Wm2++qQEDBig6Olo6Pij54MGDft1q5fH73/++VG1Lly4t1SUWERGh2267zffa7Xbrtttu065du5SVlSVJeu+991SvXj0NGTLEd1y1atV055136sCBA/r4449918PlcmnChAml6nG5XH6v09PT1bx5c9/rtm3bKi4uTt9//71vX82aNbV69Wq/mWoAzgwDigGHqFu3rtLT0zVnzhwdOnRIxcXFfmNafi0+Pl5TpkzRlClTtG3bNmVmZurxxx/X9OnTFR8fr//7v//zO/6SSy45owHFgwcP1v3336+ffvpJDRo00EcffaRdu3Zp8ODBvmNGjRqlN954Q3379lWDBg10+eWXa9CgQerTp0+ZvqNhw4ZKT0//zeOSk5NVo0YNv33eGVVbt25Vly5dtG3bNrVo0UJhYf7/3+/888+XJG3btk06Pm4pOTlZtWvX/s3vbdSoUal9tWrV0r59+3yvp0yZooyMDKWkpCg1NVX9+vXTsGHD1KxZs9/8fAD+aLkBHOSPf/yjFi9erJkzZ6pv375lnqbduHFj3XjjjVqxYoVq1qyp119/PWA1DR48WMYYLViwQJL0xhtvKD4+3i+4JCYmKjs7W++8846uvPJKLVu2TH379lVGRkbA6rBTeHj4SfefON5p0KBB+v777zVt2jQlJyfrscceU6tWrbR48eIQVgo4A+EGcJCrr75aYWFhWrVq1Sm7pE6nVq1aat68uX7++eeA1dS0aVN16tRJ8+fP17Fjx/TWW2/pqquuKrX2jtvt1oABA/Tss8/qu+++02233aZXXnlF3377bcBq2blzpw4ePOi37+uvv5aOD8rW8aD3zTffyOPx+B331Vdf+d6XpObNm2vnzp3au3dvwOqrX7++Ro0apYULF+qHH35QnTp1NHHixIB9PlBVEG4AB4mJidGMGTP00EMPacCAAac8bsOGDaXWvtHxLpcvv/xS5513XkDrGjx4sFatWqXZs2drz549fl1SkvxmDUlSWFiY2rZtK0knHf9zpo4dO6bnnnvO97qoqEjPPfec6tatq9TUVOn4mkE5OTl+44SOHTumadOmKSYmRt27d5ckDRw4UMYYPfzww6W+p6wz0LyKi4uVl5fnty8xMVHJyckB/f1AVcGYG8BhytKVs3TpUk2YMEFXXnmlunTpopiYGH3//feaPXu2CgsL9dBDD5U658033zzpCsW9evVSUlLSab9v0KBBuvvuu3X33Xerdu3apcbH3Hzzzdq7d68uvfRSNWzYUNu2bdO0adPUvn1731iX0/n666/12muvldqflJSkXr16+V4nJyfr0Ucf1datW3Xuuedq/vz5ys7O1vPPP69q1apJkm699VY999xzuuGGG5SVlaUmTZrozTff1IoVKzR16lTFxsZKknr27Knrr79eTz/9tL755hv16dNHHo9Hn376qXr27Fmu+0kVFBSoYcOGuuaaa9SuXTvFxMTogw8+0Nq1a/XEE0+U+XMAHGf3dC0AZ+7EqeCn8+up4N9//70ZP3686dKli0lMTDQRERGmbt26pn///ubDDz/0O/d0U8ElmWXLlpWp1m7duhlJ5uabby713ptvvmkuv/xyk5iYaNxut2nUqJG57bbbzM8///ybn3u62rp37+47rnv37qZVq1Zm3bp1Ji0tzURFRZnGjRub6dOnl/rM3NxcM3z4cJOQkGDcbrdp06aNefHFF0sdd+zYMfPYY4+Zli1bGrfbberWrWv69u1rsrKy/Oo72RTvxo0bm4yMDGOOT0m/5557TLt27UxsbKypUaOGadeunXn22Wd/8/cDKM1lytt+CgCVUI8ePbRnzx59/vnndpcCIMgYcwMAAByFcAMAAByFcAMAAByFMTcAAMBRaLkBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOQrgBAACOEmF3AaHm8Xi0c+dOxcbGyuVy2V0OAAAoA2OMCgoKlJycrLCw07fNVLlws3PnTqWkpNhdBgAAOAM7duxQw4YNT3tMlQs3sbGx0vGLExcXZ3c5AACgDPLz85WSkuL7O346VS7ceLui4uLiCDcAAFQyZRlSwoBiAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKIQbAADgKFXuxpnBUlQk5eZKHo/UuLHd1QAAUHXRchMga9ZIjRpJvXrZXQkAAFUb4SZAoqKsxyNH7K4EAICqjXATIN5wU1hodyUAAFRthJsAoeUGAICKgXATIJGR1iPhBgAAexFuAsTbclNUZM2YAgAA9iDcBIg33IhxNwAA2IpwEyAnhhu6pgAAsA/hJkAiIqSw41eTlhsAAOxDuAkQl4sZUwAAVASEmwBixhQAAPYj3AQQLTcAANiPcBNAhBsAAOxHuAkgwg0AAPYj3AQQ95cCAMB+hJsAouUGAAD7EW4CiNlSAADYj3ATQLTcAABgP8JNABFuAACwH+EmgAg3AADYj3ATQIQbAADsR7gJIKaCAwBgP8JNANFyAwCA/Qg3AcRUcAAA7Ee4CSBabgAAsB/hJoAINwAA2I9wE0CEGwAA7Ee4CSBmSwEAYD/CTQDRcgMAgP0INwHEbCkAAOxHuAkgWm4AALAf4SaACDcAANiPcBNAhBsAAOxHuAkgZksBAGA/wk0A0XIDAID9CDcBxGwpAADsR7gJIFpuAACwH+EmgAg3AADYr0KEm2eeeUZNmjRRVFSUOnfurDVr1pTpvHnz5snlcumqq64Keo1l4Q03xcXSsWN2VwMAQNVke7iZP3++xo4dqwkTJmj9+vVq166devfurV27dp32vK1bt+ruu+/WxRdfHLJaf4s33IjWGwAAbGN7uHnyySd1yy23aPjw4brgggs0c+ZMVa9eXbNnzz7lOcXFxRo6dKgefvhhNWvWLKT1no53QLGYDg4AgG1sDTdFRUXKyspSenp6SUFhYUpPT9fKlStPed4jjzyixMRE3XTTTb/5HYWFhcrPz/fbgiU8XIqIsJ7TcgMAgD1sDTd79uxRcXGxkpKS/PYnJSUpJyfnpOcsX75cL7zwgmbNmlWm75g0aZLi4+N9W0pKSkBqPxUGFQMAYC/bu6XKo6CgQNdff71mzZqlhISEMp0zbtw45eXl+bYdO3YEtUbCDQAA9oqw88sTEhIUHh6u3Nxcv/25ubmqV69eqeO/++47bd26VQMGDPDt83g8kqSIiAht2bJFzZs39zsnMjJSkScOhgkywg0AAPayteXG7XYrNTVVmZmZvn0ej0eZmZlKS0srdXzLli21adMmZWdn+7Yrr7xSPXv2VHZ2dtC7nMqCcAMAgL1sbbmRpLFjxyojI0MdO3ZUp06dNHXqVB08eFDDhw+XJA0bNkwNGjTQpEmTFBUVpdatW/udX7NmTUkqtd8u3DwTAAB72R5uBg8erN27d2v8+PHKyclR+/bttWTJEt8g4+3btyssrPIMDaLlBgAAe7mMMcbuIkIpPz9f8fHxysvLU1xcXMA//6KLpBUrpH/9S/rDHwL+8QAAVEnl+ftdeZpEKglabgAAsBfhJsAINwAA2ItwE2CEGwAA7EW4CTBmSwEAYC/CTYDRcgMAgL0INwHmXQyZcAMAgD0INwFGyw0AAPYi3AQY4QYAAHsRbgKMcAMAgL0INwFGuAEAwF6EmwBjKjgAAPYi3AQYs6UAALAX4SbA6JYCAMBehJsAI9wAAGAvwk2AEW4AALAX4SbACDcAANiLcBNgzJYCAMBehJsAY7YUAAD2ItwEGN1SAADYi3ATYIQbAADsRbgJMMINAAD2ItwE2Inhxhi7qwEAoOoh3ASYN9xI0tGjdlYCAEDVRLgJsBPDDV1TAACEHuEmwNzukueEGwAAQo9wE2AuF2vdAABgJ8JNEDBjCgAA+xBugoBwAwCAfQg3QUC4AQDAPoSbIODmmQAA2IdwEwQMKAYAwD6EmyCgWwoAAPsQboKAcAMAgH0IN0FAuAEAwD6EmyAg3AAAYB/CTRAwWwoAAPsQboKA2VIAANiHcBMEdEsBAGAfwk0QEG4AALAP4SYICDcAANiHcBMEhBsAAOxDuAkCZksBAGAfwk0Q0HIDAIB9CDdBwFRwAADsQ7gJAlpuAACwD+EmCAg3AADYh3ATBIQbAADsQ7gJAsINAAD2IdwEAVPBAQCwD+EmCJgtBQCAfQg3QUC3FAAA9iHcBAHhBgAA+xBugoBwAwCAfQg3QUC4AQDAPoSbIPCGm6NHJY/H7moAAKhaCDdB4J0tJaaDAwAQcoSbIPC23IiuKQAAQo5wEwQREVLY8StLuAEAILQIN0HgcjGoGAAAuxBugoRwAwCAPQg3QUK4AQDAHoSbIPHOmGK2FAAAoUW4CRJabgAAsEeFCDfPPPOMmjRpoqioKHXu3Flr1qw55bFvvfWWOnbsqJo1a6pGjRpq3769Xn311ZDWWxaEGwAA7GF7uJk/f77Gjh2rCRMmaP369WrXrp169+6tXbt2nfT42rVr64EHHtDKlSu1ceNGDR8+XMOHD9f7778f8tpPh3ADAIA9bA83Tz75pG655RYNHz5cF1xwgWbOnKnq1atr9uzZJz2+R48euvrqq3X++eerefPmGjNmjNq2bavly5eHvPbTIdwAAGAPW8NNUVGRsrKylJ6eXlJQWJjS09O1cuXK3zzfGKPMzExt2bJFl1xyyUmPKSwsVH5+vt8WCoQbAADsYWu42bNnj4qLi5WUlOS3PykpSTk5Oac8Ly8vTzExMXK73erfv7+mTZumXr16nfTYSZMmKT4+3relpKQE/HecjDfcMFsKAIDQsr1b6kzExsYqOztba9eu1cSJEzV27Fh99NFHJz123LhxysvL8207duwISY3eqeC03AAAEFoRdn55QkKCwsPDlZub67c/NzdX9erVO+V5YWFhOueccyRJ7du31+bNmzVp0iT16NGj1LGRkZGKPPE23SFCtxQAAPawteXG7XYrNTVVmZmZvn0ej0eZmZlKS0sr8+d4PB4VVrD+H8INAAD2sLXlRpLGjh2rjIwMdezYUZ06ddLUqVN18OBBDR8+XJI0bNgwNWjQQJMmTZKOj6Hp2LGjmjdvrsLCQr333nt69dVXNWPGDJt/iT/CDQAA9rA93AwePFi7d+/W+PHjlZOTo/bt22vJkiW+Qcbbt29XWFhJA9PBgwc1atQo/fjjj4qOjlbLli312muvafDgwTb+itIINwAA2MNljDF2FxFK+fn5io+PV15enuLi4oL2PQ8/LD30kDRypPTss0H7GgAAqoTy/P2ulLOlKgNmSwEAYA/CTZDQLQUAgD0IN0FCuAEAwB6EmyAh3AAAYA/CTZAQbgAAsAfhJkgINwAA2INwEyTe2VIVbOFkAAAcj3ATJLTcAABgD8JNkBBuAACwB+EmSAg3AADYg3ATJIQbAADsQbgJEsINAAD2INwECbOlAACwB+EmSLwtN8XF0rFjdlcDAEDVQbgJEm+4EV1TAACEFOEmSLzdUiLcAAAQUoSbIAkPl6pVs54TbgAACB3CTRAxYwoAgNAj3ASRN9wwYwoAgNAh3ASRd9wNLTcAAIQO4SaI6JYCACD0CDdBRLgBACD0CDdBRLgBACD0CDdBRLgBACD0CDdBRLgBACD0CDdBxM0zAQAIPcJNENWoYT3u2mV3JQAAVB2EmyC65BLr8a237K4EAICqg3ATRNdea91jau1a6dtv7a4GAICqgXATRImJ0mWXWc/nzrW7GgAAqgbCTZD98Y/W49y5kjF2VwMAgPMRboLs6qutWVObN0sbN9pdDQAAzke4CbK4OKl/f+v5nDl2VwMAgPMRbkJgyBDrcd48yeOxuxoAAJyNcBMC/ftLsbHS9u3SypV2VwMAgLMRbkIgOtoaeyO6pgAACLpyhZspU6bo8OHDvtcrVqxQ4Qn3FigoKNCoUaMCW6FDeLumFiyQjh2zuxoAAJzLZUzZJyiHh4fr559/VmJioiQpLi5O2dnZatasmSQpNzdXycnJKi4uDl7FZyk/P1/x8fHKy8tTXFxcyL736FGpQQNp925pyRKpd++QfTUAAJVeef5+l6vl5tc5qBy5qMqrVs1asVgs6AcAQFAx5iaEvF1Tb79tteQAAIDAI9yEUNeuUp06Un6+tGaN3dUAAOBMEeU94Z///KdiYmIkSceOHdNLL72khIQE6fiAYpxaWJh1r6k33pCWLpW6dbO7IgAAnKdcA4qbNGkil8v1m8f98MMPZ1tX0Ng1oNjrn/+UbrnFCjbLl4f86wEAqJTK8/e7XC03W7duPdvaqrz0dOtx1Sqre8qGfAUAgKMx5ibEmjSRmjeXiouljz+2uxoAAJynXOFm5cqVWrRokd++V155RU2bNlViYqJuvfVWv0X9cHK9elmPH3xgdyUAADhPucLNI488oi+++ML3etOmTbrpppuUnp6u++67T++++64mTZoUjDodxds1tXSp3ZUAAOA85Qo32dnZuuyyy3yv582bp86dO2vWrFkaO3asnn76ab3xxhvBqNNRLr1UcrmkzZuln36yuxoAAJylXOFm3759SkpK8r3++OOP1bdvX9/rCy+8UDt27AhshQ5Uq5bUsaP1nK4pAAACq1zhJikpyTfNu6ioSOvXr1eXLl187xcUFKhatWqBr9KBGHcDAEBwlCvc9OvXT/fdd58+/fRTjRs3TtWrV9fFF1/se3/jxo1q3rx5MOp0HO+4mw8+kLhFFwAAgVOucPO3v/1NERER6t69u2bNmqXnn39ebrfb9/7s2bN1+eWXB6NOx+naVYqOlnJypBPGaAMAgLNUrkX8EhIS9MknnygvL08xMTEKDw/3e3/BggWKjY0NdI2OFBkpXXKJ9P77VutN69Z2VwQAgDOUK9zceOONZTpu9uzZZ1pPlZKeboWbpUulu+6yuxoAAJyhXOHmpZdeUuPGjdWhQweV45ZUOAXvoOKPP5aKiqQTevgAAMAZKle4GTlypObOnasffvhBw4cP15/+9CfVrl07eNU5XJs2Ut260u7d0urV0gljswEAwBkq14DiZ555Rj///LP+8pe/6N1331VKSooGDRqk999/n5acMxAWJnnXRGS1YgAAAqPcN86MjIzUkCFDtHTpUn355Zdq1aqVRo0apSZNmujAgQPBqdLBvOFm2TK7KwEAwBnO6q7gYWFhcrlcMsaouLg4cFVVIZdeaj2uWiUdPGh3NQAAVH7lDjeFhYWaO3euevXqpXPPPVebNm3S9OnTtX37dsXExASnSgdr2lRq3Fg6dkxavtzuagAAqPzKFW5GjRql+vXra/Lkybriiiu0Y8cOLViwQP369VNY2Fk1AlVZLpfUs6f1nK4pAADOnsuUYyRwWFiYGjVqpA4dOsjlcp3yuLfeeitQ9QVcfn6+4uPjlZeXp7i4OLvLkSS9+qo0bJh04YXSmjV2VwMAQMVTnr/f5ZoKPmzYsNOGGpwZb8tNVpaUlyfFx9tdEQAAlVe5F/ELhmeeeUaPPfaYcnJy1K5dO02bNk2dOnU66bGzZs3SK6+8os8//1ySlJqaqr///e+nPL4yaNhQatFC+uYb6ZNPpAED7K4IAIDKy/aBMvPnz9fYsWM1YcIErV+/Xu3atVPv3r21a9eukx7/0UcfaciQIVq2bJlWrlyplJQUXX755frpp59CXnsgeWdNffih3ZUAAFC5lWvMTTB07txZF154oaZPny5J8ng8SklJ0ejRo3Xffff95vnFxcWqVauWpk+frmHDhv3m8RVxzI0kzZ8vXXed1K6dlJ1tdzUAAFQs5fn7bWvLTVFRkbKyspSenl5SUFiY0tPTtXLlyjJ9xqFDh3T06NFT3gaisLBQ+fn5fltF1KOH9bhhg7Rnj93VAABQedkabvbs2aPi4mIlJSX57U9KSlJOTk6ZPuPee+9VcnKyX0A60aRJkxQfH+/bUlJSAlJ7oCUlSa1bW88/+sjuagAAqLxsH3NzNiZPnqx58+bp7bffVlRU1EmPGTdunPLy8nzbjh07Ql5nWbHeDQAAZ8/WcJOQkKDw8HDl5ub67c/NzVW9evVOe+7jjz+uyZMn67///a/atm17yuMiIyMVFxfnt1VU5RlU/NNP0rZtQS8JAIBKx9Zw43a7lZqaqszMTN8+j8ejzMxMpaWlnfK8KVOm6G9/+5uWLFmijh07hqja4Ove3Vqx+KuvpJ07/d8rLJQyM6V77pHatLGmj597rrRunV3VAgBQMdneLTV27FjNmjVLL7/8sjZv3qyRI0fq4MGDGj58uHR84cBx48b5jn/00Uf14IMPavbs2WrSpIlycnKUk5PjiDuS16oldehgPfeOu9m3T7r/fikhQUpPlx5/XDq+xI+KiqSRIyXuWQoAQIlyLeIXDIMHD9bu3bs1fvx45eTkqH379lqyZIlvkPH27dv97ls1Y8YMFRUV6ZprrvH7nAkTJuihhx4Kef2Bduml0vr10qJF0tat0pQp1qrFklSvntSnj7W1aSOlpVktN7NmSSNG2F05AAAVg+3r3IRaRV3nxmvxYqlfP/99rVtLEydaKxefePeLadOkO++UataUtmyREhNDXi4AACFRada5QWkXXSRFR1vPmzWTXnvNWtTvyiv9g41kdUl16CDt3y/de68t5QIAUOEQbiqY2Fjp/felefOkzZuloUOl8PCTHxsRIc2YYYWel16Sli8PdbUAAFQ8hJsK6OKLpcGDJbf7t4/t3Fm65Rbr+ciR0tGjQS8PAIAKjXDjAH//u1SnjjWL6vgtugAAqLIINw5Qp44VcCSrmwoAgKqMcOMQQ4ZI1apJ33wjff213dUAAGAfwo1DxMaW3Fn8P/+xuxoAAOxDuHGQ/v2tx0WL7K4EAAD7EG4c5IorrMdPPpHy8+2uBgAAexBuHKR5c+m886Rjx6T//tfuagAAsAfhxmG8rTeMuwEAVFWEG4fxjrt57z3J47G7GgAAQo9w4zAXXSTFxUm7dll3DAcAoKoh3DhMtWpS797Wc2ZNAQCqIsKNAzHuBgBQlRFuHKhvX+tO4evXSzt32l0NAAChRbhxoLp1rbuF6/jAYgAAqhLCjUOxWjEAoKoi3DiUd9zNBx9IR47YXQ0AAKFDuHGodu2kBg2kgwdZrRgAULUQbhzK5ZKuu856/txzdlcDAEDoEG4cbMQI63HxYun77+2uBgCA0CDcONg551gL+hkjzZxpdzUAAIQG4cbhbr/denzhBenwYburAQAg+Ag3Dtevn9SokbR3rzR/vt3VAAAQfIQbhwsPLxl78+yzdlcDAEDwEW6qgJtuktxuae1aayuP/fulKVOkyy+XPv00WBUCABA4hJsqIDFRuvZa63lZW2+2bZP+/GcpJUW6915p6VLrnlWrVgW1VAAAzhrhporwDiyeN0/65ZdTH3fggDR8uNS8uTR1qvW6VSspLc1aELBvXyk7O2RlAwBQboSbKqJLF6l9e+tWDC+8cPJjtm+XunWTXnpJKi6WLrvMWiNn0yar5aZbN6ubqlcvafPmUP8CAADKhnBTRbhcJa0348ZJt94q7dxZ8v7q1VKnTtLGjVJSkvTJJ9Z9qfr0sc6tUUP6z3+k1FRpzx4r+Hz3nW0/BwCAU3IZY4zdRYRSfn6+4uPjlZeXp7i4OLvLCanCQun666UFC6zX0dHWuJpzzpFGjrTeb9tWevdda/r4yfzyi9Sjh/T551LTptKGDVJsbEh/BgCgCirP329abqqQyEjpjTesWU9du1qL+v3979KNN1rBZsAAafnyUwcbSapTx+qiatxY+uEHaeLEUP4CAAB+G+GmCrroIivELFwotWxp7ft//096++2ytcLUqydNn249f/JJacuW4NYLAEB50C1VxRUXS7t3W4GlvK64whqH07u3NfDY5QpGhQAA0C2FcggPP7NgI1lTxd1u6f33pXfeCXRlAACcGcINztg550h33209v+subswJAKgYCDc4K/ffLzVsKG3dat2mAQAAuxFucFZq1JCeeMJ6PnmyFXIAALAT4QZn7dprpZ49rdWPBw2S8vPtrggAUJURbnDWXC5p5kxrDZy1a6X+/a37UAEAYAfCDQLi3HOl//5Xio+31tC56iqrJQcAgFAj3CBgfvc7a72bGjWs+1Jde61UVGR3VQCAqoZwg4BKS5MWLZKioqzHoUNpwQEAhBbhBgHXo4d1K4dq1aQ337Ru9/DDD3ZXBQCoKgg3CIo+faxbM9SpI2VlWV1W775rd1UAgKqAcIOg6dVL+uwzqUsXaf9+6corpfvuk44ds7syAICTEW4QVCkp0scfS2PGWK8ffdSaKn7ggN2VAQCcinCDoHO7rZtsvvGGVL26NWX80kulPXvsrgwA4ESEG4TMtddKH35YstjfRRdJ27bZXRUAwGkINwipzp2tRf5SUqQtW6Ru3aQvvrC7KgCAkxBuEHItW0r/+590wQXSTz9JF18sLVhgd1UAAKcg3MAWDRtKn34qde0q7dtn3XDz2mulXbvsrgwAUNkRbmCb2rWlZcuk8eOliAhrwb9WrQLXirNrl5STE5jPAgBUHoQb2Mrtlh5+WFqzRmrTxppBNWiQNHr02X3ukSPWwoFt2lhr7AAAqg7CDSqEDh2kdeukBx+UwsKk6dOtm3CeqQ8/tMbz7NkjvfVWICsFAFR0hBtUGG639Mgj0l13Wa9vu00qKDizz3rnnZLnc+YEpj4AQOVAuEGF88gjUpMm0o4d0l//Wv7zPR7/cLNsmfTzzwEtEQBQgRFuUOHUqCE995z1fNo0adWq8p2flWWFmZgYKTXVCjtvvBGUUgEAFRDhBhXS5ZdLGRmSMdLNN0tFRWU/99//th779JFuuMF6TtcUAFQdhBtUWE88IdWta61gPHly2c/zdkn9/vfW2jnh4dZsrG+/DVqpAIAKhHCDCqtOHenpp63n//d//uNoTuWHH6RNm6xA06+flJQkpadb782dG9x6AQAVA+EGFdrgwdKVV0pHj1otMQMHSj/+eOrjvQHo4outRQIl6Y9/tB7nzLG6uQAAzmZ7uHnmmWfUpEkTRUVFqXPnzlqzZs0pj/3iiy80cOBANWnSRC6XS1OnTg1prQg9l8tqcbn3Xqs15q23pPPPt1p0iotLH+8db/P735fsu+oqKSpK+uoracOG0NUOALCHreFm/vz5Gjt2rCZMmKD169erXbt26t27t3ad4gZDhw4dUrNmzTR58mTVq1cv5PXCHtWrW2Nu1q+XunSRDhyQxoyRevSQ8vNLjtu3T/rkE+v5lVeW7I+Lk664wnrOwGIAcD5bw82TTz6pW265RcOHD9cFF1ygmTNnqnr16po9e/ZJj7/wwgv12GOP6brrrlNkZGTI64W92raVVqyQZs60Asvy5daMKG/Aee89qzWndWupWTP/c71dU3PnWlPDAQDOZVu4KSoqUlZWltK9oz0lhYWFKT09XStXrgzY9xQWFio/P99vQ+UVFmatXLxsmVSrlrRypdS3r7WSsXe8zYmtNl59+0rx8dZ4neXLQ142ACCEbAs3e/bsUXFxsZKSkvz2JyUlKSeAt3KeNGmS4uPjfVtKSkrAPhv2+d3vpA8+kGrWlP73P6sFx3svqhPH23hFRVmDkSUW9AMAp7N9QHGwjRs3Tnl5eb5tx44ddpeEAPl1wCkokOrXlzp2PPnxffpYj+Vd8RgAULnYFm4SEhIUHh6u3Nxcv/25ubkBHSwcGRmpuLg4vw3OkZoqLV1qBRxJGjDA6ro6mQsvtB43bpQKC0NXIwAgtGwLN263W6mpqcrMzPTt83g8yszMVFpaml1loRLq2FH66CNp5EjpgQdOfVzjxtbCgEePMiUcAJzM1m6psWPHatasWXr55Ze1efNmjRw5UgcPHtTw4cMlScOGDdO4ceN8xxcVFSk7O1vZ2dkqKirSTz/9pOzsbH3LuvpVXrt20rPPSo0anfoYl6uk9WbdupCVBgAIsQg7v3zw4MHavXu3xo8fr5ycHLVv315LlizxDTLevn27wk7oY9i5c6c6dOjge/3444/r8ccfV/fu3fXRRx/Z8htQuXTsKC1ZIq1da3clAIBgcRlTtRakz8/PV3x8vPLy8hh/UwW98441m6pVK+nzz+2uBgBQVuX5++342VLAibwzqTZvtlY6BgA4D+EGVUpysrV5PNJnn9ldDQAgGAg3qHK8g4oZdwMAzkS4QZXDjCkAcDbCDaoc77gbWm4AwJkIN6hyvOHm22+lffvsrgYAEGiEG1Q5depIzZpZz7Oy7K4GABBohBtUSXRNAYBzEW5QJTFjCgCci3CDKokZUwDgXIQbVEm/+511I80dO6TcXLurAQAEEuEGVVJsrNSypfWcrikAcBbCDaosuqYAwJkIN6iymDEFAM5EuEGVdeKMKWPsrgYAECiEG1RZ7dpJERHS7t3Se+/ZXQ0AIFAIN6iyoqOlIUOs51dfLf3rX3ZXBAAIBMINqrR//lMaNEg6etR6fPlluysCAJwtwg2qNLdbmjNHuvFGyeORbrhBmj7d7qoAAGeDcIMqLzxcmjVLGjPGej16tDR1qt1VAQDOFOEGkBQWJv3jH9KDD1qv/9//k5Yts7sqAMCZINwAx7lc0sMPSxkZVhfVdddJO3faXRUAoLwIN8AJXC7p2Weltm2lXbusgHP0qN1VAQDKg3AD/Er16tKbb1r3n/r0U+n+++2uCABQHoQb4CRatJBeesl6/vjj0ttv210RAKCsCDfAKfzhD9LYsdbzG26wWnEAABUf4QY4jcmTpYsvlvLzpZ49pSef5D5UAFDREW6A06hWTVq82LpNQ3GxNUV88GCpoMDuygAAp0K4AX5DjRrS669L06ZZN9pcsEDq1En66iu7KwMAnAzhBigDl0u64w7pk0+k5GQr2PTsKf34o92VAQB+jXADlENamvTZZ1Lr1lJOjnU38cOH7a4KAHAiwg1QTomJ0jvvSHXqSOvWSTffzCBjAKhICDfAGWja1FroLyLCuqv4lCl2VwQA8CLcAGeoRw/pqaes5+PGSf/5j90VAQBEuAHOzsiR0m23Wd1SQ4ZY3VQAAHsRboCz4HJJTz8tXXKJtfZN9+7Sv/9td1UAULURboCz5HZL774r9e4tHTpkzaDydlcBAEKPcAMEQFycFXBuvdXqorrrLunOO61VjQEAoUW4AQKkWjVp5sySmVPTpknXXCMdPWp3ZQBQtRBugAByuaR77rFu0RAZKS1cKN10k+Tx2F0ZAFQdhBsgCK65RvrXv6TwcOnVV6W772ahPwAIFcINECT9+0svvmg9/8c/pMmT7a4IAKoGwg0QRNdfLz35pPX8/vulWbPsrggAnI9wAwTZn/9srWAsSSNGSK+/bndFAOBshBsgBCZOtKaJezxWa84zz9hdEQA4F+EGCAGXS5oxQxo92hpYfMcd0t/+xiBjAAgGwg0QImFh1srFEyZYr8ePt7qsmCYOAIFFuAFCyOWSHnqo5PYMTz1l3a5hyRKpqMju6gDAGQg3gA3uvFN65RVrHZx33pH69pXq1pWGDrXWx9m3z+4KAaDychlTtXr98/PzFR8fr7y8PMXFxdldDqq4tWuttXDeflvKyfF/74ILpIsukrp1k3r0kBo1sqtKALBfef5+E26ACsDjkVavlt56S1q0SPrqq9LHdO8u3XCDtfpxTIwdVQKAfQg3p0G4QWWwe7f0v/9Jy5db2+rVJTOrqle3Ak7PnlLLltZWs6bdFQNAcBFuToNwg8poxw7rHlUvvyx9/XXp9+vVk845R0pIkGrXlurUsR737ZO2bi3ZCgqkTp2sbq4ePaQuXaSoKDt+EQCUD+HmNAg3qMyMkVatkt54Q9q0Sdq8Wdq588w/LzLSugfWI49IrVoFslIACCzCzWkQbuA0+fnWGJ2tW6VffpH27i15jI+XmjQp2dxuacUK6aOPpGXLSgYxu1zWyskPP2wdBwAVDeHmNAg3gMUYaeNGa6Xkf/3L2letmnTzzVKbNlKNGiffYmKk2FiJf30AhBLh5jQIN0Bpa9ZYN/f88MOyn5OUJLVrV7Kdc44VeGJjrQAUEyNFRASzagBVCeHmNAg3wKl98IE0Z46UlycdPGhtBw6UPPe+Li4u2+edc47UsWPJ1qqVFXoiI62uMAAoK8LNaRBugLNjjBVyvvhC2rChZNu+3dpfUCAdPXr6z3C5rCnt0dElW1RUyeOvt6ZNpYsvtmZ3Va8eql8KoCIpz99vGo0BlIvLZbW+dO5sbSdTVGRNQ9+4UVq3rmTbvt163xuQDh4s33dHREipqdb3ut1WiPJuhw9bg6sLCqzHI0esVZ67drW29u2tcwA4Hy03AELGG0IOHSrZjhyx9nkfDx+WCgut10eOWMds3Ch9+qn0449n/t1RUSULHsbFWTPJataUkpOtW1t4t/r1rYHVACoWWm4AVEjVqlnbmfz/CmOkbdukTz6xwo7LVfJ51apZ4SUurmRQc3i49Nln1krPK1da0+Ozs8v2XWFh1rgg71atmtVqFB5uPVavLp17rtUydP751mPt2tZYJO8mlQSo8PDS3+ENerGxjD8CAo2WGwCOZ4z0zTfSd99ZXVZ5eda2f7/VGrR9u7Xt2PHb44XORFycFX7Cwkq6zgoLrffcbqu1qEEDa4uPt97zbkVF1nknBjnJv/Xr0CErcJ24OnV8fMmYJW9Ii4uz3vdutWvTVYfKo9K13DzzzDN67LHHlJOTo3bt2mnatGnq1KnTKY9fsGCBHnzwQW3dulUtWrTQo48+qn79+oW0ZgCVh8tltbSce+7pjysutkLPieGisFA6dsx/y8uzFk788ktrlegvv7TGD4WHW1tYmBWoDh+2Pjc/39pOpqjIapHati3wv7ss3O6S9Ytq1LBCkjcQeR+9rVZhYSW/8devo6Otc73biTPivI+Rkf6DyN1u63p6x00dO2bVkJAg1a1rPUZGWtcoL88ax7V/vxXw6tWzjjndcgPGWJ9bVGRtbrf1vSdrSYOz2B5u5s+fr7Fjx2rmzJnq3Lmzpk6dqt69e2vLli1KTEwsdfz//vc/DRkyRJMmTdIVV1yhOXPm6KqrrtL69evVunVrW34DAGcID7daM8piwIDfPqaoyPpjvG+ftWK0MVY3lHcRxKgo6yapP/1Ush044N8l5nZbd433hoBjx6zPOTFIREVZQWrv3pIVqvfv9w9oR45YAeGXX6xt3z7rc7x/+PftO+vLFxSRkSWtXL8WFmYFnLp1reviHbPlHbdVVHTy87zdmN5Qd+Lmdpd0P54Y4lwuawsLK725XP7X2RuIvf0i3sfoaP8FMd3u0l2Sxlj/e3s3HR9I793Cw0u6SSMiSlrzTvxnJjLSOvfXLYBSye9wufy7cuPirN9vjBXyPR7/blbva4/H//vcbuv1r39HZKQVQO1ie7dU586ddeGFF2r69OmSJI/Ho5SUFI0ePVr33XdfqeMHDx6sgwcPatGiRb59Xbp0Ufv27TVz5szf/D66pQCgpJXqxHWMDhwoGdx94qDuX/+BO9nmHfzt3byB5MQ/8IWFJeHj0CHrD673D7X38cABac8ea/v1ekregeBHj0q7dpX88UfFk5ZmjXcLpErTLVVUVKSsrCyNGzfOty8sLEzp6elauXLlSc9ZuXKlxo4d67evd+/eWrhw4UmPLywsVOEJsT//VG3DAFCFeFupytpSFWoej9X6lJ9f0rJwYhdUcbEVgHJyrNavatX810v6detXtWpWmDpxVt6Joa6gwHr0to6dGNy8rSm/fvSGPWP8u/G8g9C9XC7r2CNHrO87dMh6/HXLkjGlW4W8LSneLtGjR63X3jqPHbM+59ctR78eFH/i2CpjrO3IkZIu0/x86/e7XCfvdvS+DguzvvvEFqGTjVOLjAzGPxVlZ2u42bNnj4qLi5WUlOS3PykpSV999dVJz8nJyTnp8TneOwD+yqRJk/Twww8HsGoAQLCFhZ0+fIWHW7cA+dWfg9OqVs3qDoLzhdldQLCNGzdOeXl5vm3Hjh12lwQAAILI1pabhIQEhYeHKzc3129/bm6u6p1iJFK9evXKdXxkZKQi7W4fAwAAIWNry43b7VZqaqoyMzN9+zwejzIzM5WWlnbSc9LS0vyOl6SlS5ee8ngAAFC12D4VfOzYscrIyFDHjh3VqVMnTZ06VQcPHtTw4cMlScOGDVODBg00adIkSdKYMWPUvXt3PfHEE+rfv7/mzZundevW6fnnn7f5lwAAgIrA9nAzePBg7d69W+PHj1dOTo7at2+vJUuW+AYNb9++XWFhJQ1MXbt21Zw5c/TXv/5V999/v1q0aKGFCxeyxg0AAJAqwjo3ocY6NwAAVD7l+fvt+NlSAACgaiHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAAR7F9heJQ865ZmJ+fb3cpAACgjLx/t8uy9nCVCzcFBQWSpJSUFLtLAQAA5VRQUKD4+PjTHlPlbr/g8Xi0c+dOxcbGyuVyBfSz8/PzlZKSoh07dnBrhyDjWocO1zp0uNahw7UOnUBda2OMCgoKlJyc7HfPyZOpci03YWFhatiwYVC/Iy4ujn9ZQoRrHTpc69DhWocO1zp0AnGtf6vFxosBxQAAwFEINwAAwFEINwEUGRmpCRMmKDIy0u5SHI9rHTpc69DhWocO1zp07LjWVW5AMQAAcDZabgAAgKMQbgAAgKMQbgAAgKMQbgAAgKMQbgLkmWeeUZMmTRQVFaXOnTtrzZo1dpdU6U2aNEkXXnihYmNjlZiYqKuuukpbtmzxO+bIkSO6/fbbVadOHcXExGjgwIHKzc21rWanmDx5slwul+666y7fPq514Pz000/605/+pDp16ig6Olpt2rTRunXrfO8bYzR+/HjVr19f0dHRSk9P1zfffGNrzZVRcXGxHnzwQTVt2lTR0dFq3ry5/va3v/ndm4hrfeY++eQTDRgwQMnJyXK5XFq4cKHf+2W5tnv37tXQoUMVFxenmjVr6qabbtKBAwfOvjiDszZv3jzjdrvN7NmzzRdffGFuueUWU7NmTZObm2t3aZVa7969zYsvvmg+//xzk52dbfr162caNWpkDhw44DtmxIgRJiUlxWRmZpp169aZLl26mK5du9pad2W3Zs0a06RJE9O2bVszZswY336udWDs3bvXNG7c2Nxwww1m9erV5vvvvzfvv/+++fbbb33HTJ482cTHx5uFCxeaDRs2mCuvvNI0bdrUHD582NbaK5uJEyeaOnXqmEWLFpkffvjBLFiwwMTExJinnnrKdwzX+sy999575oEHHjBvvfWWkWTefvttv/fLcm379Olj2rVrZ1atWmU+/fRTc84555ghQ4acdW2EmwDo1KmTuf32232vi4uLTXJyspk0aZKtdTnNrl27jCTz8ccfG2OM2b9/v6lWrZpZsGCB75jNmzcbSWblypU2Vlp5FRQUmBYtWpilS5ea7t27+8IN1zpw7r33XnPRRRed8n2Px2Pq1atnHnvsMd++/fv3m8jISDN37twQVekM/fv3NzfeeKPfvj/84Q9m6NChxnCtA+rX4aYs1/bLL780kszatWt9xyxevNi4XC7z008/nVU9dEudpaKiImVlZSk9Pd23LywsTOnp6Vq5cqWttTlNXl6eJKl27dqSpKysLB09etTv2rds2VKNGjXi2p+h22+/Xf379/e7puJaB9Q777yjjh076tprr1ViYqI6dOigWbNm+d7/4YcflJOT43et4+Pj1blzZ651OXXt2lWZmZn6+uuvJUkbNmzQ8uXL1bdvX4lrHVRlubYrV65UzZo11bFjR98x6enpCgsL0+rVq8/q+6vcjTMDbc+ePSouLlZSUpLf/qSkJH311Ve21eU0Ho9Hd911l7p166bWrVtLknJycuR2u1WzZk2/Y5OSkpSTk2NTpZXXvHnztH79eq1du7bUe1zrwPn+++81Y8YMjR07Vvfff7/Wrl2rO++8U263WxkZGb7rebL/pnCty+e+++5Tfn6+WrZsqfDwcBUXF2vixIkaOnSodPyfa3Gtg6Is1zYnJ0eJiYl+70dERKh27dpnff0JN6gUbr/9dn3++edavny53aU40o4dOzRmzBgtXbpUUVFRdpfjaB6PRx07dtTf//53SVKHDh30+eefa+bMmcrIyLC7PEd544039Prrr2vOnDlq1aqVsrOzdddddyk5OZlr7XB0S52lhIQEhYeHl5o1kpubq3r16tlWl5PccccdWrRokZYtW6aGDRv69terV09FRUXav3+/3/Fc+/LLysrSrl279Lvf/U4RERGKiIjQxx9/rKeffloRERFKSkriWgdI/fr1dcEFF/jtO//887V9+3bp+D/XOn5tT8S1Lr977rlH9913n6677jq1adNG119/vf785z9r0qRJEtc6qMpybevVq6ddu3b5vX/s2DHt3bv3rK8/4eYsud1upaamKjMz07fP4/EoMzNTaWlpttZW2RljdMcdd+jtt9/Whx9+qKZNm/q9n5qaqmrVqvld+y1btmj79u1c+3K67LLLtGnTJmVnZ/u2jh07aujQob7nXOvA6NatW6klDb7++ms1btxYktS0aVPVq1fP71rn5+dr9erVXOtyOnTokMLC/P/MhYeHy+PxSFzroCrLtU1LS9P+/fuVlZXlO+bDDz+Ux+NR586dz66AsxqODGOOTwWPjIw0L730kvnyyy/NrbfeamrWrGlycnLsLq1SGzlypImPjzcfffSR+fnnn33boUOHfMeMGDHCNGrUyHz44Ydm3bp1Ji0tzaSlpdlat1OcOFvKcK0DZs2aNSYiIsJMnDjRfPPNN+b111831atXN6+99prvmMmTJ5uaNWuaf//732bjxo3m97//PdOTz0BGRoZp0KCBbyr4W2+9ZRISEsxf/vIX3zFc6zNXUFBgPvvsM/PZZ58ZSebJJ580n332mdm2bZsxZby2ffr0MR06dDCrV682y5cvNy1atGAqeEUybdo006hRI+N2u02nTp3MqlWr7C6p0pN00u3FF1/0HXP48GEzatQoU6tWLVO9enVz9dVXm59//tnWup3i1+GGax047777rmndurWJjIw0LVu2NM8//7zf+x6Pxzz44IMmKSnJREZGmssuu8xs2bLFtnorq/z8fDNmzBjTqFEjExUVZZo1a2YeeOABU1hY6DuGa33mli1bdtL/RmdkZBhTxmv7yy+/mCFDhpiYmBgTFxdnhg8fbgoKCs66Npc5calGAACASo4xNwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwAAwFEINwCqJJfLpYULF9pdBoAgINwACLkbbrhBLper1NanTx+7SwPgABF2FwCgaurTp49efPFFv32RkZG21QPAOWi5AWCLyMhI1atXz2+rVauWdLzLaMaMGerbt6+io6PVrFkzvfnmm37nb9q0SZdeeqmio6NVp04d3XrrrTpw4IDfMbNnz1arVq0UGRmp+vXr64477vB7f8+ePbr66qtVvXp1tWjRQu+8847vvX379mno0KGqW7euoqOj1aJFi1JhDEDFRLgBUCE9+OCDGjhwoDZs2KChQ4fquuuu0+bNmyVJBw8eVO/evVWrVi2tXbtWCxYs0AcffOAXXmbMmKHbb79dt956qzZt2qR33nlH55xzjt93PPzwwxo0aJA2btyofv36aejQodq7d6/v+7/88kstXrxYmzdv1owZM5SQkBDiqwDgjJz1fcUBoJwyMjJMeHi4qVGjht82ceJEY4wxksyIESP8zuncubMZOXKkMcaY559/3tSqVcscOHDA9/5//vMfExYWZnJycowxxiQnJ5sHHnjglDVIMn/96199rw8cOGAkmcWLFxtjjBkwYIAZPnx4gH85gFBgzA0AW/Ts2VMzZszw21e7dm3f87S0NL/30tLSlJ2dLUnavHmz2rVrpxo1avje79atmzwej7Zs2SKXy6WdO3fqsssuO20Nbdu29T2vUaOG4uLitGvXLknSyJEjNXDgQK1fv16XX365rrrqKnXt2vUsfzWAUCDcALBFjRo1SnUTBUp0dHSZjqtWrZrfa5fLJY/HI0nq27evtm3bpvfee09Lly7VZZddpttvv12PP/54UGoGEDiMuQFQIa1atarU6/PPP1+SdP7552vDhg06ePCg7/0VK1YoLCxM5513nmJjY9WkSRNlZmaeVQ1169ZVRkaGXnvtNU2dOlXPP//8WX0egNCg5QaALQoLC5WTk+O3LyIiwjdod8GCBerYsaMuuugivf7661qzZo1eeOEFSdLQoUM1YcIEZWRk6KGHHtLu3bs1evRoXX/99UpKSpIkPfTQQxoxYoQSExPVt29fFRQUaMWKFRo9enSZ6hs/frxSU1PVqlUrFRYWatGiRb5wBaBiI9wAsMWSJUtUv359v33nnXeevvrqK+n4TKZ58+Zp1KhRql+/vubOnasLLrhAklS9enW9//77GjNmjC688EJVr15dAwcO1JNPPun7rIyMDB05ckT/+Mc/dPfddyshIUHXXHNNmetzu90aN26ctm7dqujoaF188cWaN29ewH4/gOBxGWvWAABUGC6XS2+//bauuuoqu0sBUAkx5gYAADgK4QYAADgKY24AVDj0lgM4G7TcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAARyHcAAAAR/n/66M0wFecaLoAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.suptitle('MSE vs Epochs')\n",
"plt.plot(tune_train_err, label='Train', color='blue')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('MSE')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.001733972690999508"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def tune_evaluate():\n",
" model.eval()\n",
" test_loss = 0\n",
" with torch.no_grad():\n",
" batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ))\n",
" output = model(batch_src, batch_padding_mask)\n",
" loss = criterion(output.squeeze(1), batch_labels)\n",
" return loss.item()\n",
"\n",
"tune_evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[6.0290e+03, 1.8597e+04, 8.4000e+01, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 1.0000e+00]]),\n",
" array([1. , 1.1 , 1.2 , 1.3 , 1.4 , 1.5 , 1.6 , 1.699, 1.8 ,\n",
" 1.9 , 2. , 2.1 , 2.2 , 2.3 , 2.398, 2.5 , 2.6 , 2.7 ,\n",
" 2.8 , 2.898, 3. , 3.1 , 3.2 , 3.299, 3.398, 3.5 , 3.6 ,\n",
" 3.7 , 3.799, 3.898, 4. , 4.1 , 4.2 , 4.297, 4.4 , 4.5 ,\n",
" 4.6 , 4.7 , 4.797, 4.9 , 5. , 5.098, 5.2 , 5.3 , 5.4 ,\n",
" 5.5 , 5.598, 5.7 , 5.797, 5.9 , 6. ], dtype=float16),\n",
" array([0.9014, 0.998 , 1.094 , 1.19 , 1.287 , 1.384 , 1.48 , 1.576 ,\n",
" 1.673 , 1.77 , 1.865 , 1.962 , 2.059 , 2.156 , 2.25 , 2.348 ,\n",
" 2.445 , 2.541 , 2.637 , 2.734 , 2.83 , 2.926 , 3.023 , 3.121 ,\n",
" 3.215 , 3.312 , 3.41 , 3.504 , 3.602 , 3.7 , 3.793 , 3.89 ,\n",
" 3.988 , 4.082 , 4.18 , 4.277 , 4.375 , 4.47 , 4.566 , 4.664 ,\n",
" 4.758 , 4.855 , 4.953 , 5.05 , 5.145 , 5.242 , 5.34 , 5.434 ,\n",
" 5.53 , 5.63 , 5.723 ], dtype=float16),\n",
" <matplotlib.collections.QuadMesh at 0x7f201dae35c0>)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAhsAAAGdCAYAAAC7JrHlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVOklEQVR4nO3df4yV9Z3o8c8wI0cWZ0ZhYYUyTDSuUKEQf/QSpFrrr2RiJvYf6xpaiTXZtBkVatyYaXKj5FYH/2gTTc1UvY3c3EhoaYq23ouUJgLpZqkwXm5A11asiaNg3CV1hqHpWWXO/WNxbtGic4b5zDNzzuuVnD/O+J3zfJJjnLffc57naahUKpUAAEgyregBAIDaJjYAgFRiAwBIJTYAgFRiAwBIJTYAgFRiAwBIJTYAgFRNE33A4eHhOHz4cDQ3N0dDQ8NEHx4AGINKpRLHjh2L+fPnx7Rp1e1VTHhsHD58ONra2ib6sADAOOjv748FCxZU9TsTHhvNzc0RJ4dtaWmZ6MMDAB9zc+vtn7nmw/ggfhP/e+TveDUmPDY++uikpaVFbADAJNDUcNZnLzp5J7WxfAXCF0QBgFRiAwBIJTYAgFRiAwBINeFfEAUAJpcdw1s+c83g4GC0traO6fXtbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqZqKHgAAKNYN0275zDUfVj4Y8+vb2QAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCVW8wDMGnduPK/jWrdr/7lv6bPUst2DG/5zDWDg4PR2to6pte3swEApBIbAEAqsQEApBIbAEAqsQEApBIbAECqqk59ffDBB2P9+vWn/GzRokXx2muvjfdcABDvXzyz6BEYB1VfZ2PJkiXx61//+v+/QJNLdQAAp1d1KTQ1NcX555+fMw0AUHOq/s7G66+/HvPnz48LL7wwVq9eHW+99danri+XyzE4OHjKAwCoH1XFxooVK2Ljxo3xwgsvRG9vb7z55ptx1VVXxbFjx077Oz09PdHa2jryaGtrG4+5AYApoqFSqVTG+svvv/9+tLe3xw9+8IO48847/+qacrkc5XJ55Png4GC0tbXFwMBAtLS0jPXQANSB/7LmB6Na99L/uDd9lnr30b1RxvL3+4y+3XnuuefGxRdfHIcOHTrtmlKpFKVS6UwOAwBMYWd0nY2hoaF44403Yt68eeM3EQBQU6ra2bjvvvuis7Mz2tvb4/Dhw/HAAw9EY2Nj3HbbbXkTAlC3Zr/0b0WPwDioKjbefvvtuO222+Lo0aMxZ86c+NKXvhR79uyJOXPm5E0IAExpVcXG5s2b8yYBAGqSe6MAAKnEBgCQSmwAAKnEBgCQyi1bAZi0tv1rT9EjMA7sbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJBKbAAAqZqKHgBgsrlh2i2jWrdjeEv6LFAL7GwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnc9RXgY9zNFcaXnQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXrbAB8TMeCe0a1btvbj6XPArXAzgYAkEpsAACpxAYAkEpsAACpxAYAkEpsAACpnPoK8DF/+McLih4BaoqdDQAgldgAAFKdUWxs2LAhGhoaYt26deM3EQBQU8YcG3v37o0nnngili1bNr4TAQA1ZUyxMTQ0FKtXr46nnnoqzjvvvPGfCgCoGWOKja6urrjpppvi+uuv/8y15XI5BgcHT3kAAPWj6lNfN2/eHC+//HLs3bt3VOt7enpi/fr1Y5kNAKgBVe1s9Pf3x9q1a+OZZ56Js88+e1S/093dHQMDAyOP/v7+sc4KMCE+/JvKqB7A6FS1s9HX1xfvvfdeXHbZZSM/O3HiROzevTt++MMfRrlcjsbGxlN+p1QqRalUGr+JAYApparYuO666+LAgQOn/OyOO+6IxYsXx/333/+J0AAAqCo2mpubY+nSpaf8bObMmTF79uxP/BwAIFxBFADIdsY3Ytu5c+f4TAIA1CQ7GwBAKreYB/iYv/2/RU8AtcXOBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQynU2AD6m5Zl/Gd3C/5k9CdQGOxsAQCqxAQCkEhsAQCqxAQCkEhsAQCqxAQCkcuorwMfsGN5S9AhQU+xsAACpxAYAkEpsAACpxAYAkEpsAACpxAYAkMqprwAfc8O0W0a1zimyMDp2NgCAVGIDAEglNgCAVGIDAEglNgCAVGIDAEglNgCAVK6zAZNEx4J7RrVu29uPpc8CMJ7sbAAAqcQGAJBKbAAAqcQGAJBKbAAAqcQGAJDKqa8wSQz/3XlFj8BJTfPnFT0C1BQ7GwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKRy6itMEscuail6BIAUdjYAgFRiAwBIJTYAgFRiAwBIJTYAgFRiAwBIJTYAgFSuswGTxNDntD9Qm/zXDQBIJTYAgFRiAwBIJTYAgFRiAwBIJTYAgFROfYVJYnD5fxQ9Aidte/uxokeAmmJnAwBIJTYAgFRiAwBIJTYAgFRVxUZvb28sW7YsWlpaoqWlJVauXBnbtm3Lmw4AmPKqio0FCxbEhg0boq+vL/bt2xfXXntt3HzzzfHKK6/kTQgATGlVnfra2dl5yvOHHnooent7Y8+ePbFkyZLxng0AqAFjvs7GiRMnYsuWLXH8+PFYuXLladeVy+Uol8sjzwcHB8d6SKhpf9/+btEjAKSo+guiBw4ciHPOOSdKpVJ861vfiq1bt8Yll1xy2vU9PT3R2to68mhrazvTmQGAKaTq2Fi0aFHs378/fvvb38a3v/3tWLNmTbz66qunXd/d3R0DAwMjj/7+/jOdGQCYQqr+GGX69Olx0UUXRUTE5ZdfHnv37o1HH300nnjiib+6vlQqRalUOvNJAYAp6YyvszE8PHzKdzIAAP5SVTsb3d3d0dHREQsXLoxjx47Fpk2bYufOnbF9+/a8CQGAKa2q2Hjvvffi9ttvjyNHjkRra2ssW7Ystm/fHjfccEPehADAlFZVbPz4xz/Om4RC3DDtllGt2zG8JX2Wenfkfy0c3cJrsicBGF/ujQIApBIbAEAqsQEApBIbAEAqsQEApBIbAECqMd/1FRhfs373YdEjAKSwswEApBIbAEAqsQEApBIbAEAqsQEApBIbAEAqsQEApHKdjTrn1vGTx4x3hooeASCFnQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSOfUVJontfeuLHgEghZ0NACCV2AAAUokNACCV2AAAUokNACCV2AAAUokNACCV62zUuY4F94xq3ba3H0ufpd55L4BaZWcDAEglNgCAVGIDAEglNgCAVGIDAEglNgCAVE59rXMn2uYUPQInOaUVqFV2NgCAVGIDAEglNgCAVGIDAEglNgCAVGIDAEjl1Nc6Nzy9segROKnj892jWrftX3vSZwEYT3Y2AIBUYgMASCU2AIBUYgMASCU2AIBUYgMASCU2AIBUrrNR5/7t0hlFj8BJrp8B1Co7GwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKRy6mude3/ph0WPAECNs7MBAKQSGwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKSq6jobPT098fOf/zxee+21mDFjRlx55ZXxyCOPxKJFi/ImJNVj1z4zypX/lDwJN0y7ZVTrdgxvSZ8FYDxVtbOxa9eu6Orqij179sSOHTvigw8+iBtvvDGOHz+eNyEAMKVVtbPxwgsvnPJ848aNMXfu3Ojr64urr756vGcDAGrAGV2ufGBgICIiZs2addo15XI5yuXyyPPBwcEzOSQAMMWM+Quiw8PDsW7duli1alUsXbr0tOt6enqitbV15NHW1jbWQwIAU9CYY6OrqysOHjwYmzdv/tR13d3dMTAwMPLo7+8f6yEBgCloTB+j3HXXXfH888/H7t27Y8GCBZ+6tlQqRalUGut8AMAUV1VsVCqVuPvuu2Pr1q2xc+fOuOCCC/ImY0Ks/ed/GNW6my9MH6XuOaUVqFVVxUZXV1ds2rQpnnvuuWhubo533303IiJaW1tjxowZWTMCAFNYVd/Z6O3tjYGBgbjmmmti3rx5I4+f/OQneRMCAFNa1R+jAABUw71RAIBUYgMASCU2AIBUZ3S5cqa+psOugQJALjsbAEAqsQEApBIbAEAqsQEApBIbAEAqsQEApBIbAECqwq6zcXPr7dHUcNZp/7nbbU+Mme8UPQEAtc7OBgCQSmwAAKnEBgCQSmwAAKnEBgCQSmwAAKkKO/W16e8vjKZGtzcv2jlHhoseAYAaZ2cDAEglNgCAVGIDAEglNgCAVGIDAEglNgCAVGIDAEhV2HU2yvNa4kTT2UUdnpNmvv2nokcAoMbZ2QAAUokNACCV2AAAUokNACCV2AAAUokNACBVYae+lo4MRlNjuajDc1LDf3xY9AgA1Dg7GwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKQq7NTXf1/xt9E43V1fizb8f14tegQAapydDQAgldgAAFKJDQAgldgAAFKJDQAgldgAAFKJDQAgVWHX2YiGkw8KtWN4S9EjAFDj7GwAAKnEBgCQSmwAAKnEBgCQSmwAAKnEBgCQqrBTXwfbG6LxbOe+AkCts7MBAKQSGwBAKrEBAKQSGwBAKrEBAKQSGwBAKrEBAKQq7Dobfbf/92hpbvyUFd+ZwGkAgCx2NgCAVGIDAEhVdWzs3r07Ojs7Y/78+dHQ0BDPPvtszmQAQE2oOjaOHz8ey5cvj8cffzxnIgCgplT9BdGOjo7o6OjImQYAqDnpZ6OUy+Uol8sjzwcHB7MPCQBMIumx0dPTE+vXr//Ez7f/6ez4m8bTn/ramTwXADAx0s9G6e7ujoGBgZFHf39/9iEBgEkkfWejVCpFqVTKPgwAMEm5zgYAkKrqnY2hoaE4dOjQyPM333wz9u/fH7NmzYqFCxeO93wAwBRXdWzs27cvvvKVr4w8v/feeyMiYs2aNbFx48bxnQ4AmPKqjo1rrrkmKpXKmA/40e/+aejEp65ziiwATB4f/V0eSwNM+F1fjx49GhER3/zS7z5jZeuEzAMAjN7Ro0ejtbW6v9ETHhuzZs2KiIi33nqr6mEZX4ODg9HW1hb9/f3R0tJS9Dh1zXsxeXgvJg/vxeQyMDAQCxcuHPk7Xo0Jj41p0/7zBJjW1lb/8kwSLS0t3otJwnsxeXgvJg/vxeTy0d/xqn4nZRIAgJPEBgCQasJjo1QqxQMPPOCqopOA92Ly8F5MHt6LycN7MbmcyfvRUDmT81gBAD6Dj1EAgFRiAwBIJTYAgFRiAwBINWGxsXv37ujs7Iz58+dHQ0NDPPvssxN1aP5CT09PfPGLX4zm5uaYO3dufPWrX43f/e6zLh1Plt7e3li2bNnIRYtWrlwZ27ZtK3osImLDhg3R0NAQ69atK3qUuvPggw9GQ0PDKY/FixcXPVbdeuedd+LrX/96zJ49O2bMmBFf+MIXYt++fVW9xoTFxvHjx2P58uXx+OOPT9Qh+St27doVXV1dsWfPntixY0d88MEHceONN8bx48eLHq0uLViwIDZs2BB9fX2xb9++uPbaa+Pmm2+OV155pejR6trevXvjiSeeiGXLlhU9St1asmRJHDlyZOTxm9/8puiR6tIf//jHWLVqVZx11lmxbdu2ePXVV+P73/9+nHfeeVW9zoRdrryjoyM6Ojom6nCcxgsvvHDK840bN8bcuXOjr68vrr766sLmqlednZ2nPH/ooYeit7c39uzZE0uWLClsrno2NDQUq1evjqeeeiq+973vFT1O3Wpqaorzzz+/6DHq3iOPPBJtbW3x9NNPj/zsggsuqPp1fGejzg0MDET8xQ3yKM6JEydi8+bNcfz48Vi5cmXR49Strq6uuOmmm+L6668vepS69vrrr8f8+fPjwgsvjNWrV8dbb71V9Eh16Re/+EVcccUVccstt8TcuXPj0ksvjaeeeqrq15nwG7ExeQwPD8e6deti1apVsXTp0qLHqVsHDhyIlStXxp///Oc455xzYuvWrXHJJZcUPVZd2rx5c7z88suxd+/eokepaytWrIiNGzfGokWL4siRI7F+/fq46qqr4uDBg9Hc3Fz0eHXlD3/4Q/T29sa9994b3/3ud2Pv3r1xzz33xPTp02PNmjWjfh2xUce6urri4MGDPgst2KJFi2L//v0xMDAQP/vZz2LNmjWxa9cuwTHB+vv7Y+3atbFjx444++yzix6nrv3lR+7Lli2LFStWRHt7e/z0pz+NO++8s9DZ6s3w8HBcccUV8fDDD0dExKWXXhoHDx6MH/3oR1XFho9R6tRdd90Vzz//fLz44ouxYMGCosepa9OnT4+LLrooLr/88ujp6Ynly5fHo48+WvRYdaevry/ee++9uOyyy6KpqSmamppi165d8dhjj0VTU1OcOHGi6BHr1rnnnhsXX3xxHDp0qOhR6s68efM+8T8+n//856v+WMvORp2pVCpx9913x9atW2Pnzp1j+qIPuYaHh6NcLhc9Rt257rrr4sCBA6f87I477ojFixfH/fffH42NjYXNVu+GhobijTfeiG984xtFj1J3Vq1a9YnLI/z+97+P9vb2ql5nwmJjaGjolCp98803Y//+/TFr1qxYuHDhRI1R97q6umLTpk3x3HPPRXNzc7z77rsREdHa2hozZswoery6093dHR0dHbFw4cI4duxYbNq0KXbu3Bnbt28verS609zc/InvLs2cOTNmz57tO00T7L777ovOzs5ob2+Pw4cPxwMPPBCNjY1x2223FT1a3fnOd74TV155ZTz88MPxta99LV566aV48skn48knn6zuhSoT5MUXX6xExCcea9asmagR+M87/P7Vx9NPP130aHXpm9/8ZqW9vb0yffr0ypw5cyrXXXdd5Ve/+lXRY3HSl7/85cratWuLHqPu3HrrrZV58+ZVpk+fXvnc5z5XufXWWyuHDh0qeqy69ctf/rKydOnSSqlUqixevLjy5JNPVv0abjEPAKTyBVEAIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABSiQ0AIJXYAABS/T/VxjQRzeYx5wAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" output = model(batch_src, batch_padding_mask)\n",
"x = batch_labels.detach().to(torch.float16).cpu().numpy().flatten()\n",
"y = output.detach().to(torch.float16).cpu().numpy().flatten()\n",
"plt.hist2d(x, y, bins=50, norm=mpl.colors.LogNorm())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JtTLXn4zC1z_"
},
"source": [
"# Step 6: Test generalization"
]
},
{
"cell_type": "code",
"execution_count": 27,
"execution_state": "idle",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.20130713284015656\n"
]
},
{
"data": {
"text/plain": [
"(array([[1.7789e+04, 6.5000e+02, 4.8000e+02, ..., 0.0000e+00, 1.0000e+00,\n",
" 1.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" ...,\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00],\n",
" [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00,\n",
" 0.0000e+00]]),\n",
" array([ 1. , 1.18 , 1.36 , 1.54 , 1.721, 1.9 , 2.08 , 2.262,\n",
" 2.441, 2.621, 2.8 , 2.98 , 3.16 , 3.34 , 3.521, 3.701,\n",
" 3.88 , 4.062, 4.242, 4.42 , 4.6 , 4.78 , 4.96 , 5.14 ,\n",
" 5.32 , 5.5 , 5.68 , 5.863, 6.043, 6.223, 6.402, 6.582,\n",
" 6.76 , 6.94 , 7.12 , 7.3 , 7.48 , 7.66 , 7.844, 8.02 ,\n",
" 8.2 , 8.38 , 8.56 , 8.74 , 8.92 , 9.1 , 9.28 , 9.46 ,\n",
" 9.64 , 9.82 , 10. ], dtype=float16),\n",
" array([ 0.879, 1.156, 1.434, 1.71 , 1.987, 2.266, 2.541, 2.818,\n",
" 3.096, 3.373, 3.65 , 3.928, 4.203, 4.48 , 4.758, 5.035,\n",
" 5.312, 5.59 , 5.867, 6.145, 6.42 , 6.7 , 6.977, 7.254,\n",
" 7.527, 7.805, 8.08 , 8.36 , 8.64 , 8.92 , 9.19 , 9.47 ,\n",
" 9.75 , 10.016, 10.3 , 10.58 , 10.86 , 11.125, 11.41 , 11.69 ,\n",
" 11.97 , 12.234, 12.516, 12.8 , 13.08 , 13.34 , 13.625, 13.91 ,\n",
" 14.17 , 14.45 , 14.734], dtype=float16),\n",
" <matplotlib.collections.QuadMesh at 0x7f201daaff50>)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAfp0lEQVR4nO3de3BU9f3/8dcmIUuMyUr4Si6SYEAKym2gWH6AbaFE/eaHFKZTVCZiCnU67UQh0lpINSqDENHWKpUJYjuAF/DSacDyG6Qp5VJ+VQjBONALF6WYcp35DeyS8GUN2fP7o2OmEYTEPZvz3s3zMXP+2LOf7Hkdg+yLz57PWZ/jOI4AAAAMSfI6AAAAwOdRUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYk+J1gM+LRCI6fvy4MjIy5PP5vI4DAAA6wHEcnTt3Tnl5eUpKin7+w1xBOX78uPLz872OAQAAvoTGxkb17ds36tcxV1AyMjIkSbfpfytFPa44dkPwlS5KBQAAriQUCik/P7/tfTxa5grKZx/rpKiHUnxXLiiZmZldlAoAAHSEW5dncJEsAAAwh4ICAADMoaAAAABzKCgAAMAccxfJAug6tydN79C42sjbMc8CAP+JGRQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5vgcx3E68wM7duzQs88+q/r6ep04cUI1NTWaNm3aZcf+8Ic/1EsvvaRf/vKXKi8v79Drh0IhBQIBBYNBvgwQAIA44fb7d6dnUJqbmzVixAgtX778iuNqamr0/vvvKy8vL5p8AACgG+r0nWSLi4tVXFx8xTHHjh3TQw89pM2bN2vy5MnR5AMAAN2Q67e6j0Qimjlzph555BENGTLkquPD4bDC4XDb41Ao5HYkAAAQZ1y/SHbp0qVKSUnRnDlzOjS+qqpKgUCgbcvPz3c7EgAAiDOuFpT6+nq98MILWr16tXw+X4d+pqKiQsFgsG1rbGx0MxIAAIhDrhaUP//5zzp9+rQKCgqUkpKilJQUHT16VD/+8Y914403XvZn/H6/MjMz220AAKB7c/UalJkzZ6qoqKjdvjvvvFMzZ87UrFmz3DwUAABIYJ0uKE1NTTp8+HDb4yNHjqihoUFZWVkqKChQ7969243v0aOHcnJyNGjQIHcSAwCAhNfpgrJnzx5NnDix7fG8efMkSaWlpVq9erW76QAAQLfU6YIyYcIEdebms//85z87ewgAANDN8V08AADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAnBSvA3yRqYH7leLrccUxtZG3uywPAADoOsygAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADDH7H1QNgRfUWZmptcxAACAB5hBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJjT6YKyY8cOTZkyRXl5efL5fFq/fn3bcy0tLZo/f76GDRum9PR05eXl6f7779fx48fdzg0AABJYpwtKc3OzRowYoeXLl1/y3Pnz57V3715VVlZq7969+t3vfqcDBw7o29/+tlt5AQBAN+BzHMf50j/s86mmpkbTpk37wjF1dXX62te+pqNHj6qgoOCqrxkKhRQIBBQMBrnVPQAAccLt9++YX4MSDAbl8/l03XXXxfpQAAAgQcT0ywIvXLig+fPna8aMGV/YpsLhsMLhcNvjUCgUy0gAACAOxGwGpaWlRXfffbccx1F1dfUXjquqqlIgEGjb8vPzYxUJAADEiZgUlM/KydGjR1VbW3vFz6IqKioUDAbbtsbGxlhEAgAAccT1j3g+KyeHDh3S1q1b1bt37yuO9/v98vv9bscAAABxrNMFpampSYcPH257fOTIETU0NCgrK0u5ubn67ne/q71792rjxo1qbW3VyZMnJUlZWVlKTU11Nz0AAEhInV5mvG3bNk2cOPGS/aWlpXryySdVWFh42Z/bunWrJkyYcNXXZ5kxAADxx+33707PoEyYMEFX6jRR3FYFAABA4rt4AACARRQUAABgTkxv1BaNqYH7leLrccUxtZG3uywPAADoOsygAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcs8uMNwRf4Vb3AAB0U8ygAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHM6XVB27NihKVOmKC8vTz6fT+vXr2/3vOM4evzxx5Wbm6u0tDQVFRXp0KFDbmYGAAAJrtMFpbm5WSNGjNDy5csv+/wzzzyjZcuWacWKFdq1a5fS09N155136sKFC27kBQAA3UBKZ3+guLhYxcXFl33OcRw9//zzeuyxxzR16lRJ0iuvvKLs7GytX79e9957b/SJAQBAwnP1GpQjR47o5MmTKioqatsXCAQ0ZswYvffee5f9mXA4rFAo1G4DAADdm6sF5eTJk5Kk7Ozsdvuzs7Pbnvu8qqoqBQKBti0/P9/NSAAAIA55voqnoqJCwWCwbWtsbPQ6EgAA8JirBSUnJ0eSdOrUqXb7T5061fbc5/n9fmVmZrbbAABA9+ZqQSksLFROTo62bNnSti8UCmnXrl0aO3asm4cCAAAJrNOreJqamnT48OG2x0eOHFFDQ4OysrJUUFCg8vJyPfXUUxo4cKAKCwtVWVmpvLw8TZs2ze3sAAAgQXW6oOzZs0cTJ05sezxv3jxJUmlpqVavXq2f/vSnam5u1g9+8AOdPXtWt912m95991317NnT3eQAACBh+RzHcbwO8Z9CoZACgYCCwSDXowAAECfcfv/u9AxKV5kauF8pvh5XHFMbebvL8qB7uD1peofG8WcPAGLL82XGAAAAn0dBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmcB8UAAAQNbffv5lBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDmuF5TW1lZVVlaqsLBQaWlpGjBggBYtWiTHcdw+FAAASFApbr/g0qVLVV1drTVr1mjIkCHas2ePZs2apUAgoDlz5rh9OAAAkIBcLyh/+ctfNHXqVE2ePFmSdOONN2rdunXavXu324cCAAAJyvWPeMaNG6ctW7bo4MGDkqQPP/xQO3fuVHFx8WXHh8NhhUKhdhsAAOjeXJ9BWbBggUKhkAYPHqzk5GS1trZq8eLFKikpuez4qqoqLVy40O0YAAAgjrk+g/LWW2/p9ddf19q1a7V3716tWbNGP//5z7VmzZrLjq+oqFAwGGzbGhsb3Y4EAADijM9xeXlNfn6+FixYoLKysrZ9Tz31lF577TX94x//uOrPh0IhBQIBBYNBZWZmuhkNAADEiNvv367PoJw/f15JSe1fNjk5WZFIxO1DAQCABOX6NShTpkzR4sWLVVBQoCFDhuiDDz7Qc889p9mzZ7t9KAAAkKBc/4jn3LlzqqysVE1NjU6fPq28vDzNmDFDjz/+uFJTU6/683zEAwBA/HH7/dv1ghKtz05wgqYqxdfjimNrI293WS4AAPDFzF+DAgAAEC0KCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx/UbtbllQ/AV7oMCAEA3xQwKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMSfE6QCK7PWl6h8bVRt6OeRYAAOIJMygAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwBwKCgAAMIdlxjHE8mEAQCLoyG0zLjotrh6TGRQAAGAOBQUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmMMyYwAAcEUduW1GKBRSIBBw7ZjMoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAnJgUlGPHjum+++5T7969lZaWpmHDhmnPnj2xOBQAAEhArt8H5cyZMxo/frwmTpyoTZs26frrr9ehQ4fUq1cvtw8FAB36Gnh18D4O8aC7nS+6L9cLytKlS5Wfn69Vq1a17SssLHT7MAAAIIG5/hHPO++8o9GjR2v69Onq06ePRo4cqZdfftntwwAAgATmekH5+OOPVV1drYEDB2rz5s360Y9+pDlz5mjNmjWXHR8OhxUKhdptAACge3P9I55IJKLRo0dryZIlkqSRI0dq//79WrFihUpLSy8ZX1VVpYULF7odAwAAxDHXZ1Byc3N1yy23tNt3880365NPPrns+IqKCgWDwbatsbHR7UgAACDOuD6DMn78eB04cKDdvoMHD6pfv36XHe/3++X3+92OAQAA4pjrBeXhhx/WuHHjtGTJEt19993avXu3Vq5cqZUrV7p9KADodstpU/JyvY6Abqgjy9svOi2uHtP1j3huvfVW1dTUaN26dRo6dKgWLVqk559/XiUlJW4fCgAAJCjXZ1Ak6a677tJdd90Vi5cGAADdAN/FAwAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMickqHjdMDdyvFF+PK47pbvc/QOx1t6+yL+47p0PjNv1rWcyzfFl3fvWJDo3bXJ8gX6mRke51Ald05P+1RPn/LBF05HcRCoUUCARcOyYzKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx+wy40SQCEs4OypRludaz+e2f91b6HWEqDUWX+d1hC4VGtLb6wiu6G7/r6HzmEEBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDlmlxl/evsoRXr09DpGVC4W9PE6QpdpuWO01xFcMeG/l3Zo3LZ358c8S1dozne8jhC1Xv9o9TpCl/q/v/2J1xGALsEMCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMoaAAAABz7N4HJTNZkR7JXseIyqe9Ur2O0GUu/FcPryO4Ijig+/zOJKnnKZ/XEaLmP3PR6whd6vak6R0aVxt5O+ZZgFhiBgUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5phdZpzU4ihJ8f1V8OFeZv/zus7XGt+/q8+0XON1gq6VcsHrBNFL3lrvdYQuxfJhdBfMoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx+yNOq45cUEpZtN1jNON6l+kh8/rCK44P+p/vI7QpZI+9TpB9FIG3eR1hC51e9L0Do3jfimIdzF/C3366afl8/lUXl4e60MBAIAEEdOCUldXp5deeknDhw+P5WEAAECCiVlBaWpqUklJiV5++WX16tUrVocBAAAJKGYFpaysTJMnT1ZRUdEVx4XDYYVCoXYbAADo3mJyGeobb7yhvXv3qq6u7qpjq6qqtHDhwljEAAAAccr1GZTGxkbNnTtXr7/+unr27HnV8RUVFQoGg21bY2Oj25EAAECccX0Gpb6+XqdPn9aoUaPa9rW2tmrHjh168cUXFQ6HlZyc3Pac3++X3++/5HXODr5GyalXLziWnb8+MZbedkSoX2Kcq9OaGOfRUWn/L+J1BAC4LNcLyqRJk7Rv3752+2bNmqXBgwdr/vz57coJAADA5bheUDIyMjR06NB2+9LT09W7d+9L9gMAAFxON7rXKQAAiBddcjP5bdu2dcVhAABAgmAGBQAAmENBAQAA5pj9vuD/6e1Tsj++l3w63WjBUkum43UEV6Q0xvfS9s5KDbV6HQGdxLcUo7tgBgUAAJhDQQEAAOZQUAAAgDkUFAAAYA4FBQAAmENBAQAA5lBQAACAOWbvg+Kk/HuLZ5FUrxN0nYvXJcb9NHwRrxN0LV9r/N+/5uKBw15H6FK3J03v0Djul4J4xwwKAAAwh4ICAADMoaAAAABzKCgAAMAcCgoAADCHggIAAMwxu5C3NVWS3+sU0bnY0+sEXWfq6L1eR3CF/4zXCbpWUkv8LzPW/xrhdYIuxfJhdBfMoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwx+x9UC5mOIr0jO97NHz6X61eR+gy6+tHdWjcCyNjHiUq1x6PeB2hS/mPB72OELXavzzmdQR8CbcnTb/qGO750r0xgwIAAMyhoAAAAHMoKAAAwBwKCgAAMIeCAgAAzKGgAAAAc8wuM470apHSkr2OEZ30i14n6DL+k2b/KHXKxZ4+ryMA3QJLiHE1zKAAAABzKCgAAMAcCgoAADCHggIAAMyhoAAAAHMoKAAAwByza0P96WElXxPfSz4fG/N/OjhyQYyTxJ4vvr94us01p7rP0nBJunjgsNcRAOCymEEBAADmUFAAAIA5rheUqqoq3XrrrcrIyFCfPn00bdo0HThwwO3DAACABOZ6Qdm+fbvKysr0/vvvq7a2Vi0tLbrjjjvU3Nzs9qEAAECCcv0i2Xfffbfd49WrV6tPnz6qr6/XN77xDbcPBwAAElDMV/EEg0FJUlZW1mWfD4fDCofDbY9DoVCsIwEAAONiepFsJBJReXm5xo8fr6FDh152TFVVlQKBQNuWn58fy0gAACAOxHQGpaysTPv379fOnTu/cExFRYXmzZvX9jgUCik/P1/+1ItKTm2JZbyYO/5pL68jdJlrTnqdwB2pm+q8jtCl+Mp7AFbFrKA8+OCD2rhxo3bs2KG+fft+4Ti/3y+/3x+rGAAAIA65XlAcx9FDDz2kmpoabdu2TYWFhW4fAgAAJDjXC0pZWZnWrl2rDRs2KCMjQydP/nvuPxAIKC0tze3DAQCABOT6RbLV1dUKBoOaMGGCcnNz27Y333zT7UMBAIAEFZOPeAAAAKLBd/EAAABzYn6jti+r5WKyIheTvY4Rld+8//UOjXv88reIiSvXHm/1OoIrkkbe4nWELlXcd06Hxm3617KYZwGA/8QMCgAAMIeCAgAAzKGgAAAAcygoAADAHAoKAAAwh4ICAADMMbvMON3/qZL9Pq9jRKXnCbP/eV3Xc8MuryO4YnP9Qq8jdCmWDwOwihkUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOaYvVHHn4avV2ZG8lVGLemiNF/ODdvCHRtYGeskAADEF2ZQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA55lbxOI4jSQo1Ra46NumaUBck+vIuXrzQoXGhkO3z6IiLTkuHxiXCuQIALvXZ3++fvY9Hy+e49Uou+fjjjzVgwACvYwAAgC/ho48+Uv/+/aN+HXMzKFlZWZKkTz75RIFAwOs4rgqFQsrPz1djY6MyMzO9juO6RD4/zi0+cW7xiXOLT8FgUAUFBW3v49EyV1CSkv59WUwgEEi4X95nMjMzE/bclODnx7nFJ84tPnFu8emz9/GoX8eVVwEAAHARBQUAAJhjrqD4/X498cQT8vv9XkdxXSKfmxL8/Di3+MS5xSfOLT65fW7mVvEAAACYm0EBAACgoAAAAHMoKAAAwBwKCgAAMMdMQdmxY4emTJmivLw8+Xw+rV+/3utIrqmqqtKtt96qjIwM9enTR9OmTdOBAwe8juWK6upqDR8+vO2mQ2PHjtWmTZu8jhUTTz/9tHw+n8rLy72OErUnn3xSPp+v3TZ48GCvY7nm2LFjuu+++9S7d2+lpaVp2LBh2rNnj9exXHHjjTde8rvz+XwqKyvzOlrUWltbVVlZqcLCQqWlpWnAgAFatGiRa9/t4rVz586pvLxc/fr1U1pamsaNG6e6ujqvY3Xa1d6vHcfR448/rtzcXKWlpamoqEiHDh3q9HHMFJTm5maNGDFCy5cv9zqK67Zv366ysjK9//77qq2tVUtLi+644w41Nzd7HS1qffv21dNPP636+nrt2bNH3/rWtzR16lT99a9/9Tqaq+rq6vTSSy9p+PDhXkdxzZAhQ3TixIm2befOnV5HcsWZM2c0fvx49ejRQ5s2bdLf/vY3/eIXv1CvXr28juaKurq6dr+32tpaSdL06dO9jha1pUuXqrq6Wi+++KL+/ve/a+nSpXrmmWf0q1/9yutornjggQdUW1urV199Vfv27dMdd9yhoqIiHTt2zOtonXK19+tnnnlGy5Yt04oVK7Rr1y6lp6frzjvv1IULHfsC3TaOQZKcmpoar2PEzOnTpx1Jzvbt272OEhO9evVyfv3rX3sdwzXnzp1zBg4c6NTW1jrf/OY3nblz53odKWpPPPGEM2LECK9jxMT8+fOd2267zesYXWbu3LnOgAEDnEgk4nWUqE2ePNmZPXt2u33f+c53nJKSEs8yueX8+fNOcnKys3Hjxnb7R40a5Tz66KOe5YrW59+vI5GIk5OT4zz77LNt+86ePev4/X5n3bp1nXptMzMo3UkwGJT+44sRE0Vra6veeOMNNTc3a+zYsV7HcU1ZWZkmT56soqIir6O46tChQ8rLy1P//v1VUlKiTz75xOtIrnjnnXc0evRoTZ8+XX369NHIkSP18ssvex0rJj799FO99tprmj17tnw+n9dxojZu3Dht2bJFBw8elCR9+OGH2rlzp4qLi72OFrWLFy+qtbVVPXv2bLc/LS0tYWYvJenIkSM6efJku78vA4GAxowZo/fee69Tr2XuywITXSQSUXl5ucaPH6+hQ4d6HccV+/bt09ixY3XhwgVde+21qqmp0S233OJ1LFe88cYb2rt3b1x+TnwlY8aM0erVqzVo0CCdOHFCCxcu1Ne//nXt379fGRkZXseLyscff6zq6mrNmzdPP/vZz1RXV6c5c+YoNTVVpaWlXsdz1fr163X27Fl973vf8zqKKxYsWKBQKKTBgwcrOTlZra2tWrx4sUpKSryOFrWMjAyNHTtWixYt0s0336zs7GytW7dO7733nm666Sav47nm5MmTkqTs7Ox2+7Ozs9ue6ygKShcrKyvT/v37E6oxDxo0SA0NDQoGg/rtb3+r0tJSbd++Pe5LSmNjo+bOnava2tpL/tUT7/7zX6TDhw/XmDFj1K9fP7311lv6/ve/72m2aEUiEY0ePVpLliyRJI0cOVL79+/XihUrEq6g/OY3v1FxcbHy8vK8juKKt956S6+//rrWrl2rIUOGqKGhQeXl5crLy0uI392rr76q2bNn64YbblBycrJGjRqlGTNmqL6+3utoJvERTxd68MEHtXHjRm3dulV9+/b1Oo5rUlNTddNNN+mrX/2qqqqqNGLECL3wwgtex4pafX29Tp8+rVGjRiklJUUpKSnavn27li1bppSUFLW2tnod0TXXXXedvvKVr+jw4cNeR4labm7uJeX45ptvTpiPsD5z9OhR/fGPf9QDDzzgdRTXPPLII1qwYIHuvfdeDRs2TDNnztTDDz+sqqoqr6O5YsCAAdq+fbuamprU2Nio3bt3q6WlRf379/c6mmtycnIkSadOnWq3/9SpU23PdRQFpQs4jqMHH3xQNTU1+tOf/qTCwkKvI8VUJBJROBz2OkbUJk2apH379qmhoaFtGz16tEpKStTQ0KDk5GSvI7qmqalJH330kXJzc72OErXx48dfsoz/4MGD6tevn2eZYmHVqlXq06ePJk+e7HUU15w/f15JSe3flpKTkxWJRDzLFAvp6enKzc3VmTNntHnzZk2dOtXrSK4pLCxUTk6OtmzZ0rYvFApp165dnb420cxHPE1NTe3+9XbkyBE1NDQoKytLBQUFnmaLVllZmdauXasNGzYoIyOj7XO4QCCgtLQ0r+NFpaKiQsXFxSooKNC5c+e0du1abdu2TZs3b/Y6WtQyMjIuuU4oPT1dvXv3jvvrh37yk59oypQp6tevn44fP64nnnhCycnJmjFjhtfRovbwww9r3LhxWrJkie6++27t3r1bK1eu1MqVK72O5ppIJKJVq1aptLRUKSlm/hqP2pQpU7R48WIVFBRoyJAh+uCDD/Tcc89p9uzZXkdzxebNm+U4jgYNGqTDhw/rkUce0eDBgzVr1iyvo3XK1d6vy8vL9dRTT2ngwIEqLCxUZWWl8vLyNG3atM4dyNX1RlHYunWrI+mSrbS01OtoUbvceUlyVq1a5XW0qM2ePdvp16+fk5qa6lx//fXOpEmTnD/84Q9ex4qZRFlmfM899zi5ublOamqqc8MNNzj33HOPc/jwYa9jueb3v/+9M3ToUMfv9zuDBw92Vq5c6XUkV23evNmR5Bw4cMDrKK4KhULO3LlznYKCAqdnz55O//79nUcffdQJh8NeR3PFm2++6fTv399JTU11cnJynLKyMufs2bNex+q0q71fRyIRp7Ky0snOznb8fr8zadKkL/Vn1eckyi36AABAwuAaFAAAYA4FBQAAmENBAQAA5lBQAACAORQUAABgDgUFAACYQ0EBAADmUFAAAIA5FBQAAGAOBQUAAJhDQQEAAOZQUAAAgDn/H/iu+7rgR1tVAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"batch_src, batch_labels, batch_padding_mask = map(lambda x: x.to(device), mktunebatch(BSZ, test=True))\n",
"model.eval()\n",
"with torch.no_grad():\n",
" output = model(batch_src, batch_padding_mask)\n",
"print(criterion(output.squeeze(1), batch_labels).item())\n",
"x = batch_labels.detach().to(torch.float16).cpu().numpy().flatten()\n",
"y = output.detach().to(torch.float16).cpu().numpy().flatten()\n",
"plt.hist2d(x, y, bins=50, norm=mpl.colors.LogNorm())"
]
}
],
"metadata": {
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|