1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_state": "idle",
"id": "ecaeb29e-fbbe-4876-86ad-9fbadea989b0",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"assert device.type == \"cuda\", \"CUDA is not available. Please check your GPU setup.\""
]
},
{
"cell_type": "code",
"execution_count": 110,
"execution_state": "idle",
"id": "84a82827-8947-4a26-a485-56f5b1eadb4c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[4, 9, 4, 1, 8, 2],\n",
" [0, 6, 5, 9, 1, 4],\n",
" [4, 9, 1, 0, 5, 5],\n",
" [5, 2, 4, 9, 1, 8],\n",
" [2, 7, 6, 8, 0, 1]], device='cuda:0'),\n",
" tensor([[5, 7, 7],\n",
" [9, 7, 9],\n",
" [4, 4, 7],\n",
" [4, 4, 2],\n",
" [0, 8, 7]], device='cuda:0'))"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"NUM_LEN = 3\n",
"\n",
"def pad(a):\n",
" s = str(a)\n",
" if len(s) > NUM_LEN:\n",
" return s[-NUM_LEN:]\n",
" return s.zfill(NUM_LEN)\n",
"\n",
"def mkbatch_ltr(size):\n",
" data = []\n",
" labels = []\n",
" for i in range(size):\n",
" a = random.randrange(0, 10**NUM_LEN)\n",
" b = random.randrange(0, 10**NUM_LEN)\n",
" c = a + b\n",
" data.append(list(map(int, pad(a) + pad(b))))\n",
" labels.append(list(map(int, pad(c))))\n",
" return torch.tensor(data, device=device), torch.tensor(labels, device=device)\n",
"\n",
"def mkbatch_rtl(size):\n",
" data, labels = mkbatch_ltr(size)\n",
" return torch.flip(data, (1,)), torch.flip(labels, (1,))\n",
"\n",
"mkbatch_rtl(5)"
]
},
{
"cell_type": "code",
"execution_count": 111,
"execution_state": "idle",
"id": "d50dce44-57b7-4d4d-895a-c2275c04234c",
"metadata": {},
"outputs": [],
"source": [
"class TransformerModel(nn.Module):\n",
" def __init__(self, input_dim, model_dim, output_dim, nheads, nenclayers, ndeclayers):\n",
" super().__init__()\n",
" self.emb = nn.Embedding(input_dim, model_dim - 1)\n",
" self.trans = nn.Transformer(d_model=model_dim, nhead=nheads, dim_feedforward=4 * model_dim,\n",
" num_encoder_layers=nenclayers, num_decoder_layers=ndeclayers,\n",
" dropout=0, batch_first=True)\n",
" self.output = nn.Linear(model_dim, output_dim)\n",
"\n",
" def forward(self, data, labels):\n",
" bsz = data.size(0)\n",
" data_pos = (torch.arange(2 * NUM_LEN, device=device) % NUM_LEN).expand(bsz, -1)\n",
" labels_pos = (torch.arange(NUM_LEN, device=device)).expand(bsz, -1)\n",
" data_emb = torch.cat((self.emb(data), data_pos.unsqueeze(2)), 2)\n",
" labels_emb = torch.cat((self.emb(labels), labels_pos.unsqueeze(2)), 2)\n",
" return self.output(self.trans(data_emb, labels_emb, tgt_mask=TGT_MASK, tgt_is_causal=True))"
]
},
{
"cell_type": "code",
"execution_count": 118,
"execution_state": "idle",
"id": "ddad4059-b06e-4eb3-a55a-5a4a842cdd7a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training data: 32768K\n",
"Trainable parameters in the model: 1251\n"
]
}
],
"source": [
"MODEL_DIM = 4 # Dimension of model\n",
"VOCAB_SIZE = 10\n",
"NEPOCHS = 1000\n",
"BSZ = 2**15 # Batch size\n",
"NHEADS = 1\n",
"NENCLAYERS = 2\n",
"NDECLAYERS = 2\n",
"\n",
"LR = 1e-2\n",
"\n",
"TGT_MASK = nn.Transformer.generate_square_subsequent_mask(NUM_LEN)\n",
"model = TransformerModel(VOCAB_SIZE + 1, MODEL_DIM, VOCAB_SIZE, NHEADS, NENCLAYERS, NDECLAYERS).to(device)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
"\n",
"train_err = []\n",
"open('loss', 'w').close()\n",
"\n",
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Training data: {NEPOCHS*BSZ//10**3}K\")\n",
"print(f\"Trainable parameters in the model: {trainable_params}\")"
]
},
{
"cell_type": "code",
"execution_count": 119,
"execution_state": "idle",
"id": "689f2e44-da84-43ea-b539-414d6f5c37e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0/1000 \t Train Err: 2.4793\n",
"Epoch 1/1000 \t Train Err: 2.4310\n",
"Epoch 2/1000 \t Train Err: 2.3800\n",
"Epoch 3/1000 \t Train Err: 2.3493\n",
"Epoch 4/1000 \t Train Err: 2.3288\n",
"Epoch 5/1000 \t Train Err: 2.3202\n",
"Epoch 6/1000 \t Train Err: 2.3171\n",
"Epoch 7/1000 \t Train Err: 2.3139\n",
"Epoch 8/1000 \t Train Err: 2.3095\n",
"Epoch 9/1000 \t Train Err: 2.3064\n",
"Epoch 10/1000 \t Train Err: 2.3040\n",
"Epoch 11/1000 \t Train Err: 2.3029\n",
"Epoch 12/1000 \t Train Err: 2.3030\n",
"Epoch 13/1000 \t Train Err: 2.3037\n",
"Epoch 14/1000 \t Train Err: 2.3047\n",
"Epoch 15/1000 \t Train Err: 2.3060\n",
"Epoch 16/1000 \t Train Err: 2.3067\n",
"Epoch 17/1000 \t Train Err: 2.3067\n",
"Epoch 18/1000 \t Train Err: 2.3068\n",
"Epoch 19/1000 \t Train Err: 2.3059\n",
"Epoch 20/1000 \t Train Err: 2.3060\n",
"Epoch 21/1000 \t Train Err: 2.3052\n",
"Epoch 22/1000 \t Train Err: 2.3044\n",
"Epoch 23/1000 \t Train Err: 2.3039\n",
"Epoch 24/1000 \t Train Err: 2.3039\n",
"Epoch 25/1000 \t Train Err: 2.3033\n",
"Epoch 26/1000 \t Train Err: 2.3032\n",
"Epoch 27/1000 \t Train Err: 2.3032\n",
"Epoch 28/1000 \t Train Err: 2.3032\n",
"Epoch 29/1000 \t Train Err: 2.3029\n",
"Epoch 30/1000 \t Train Err: 2.3028\n",
"Epoch 31/1000 \t Train Err: 2.3032\n",
"Epoch 32/1000 \t Train Err: 2.3031\n",
"Epoch 33/1000 \t Train Err: 2.3030\n",
"Epoch 34/1000 \t Train Err: 2.3031\n",
"Epoch 35/1000 \t Train Err: 2.3031\n",
"Epoch 36/1000 \t Train Err: 2.3031\n",
"Epoch 37/1000 \t Train Err: 2.3029\n",
"Epoch 38/1000 \t Train Err: 2.3033\n",
"Epoch 39/1000 \t Train Err: 2.3032\n",
"Epoch 40/1000 \t Train Err: 2.3031\n",
"Epoch 41/1000 \t Train Err: 2.3030\n",
"Epoch 42/1000 \t Train Err: 2.3027\n",
"Epoch 43/1000 \t Train Err: 2.3027\n",
"Epoch 44/1000 \t Train Err: 2.3027\n",
"Epoch 45/1000 \t Train Err: 2.3027\n",
"Epoch 46/1000 \t Train Err: 2.3026\n",
"Epoch 47/1000 \t Train Err: 2.3026\n",
"Epoch 48/1000 \t Train Err: 2.3027\n",
"Epoch 49/1000 \t Train Err: 2.3026\n",
"Epoch 50/1000 \t Train Err: 2.3027\n",
"Epoch 51/1000 \t Train Err: 2.3027\n",
"Epoch 52/1000 \t Train Err: 2.3027\n",
"Epoch 53/1000 \t Train Err: 2.3026\n",
"Epoch 54/1000 \t Train Err: 2.3028\n",
"Epoch 55/1000 \t Train Err: 2.3027\n",
"Epoch 56/1000 \t Train Err: 2.3026\n",
"Epoch 57/1000 \t Train Err: 2.3027\n",
"Epoch 58/1000 \t Train Err: 2.3026\n",
"Epoch 59/1000 \t Train Err: 2.3026\n",
"Epoch 60/1000 \t Train Err: 2.3027\n",
"Epoch 61/1000 \t Train Err: 2.3026\n",
"Epoch 62/1000 \t Train Err: 2.3026\n",
"Epoch 500/1000 \t Train Err: 2.3026\n",
"Epoch 501/1000 \t Train Err: 2.3025\n",
"Epoch 502/1000 \t Train Err: 2.3026\n",
"Epoch 503/1000 \t Train Err: 2.3026\n",
"Epoch 504/1000 \t Train Err: 2.3026\n",
"Epoch 505/1000 \t Train Err: 2.3026\n",
"Epoch 506/1000 \t Train Err: 2.3026\n",
"Epoch 507/1000 \t Train Err: 2.3026\n",
"Epoch 508/1000 \t Train Err: 2.3026\n",
"Epoch 509/1000 \t Train Err: 2.3026\n",
"Epoch 510/1000 \t Train Err: 2.3026\n",
"Epoch 511/1000 \t Train Err: 2.3026\n",
"Epoch 512/1000 \t Train Err: 2.3026\n",
"Epoch 513/1000 \t Train Err: 2.3026\n",
"Epoch 514/1000 \t Train Err: 2.3025\n",
"Epoch 515/1000 \t Train Err: 2.3027\n",
"Epoch 516/1000 \t Train Err: 2.3027\n",
"Epoch 517/1000 \t Train Err: 2.3027\n",
"Epoch 518/1000 \t Train Err: 2.3026\n",
"Epoch 519/1000 \t Train Err: 2.3026\n",
"Epoch 520/1000 \t Train Err: 2.3026\n",
"Epoch 521/1000 \t Train Err: 2.3025\n",
"Epoch 522/1000 \t Train Err: 2.3027\n",
"Epoch 523/1000 \t Train Err: 2.3027\n",
"Epoch 524/1000 \t Train Err: 2.3027\n",
"Epoch 525/1000 \t Train Err: 2.3026\n",
"Epoch 526/1000 \t Train Err: 2.3026\n",
"Epoch 527/1000 \t Train Err: 2.3026\n",
"Epoch 528/1000 \t Train Err: 2.3027\n",
"Epoch 529/1000 \t Train Err: 2.3028\n",
"Epoch 530/1000 \t Train Err: 2.3026\n",
"Epoch 531/1000 \t Train Err: 2.3026\n",
"Epoch 532/1000 \t Train Err: 2.3026\n",
"Epoch 533/1000 \t Train Err: 2.3026\n",
"Epoch 534/1000 \t Train Err: 2.3026\n",
"Epoch 535/1000 \t Train Err: 2.3026\n",
"Epoch 536/1000 \t Train Err: 2.3027\n",
"Epoch 537/1000 \t Train Err: 2.3027\n",
"Epoch 538/1000 \t Train Err: 2.3025\n",
"Epoch 539/1000 \t Train Err: 2.3026\n",
"Epoch 540/1000 \t Train Err: 2.3026\n",
"Epoch 541/1000 \t Train Err: 2.3026\n",
"Epoch 542/1000 \t Train Err: 2.3026\n",
"Epoch 543/1000 \t Train Err: 2.3026\n",
"Epoch 544/1000 \t Train Err: 2.3026\n",
"Epoch 545/1000 \t Train Err: 2.3026\n",
"Epoch 546/1000 \t Train Err: 2.3027\n",
"Epoch 547/1000 \t Train Err: 2.3026\n",
"Epoch 548/1000 \t Train Err: 2.3026\n",
"Epoch 549/1000 \t Train Err: 2.3026\n",
"Epoch 550/1000 \t Train Err: 2.3026\n",
"Epoch 551/1000 \t Train Err: 2.3026\n",
"Epoch 552/1000 \t Train Err: 2.3028\n",
"Epoch 553/1000 \t Train Err: 2.3028\n",
"Epoch 554/1000 \t Train Err: 2.3027\n",
"Epoch 555/1000 \t Train Err: 2.3026\n",
"Epoch 556/1000 \t Train Err: 2.3027\n",
"Epoch 557/1000 \t Train Err: 2.3027\n",
"Epoch 558/1000 \t Train Err: 2.3028\n",
"Epoch 559/1000 \t Train Err: 2.3026\n",
"Epoch 560/1000 \t Train Err: 2.3026\n",
"Epoch 561/1000 \t Train Err: 2.3026\n",
"Epoch 562/1000 \t Train Err: 2.3027\n",
"Epoch 563/1000 \t Train Err: 2.3027\n",
"Epoch 564/1000 \t Train Err: 2.3027\n",
"Epoch 565/1000 \t Train Err: 2.3025\n",
"Epoch 566/1000 \t Train Err: 2.3026\n",
"Epoch 567/1000 \t Train Err: 2.3026\n",
"Epoch 568/1000 \t Train Err: 2.3026\n",
"Epoch 569/1000 \t Train Err: 2.3026\n",
"Epoch 570/1000 \t Train Err: 2.3026\n",
"Epoch 571/1000 \t Train Err: 2.3026\n",
"Epoch 572/1000 \t Train Err: 2.3026\n",
"Epoch 573/1000 \t Train Err: 2.3026\n",
"Epoch 574/1000 \t Train Err: 2.3026\n",
"Epoch 575/1000 \t Train Err: 2.3028\n",
"Epoch 576/1000 \t Train Err: 2.3026\n",
"Epoch 577/1000 \t Train Err: 2.3026\n",
"Epoch 578/1000 \t Train Err: 2.3025\n",
"Epoch 579/1000 \t Train Err: 2.3026\n",
"Epoch 580/1000 \t Train Err: 2.3026\n",
"Epoch 581/1000 \t Train Err: 2.3027\n",
"Epoch 582/1000 \t Train Err: 2.3026\n",
"Epoch 583/1000 \t Train Err: 2.3027\n",
"Epoch 584/1000 \t Train Err: 2.3027\n",
"Epoch 585/1000 \t Train Err: 2.3026\n",
"Epoch 586/1000 \t Train Err: 2.3026\n",
"Epoch 587/1000 \t Train Err: 2.3026\n",
"Epoch 588/1000 \t Train Err: 2.3026\n",
"Epoch 589/1000 \t Train Err: 2.3027\n",
"Epoch 590/1000 \t Train Err: 2.3026\n",
"Epoch 591/1000 \t Train Err: 2.3026\n",
"Epoch 592/1000 \t Train Err: 2.3026\n",
"Epoch 593/1000 \t Train Err: 2.3026\n",
"Epoch 594/1000 \t Train Err: 2.3026\n",
"Epoch 595/1000 \t Train Err: 2.3026\n",
"Epoch 596/1000 \t Train Err: 2.3026\n",
"Epoch 597/1000 \t Train Err: 2.3027\n",
"Epoch 598/1000 \t Train Err: 2.3026\n",
"Epoch 599/1000 \t Train Err: 2.3027\n",
"Epoch 600/1000 \t Train Err: 2.3027\n",
"Epoch 601/1000 \t Train Err: 2.3026\n",
"Epoch 602/1000 \t Train Err: 2.3026\n",
"Epoch 603/1000 \t Train Err: 2.3026\n",
"Epoch 604/1000 \t Train Err: 2.3026\n",
"Epoch 605/1000 \t Train Err: 2.3026\n",
"Epoch 606/1000 \t Train Err: 2.3027\n",
"Epoch 607/1000 \t Train Err: 2.3026\n",
"Epoch 608/1000 \t Train Err: 2.3026\n",
"Epoch 609/1000 \t Train Err: 2.3026\n",
"Epoch 610/1000 \t Train Err: 2.3026\n",
"Epoch 611/1000 \t Train Err: 2.3026\n",
"Epoch 612/1000 \t Train Err: 2.3027\n",
"Epoch 613/1000 \t Train Err: 2.3025\n",
"Epoch 614/1000 \t Train Err: 2.3026\n",
"Epoch 615/1000 \t Train Err: 2.3026\n",
"Epoch 616/1000 \t Train Err: 2.3026\n",
"Epoch 617/1000 \t Train Err: 2.3026\n",
"Epoch 618/1000 \t Train Err: 2.3026\n",
"Epoch 619/1000 \t Train Err: 2.3026\n",
"Epoch 620/1000 \t Train Err: 2.3026\n",
"Epoch 621/1000 \t Train Err: 2.3026\n",
"Epoch 622/1000 \t Train Err: 2.3026\n",
"Epoch 623/1000 \t Train Err: 2.3026\n",
"Epoch 624/1000 \t Train Err: 2.3026\n",
"Epoch 625/1000 \t Train Err: 2.3026\n",
"Epoch 626/1000 \t Train Err: 2.3026\n",
"Epoch 627/1000 \t Train Err: 2.3026\n",
"Epoch 628/1000 \t Train Err: 2.3026\n",
"Epoch 629/1000 \t Train Err: 2.3026\n",
"Epoch 630/1000 \t Train Err: 2.3027\n",
"Epoch 631/1000 \t Train Err: 2.3026\n",
"Epoch 632/1000 \t Train Err: 2.3026\n",
"Epoch 633/1000 \t Train Err: 2.3025\n",
"Epoch 634/1000 \t Train Err: 2.3026\n",
"Epoch 635/1000 \t Train Err: 2.3026\n",
"Epoch 636/1000 \t Train Err: 2.3026\n",
"Epoch 637/1000 \t Train Err: 2.3026\n",
"Epoch 638/1000 \t Train Err: 2.3026\n",
"Epoch 639/1000 \t Train Err: 2.3027\n",
"Epoch 640/1000 \t Train Err: 2.3026\n",
"Epoch 641/1000 \t Train Err: 2.3026\n",
"Epoch 642/1000 \t Train Err: 2.3026\n",
"Epoch 643/1000 \t Train Err: 2.3026\n",
"Epoch 644/1000 \t Train Err: 2.3027\n",
"Epoch 645/1000 \t Train Err: 2.3026\n",
"Epoch 646/1000 \t Train Err: 2.3026\n",
"Epoch 647/1000 \t Train Err: 2.3025\n",
"Epoch 648/1000 \t Train Err: 2.3026\n",
"Epoch 649/1000 \t Train Err: 2.3026\n",
"Epoch 650/1000 \t Train Err: 2.3025\n",
"Epoch 651/1000 \t Train Err: 2.3026\n",
"Epoch 652/1000 \t Train Err: 2.3025\n",
"Epoch 653/1000 \t Train Err: 2.3026\n",
"Epoch 654/1000 \t Train Err: 2.3026\n",
"Epoch 655/1000 \t Train Err: 2.3026\n",
"Epoch 656/1000 \t Train Err: 2.3026\n",
"Epoch 657/1000 \t Train Err: 2.3025\n",
"Epoch 658/1000 \t Train Err: 2.3026\n",
"Epoch 659/1000 \t Train Err: 2.3025\n",
"Epoch 660/1000 \t Train Err: 2.3025\n",
"Epoch 661/1000 \t Train Err: 2.3025\n",
"Epoch 662/1000 \t Train Err: 2.3026\n",
"Epoch 663/1000 \t Train Err: 2.3026\n",
"Epoch 664/1000 \t Train Err: 2.3025\n",
"Epoch 665/1000 \t Train Err: 2.3026\n",
"Epoch 666/1000 \t Train Err: 2.3026\n",
"Epoch 667/1000 \t Train Err: 2.3025\n",
"Epoch 668/1000 \t Train Err: 2.3026\n",
"Epoch 669/1000 \t Train Err: 2.3026\n",
"Epoch 670/1000 \t Train Err: 2.3025\n",
"Epoch 671/1000 \t Train Err: 2.3026\n",
"Epoch 672/1000 \t Train Err: 2.3025\n",
"Epoch 673/1000 \t Train Err: 2.3024\n",
"Epoch 674/1000 \t Train Err: 2.3024\n",
"Epoch 675/1000 \t Train Err: 2.3024\n",
"Epoch 676/1000 \t Train Err: 2.3024\n",
"Epoch 677/1000 \t Train Err: 2.3023\n",
"Epoch 678/1000 \t Train Err: 2.3024\n",
"Epoch 679/1000 \t Train Err: 2.3022\n",
"Epoch 680/1000 \t Train Err: 2.3022\n",
"Epoch 681/1000 \t Train Err: 2.3022\n",
"Epoch 682/1000 \t Train Err: 2.3020\n",
"Epoch 683/1000 \t Train Err: 2.3018\n",
"Epoch 684/1000 \t Train Err: 2.3016\n",
"Epoch 685/1000 \t Train Err: 2.3014\n",
"Epoch 686/1000 \t Train Err: 2.3011\n",
"Epoch 687/1000 \t Train Err: 2.3007\n",
"Epoch 688/1000 \t Train Err: 2.3007\n",
"Epoch 689/1000 \t Train Err: 2.2999\n",
"Epoch 690/1000 \t Train Err: 2.2999\n",
"Epoch 691/1000 \t Train Err: 2.2993\n",
"Epoch 692/1000 \t Train Err: 2.2993\n",
"Epoch 693/1000 \t Train Err: 2.2988\n",
"Epoch 694/1000 \t Train Err: 2.2987\n",
"Epoch 695/1000 \t Train Err: 2.2983\n",
"Epoch 696/1000 \t Train Err: 2.2976\n",
"Epoch 697/1000 \t Train Err: 2.2974\n",
"Epoch 698/1000 \t Train Err: 2.2969\n",
"Epoch 699/1000 \t Train Err: 2.2975\n",
"Epoch 700/1000 \t Train Err: 2.2955\n",
"Epoch 701/1000 \t Train Err: 2.2967\n",
"Epoch 702/1000 \t Train Err: 2.2958\n",
"Epoch 703/1000 \t Train Err: 2.2933\n",
"Epoch 704/1000 \t Train Err: 2.2951\n",
"Epoch 705/1000 \t Train Err: 2.2939\n",
"Epoch 706/1000 \t Train Err: 2.2922\n",
"Epoch 707/1000 \t Train Err: 2.2919\n",
"Epoch 708/1000 \t Train Err: 2.2901\n",
"Epoch 709/1000 \t Train Err: 2.2897\n",
"Epoch 710/1000 \t Train Err: 2.2867\n",
"Epoch 711/1000 \t Train Err: 2.2855\n",
"Epoch 712/1000 \t Train Err: 2.2841\n",
"Epoch 713/1000 \t Train Err: 2.2844\n",
"Epoch 714/1000 \t Train Err: 2.2812\n",
"Epoch 715/1000 \t Train Err: 2.2801\n",
"Epoch 716/1000 \t Train Err: 2.2789\n",
"Epoch 717/1000 \t Train Err: 2.2761\n",
"Epoch 718/1000 \t Train Err: 2.2797\n",
"Epoch 719/1000 \t Train Err: 2.2796\n",
"Epoch 720/1000 \t Train Err: 2.2974\n",
"Epoch 721/1000 \t Train Err: 2.2786\n",
"Epoch 722/1000 \t Train Err: 2.2802\n",
"Epoch 723/1000 \t Train Err: 2.2805\n",
"Epoch 724/1000 \t Train Err: 2.2812\n",
"Epoch 725/1000 \t Train Err: 2.2812\n",
"Epoch 726/1000 \t Train Err: 2.2792\n",
"Epoch 727/1000 \t Train Err: 2.2780\n",
"Epoch 728/1000 \t Train Err: 2.2775\n",
"Epoch 729/1000 \t Train Err: 2.2750\n",
"Epoch 730/1000 \t Train Err: 2.2821\n",
"Epoch 731/1000 \t Train Err: 2.2815\n",
"Epoch 732/1000 \t Train Err: 2.2812\n",
"Epoch 733/1000 \t Train Err: 2.2779\n",
"Epoch 734/1000 \t Train Err: 2.2777\n",
"Epoch 735/1000 \t Train Err: 2.2799\n",
"Epoch 736/1000 \t Train Err: 2.2754\n",
"Epoch 737/1000 \t Train Err: 2.2742\n",
"Epoch 738/1000 \t Train Err: 2.2723\n",
"Epoch 739/1000 \t Train Err: 2.2719\n",
"Epoch 740/1000 \t Train Err: 2.2674\n",
"Epoch 741/1000 \t Train Err: 2.2694\n",
"Epoch 742/1000 \t Train Err: 2.2702\n",
"Epoch 743/1000 \t Train Err: 2.2693\n",
"Epoch 744/1000 \t Train Err: 2.2722\n",
"Epoch 745/1000 \t Train Err: 2.2704\n",
"Epoch 746/1000 \t Train Err: 2.2675\n",
"Epoch 747/1000 \t Train Err: 2.2644\n",
"Epoch 748/1000 \t Train Err: 2.2599\n",
"Epoch 749/1000 \t Train Err: 2.2583\n",
"Epoch 750/1000 \t Train Err: 2.2578\n",
"Epoch 751/1000 \t Train Err: 2.2507\n",
"Epoch 752/1000 \t Train Err: 2.2490\n",
"Epoch 753/1000 \t Train Err: 2.2501\n",
"Epoch 754/1000 \t Train Err: 2.2502\n",
"Epoch 755/1000 \t Train Err: 2.2520\n",
"Epoch 756/1000 \t Train Err: 2.2435\n",
"Epoch 757/1000 \t Train Err: 2.2432\n",
"Epoch 758/1000 \t Train Err: 2.2420\n",
"Epoch 759/1000 \t Train Err: 2.2393\n",
"Epoch 760/1000 \t Train Err: 2.2372\n",
"Epoch 761/1000 \t Train Err: 2.2302\n",
"Epoch 762/1000 \t Train Err: 2.2302\n",
"Epoch 763/1000 \t Train Err: 2.2294\n",
"Epoch 764/1000 \t Train Err: 2.2201\n",
"Epoch 765/1000 \t Train Err: 2.2195\n",
"Epoch 766/1000 \t Train Err: 2.2166\n",
"Epoch 767/1000 \t Train Err: 2.2139\n",
"Epoch 768/1000 \t Train Err: 2.2096\n",
"Epoch 769/1000 \t Train Err: 2.2100\n",
"Epoch 770/1000 \t Train Err: 2.2073\n",
"Epoch 771/1000 \t Train Err: 2.2058\n",
"Epoch 772/1000 \t Train Err: 2.2096\n",
"Epoch 773/1000 \t Train Err: 2.2055\n",
"Epoch 774/1000 \t Train Err: 2.2213\n",
"Epoch 775/1000 \t Train Err: 2.2435\n",
"Epoch 776/1000 \t Train Err: 2.2282\n",
"Epoch 777/1000 \t Train Err: 2.2328\n",
"Epoch 778/1000 \t Train Err: 2.2254\n",
"Epoch 779/1000 \t Train Err: 2.2246\n",
"Epoch 780/1000 \t Train Err: 2.2241\n",
"Epoch 781/1000 \t Train Err: 2.2217\n",
"Epoch 782/1000 \t Train Err: 2.2156\n",
"Epoch 783/1000 \t Train Err: 2.2219\n",
"Epoch 784/1000 \t Train Err: 2.2151\n",
"Epoch 785/1000 \t Train Err: 2.2259\n",
"Epoch 786/1000 \t Train Err: 2.2226\n",
"Epoch 787/1000 \t Train Err: 2.2176\n",
"Epoch 788/1000 \t Train Err: 2.2152\n",
"Epoch 789/1000 \t Train Err: 2.2099\n",
"Epoch 790/1000 \t Train Err: 2.2069\n",
"Epoch 791/1000 \t Train Err: 2.2034\n",
"Epoch 792/1000 \t Train Err: 2.2080\n",
"Epoch 793/1000 \t Train Err: 2.1999\n",
"Epoch 794/1000 \t Train Err: 2.1925\n",
"Epoch 795/1000 \t Train Err: 2.1840\n",
"Epoch 796/1000 \t Train Err: 2.1820\n",
"Epoch 797/1000 \t Train Err: 2.1907\n",
"Epoch 798/1000 \t Train Err: 2.1835\n",
"Epoch 799/1000 \t Train Err: 2.1886\n",
"Epoch 800/1000 \t Train Err: 2.1807\n",
"Epoch 801/1000 \t Train Err: 2.1841\n",
"Epoch 802/1000 \t Train Err: 2.1776\n",
"Epoch 803/1000 \t Train Err: 2.1800\n",
"Epoch 804/1000 \t Train Err: 2.1715\n",
"Epoch 805/1000 \t Train Err: 2.1717\n",
"Epoch 806/1000 \t Train Err: 2.1701\n",
"Epoch 807/1000 \t Train Err: 2.1635\n",
"Epoch 808/1000 \t Train Err: 2.1664\n",
"Epoch 809/1000 \t Train Err: 2.1603\n",
"Epoch 810/1000 \t Train Err: 2.1636\n",
"Epoch 811/1000 \t Train Err: 2.1575\n",
"Epoch 812/1000 \t Train Err: 2.1587\n",
"Epoch 813/1000 \t Train Err: 2.1559\n",
"Epoch 814/1000 \t Train Err: 2.1540\n",
"Epoch 815/1000 \t Train Err: 2.1537\n",
"Epoch 816/1000 \t Train Err: 2.1514\n",
"Epoch 817/1000 \t Train Err: 2.1500\n",
"Epoch 818/1000 \t Train Err: 2.1488\n",
"Epoch 819/1000 \t Train Err: 2.1475\n",
"Epoch 820/1000 \t Train Err: 2.1447\n",
"Epoch 821/1000 \t Train Err: 2.1434\n",
"Epoch 822/1000 \t Train Err: 2.1431\n",
"Epoch 823/1000 \t Train Err: 2.1441\n",
"Epoch 824/1000 \t Train Err: 2.1816\n",
"Epoch 825/1000 \t Train Err: 2.1863\n",
"Epoch 826/1000 \t Train Err: 2.1601\n",
"Epoch 827/1000 \t Train Err: 2.1623\n",
"Epoch 828/1000 \t Train Err: 2.1957\n",
"Epoch 829/1000 \t Train Err: 2.1775\n",
"Epoch 830/1000 \t Train Err: 2.1971\n",
"Epoch 831/1000 \t Train Err: 2.1851\n",
"Epoch 832/1000 \t Train Err: 2.1738\n",
"Epoch 833/1000 \t Train Err: 2.1654\n",
"Epoch 834/1000 \t Train Err: 2.1627\n",
"Epoch 835/1000 \t Train Err: 2.1606\n",
"Epoch 836/1000 \t Train Err: 2.1487\n",
"Epoch 837/1000 \t Train Err: 2.1494\n",
"Epoch 838/1000 \t Train Err: 2.1563\n",
"Epoch 839/1000 \t Train Err: 2.1521\n",
"Epoch 840/1000 \t Train Err: 2.1515\n",
"Epoch 841/1000 \t Train Err: 2.1484\n",
"Epoch 842/1000 \t Train Err: 2.1476\n",
"Epoch 843/1000 \t Train Err: 2.1406\n",
"Epoch 844/1000 \t Train Err: 2.1410\n",
"Epoch 845/1000 \t Train Err: 2.1359\n",
"Epoch 846/1000 \t Train Err: 2.1344\n",
"Epoch 847/1000 \t Train Err: 2.1323\n",
"Epoch 848/1000 \t Train Err: 2.1236\n",
"Epoch 849/1000 \t Train Err: 2.1241\n",
"Epoch 850/1000 \t Train Err: 2.1162\n",
"Epoch 851/1000 \t Train Err: 2.1179\n",
"Epoch 852/1000 \t Train Err: 2.1033\n",
"Epoch 853/1000 \t Train Err: 2.1022\n",
"Epoch 854/1000 \t Train Err: 2.1009\n",
"Epoch 855/1000 \t Train Err: 2.0978\n",
"Epoch 856/1000 \t Train Err: 2.0911\n",
"Epoch 857/1000 \t Train Err: 2.0932\n",
"Epoch 858/1000 \t Train Err: 2.0898\n",
"Epoch 859/1000 \t Train Err: 2.0844\n",
"Epoch 860/1000 \t Train Err: 2.0767\n",
"Epoch 861/1000 \t Train Err: 2.0732\n",
"Epoch 862/1000 \t Train Err: 2.0769\n",
"Epoch 863/1000 \t Train Err: 2.0725\n",
"Epoch 864/1000 \t Train Err: 2.0700\n",
"Epoch 865/1000 \t Train Err: 2.0612\n",
"Epoch 866/1000 \t Train Err: 2.0637\n",
"Epoch 867/1000 \t Train Err: 2.0580\n",
"Epoch 868/1000 \t Train Err: 2.0598\n",
"Epoch 869/1000 \t Train Err: 2.0535\n",
"Epoch 870/1000 \t Train Err: 2.0503\n",
"Epoch 871/1000 \t Train Err: 2.0492\n",
"Epoch 872/1000 \t Train Err: 2.0431\n",
"Epoch 873/1000 \t Train Err: 2.0423\n",
"Epoch 874/1000 \t Train Err: 2.0382\n",
"Epoch 875/1000 \t Train Err: 2.0328\n",
"Epoch 876/1000 \t Train Err: 2.0313\n",
"Epoch 877/1000 \t Train Err: 2.0280\n",
"Epoch 878/1000 \t Train Err: 2.0297\n",
"Epoch 879/1000 \t Train Err: 2.0243\n",
"Epoch 880/1000 \t Train Err: 2.0243\n",
"Epoch 881/1000 \t Train Err: 2.0222\n",
"Epoch 882/1000 \t Train Err: 2.0209\n",
"Epoch 883/1000 \t Train Err: 2.0161\n",
"Epoch 884/1000 \t Train Err: 2.0157\n",
"Epoch 885/1000 \t Train Err: 2.0253\n",
"Epoch 886/1000 \t Train Err: 2.0697\n",
"Epoch 887/1000 \t Train Err: 2.2021\n",
"Epoch 888/1000 \t Train Err: 2.2692\n",
"Epoch 889/1000 \t Train Err: 2.1106\n",
"Epoch 890/1000 \t Train Err: 2.1653\n",
"Epoch 891/1000 \t Train Err: 2.2021\n",
"Epoch 892/1000 \t Train Err: 2.1370\n",
"Epoch 893/1000 \t Train Err: 2.1576\n",
"Epoch 894/1000 \t Train Err: 2.1296\n",
"Epoch 895/1000 \t Train Err: 2.1303\n",
"Epoch 896/1000 \t Train Err: 2.1201\n",
"Epoch 897/1000 \t Train Err: 2.1001\n",
"Epoch 898/1000 \t Train Err: 2.1209\n",
"Epoch 899/1000 \t Train Err: 2.1034\n",
"Epoch 900/1000 \t Train Err: 2.1103\n",
"Epoch 901/1000 \t Train Err: 2.0983\n",
"Epoch 902/1000 \t Train Err: 2.0762\n",
"Epoch 903/1000 \t Train Err: 2.0929\n",
"Epoch 904/1000 \t Train Err: 2.0643\n",
"Epoch 905/1000 \t Train Err: 2.0555\n",
"Epoch 906/1000 \t Train Err: 2.0589\n",
"Epoch 907/1000 \t Train Err: 2.0454\n",
"Epoch 908/1000 \t Train Err: 2.0500\n",
"Epoch 909/1000 \t Train Err: 2.0418\n",
"Epoch 910/1000 \t Train Err: 2.0363\n",
"Epoch 911/1000 \t Train Err: 2.0357\n",
"Epoch 912/1000 \t Train Err: 2.0323\n",
"Epoch 913/1000 \t Train Err: 2.0282\n",
"Epoch 914/1000 \t Train Err: 2.0242\n",
"Epoch 915/1000 \t Train Err: 2.0120\n",
"Epoch 916/1000 \t Train Err: 2.0127\n",
"Epoch 917/1000 \t Train Err: 2.0133\n",
"Epoch 918/1000 \t Train Err: 2.0097\n",
"Epoch 919/1000 \t Train Err: 2.0087\n",
"Epoch 920/1000 \t Train Err: 2.0099\n",
"Epoch 921/1000 \t Train Err: 2.0076\n",
"Epoch 922/1000 \t Train Err: 2.0020\n",
"Epoch 923/1000 \t Train Err: 1.9990\n",
"Epoch 924/1000 \t Train Err: 1.9967\n",
"Epoch 925/1000 \t Train Err: 1.9966\n",
"Epoch 926/1000 \t Train Err: 1.9946\n",
"Epoch 927/1000 \t Train Err: 1.9904\n",
"Epoch 928/1000 \t Train Err: 1.9874\n",
"Epoch 929/1000 \t Train Err: 1.9974\n",
"Epoch 930/1000 \t Train Err: 1.9857\n",
"Epoch 931/1000 \t Train Err: 1.9892\n",
"Epoch 932/1000 \t Train Err: 1.9947\n",
"Epoch 933/1000 \t Train Err: 1.9974\n",
"Epoch 934/1000 \t Train Err: 2.0159\n",
"Epoch 935/1000 \t Train Err: 2.0433\n",
"Epoch 936/1000 \t Train Err: 2.0755\n",
"Epoch 937/1000 \t Train Err: 2.0014\n",
"Epoch 938/1000 \t Train Err: 2.0443\n",
"Epoch 939/1000 \t Train Err: 2.0184\n",
"Epoch 940/1000 \t Train Err: 2.0192\n",
"Epoch 941/1000 \t Train Err: 2.0248\n",
"Epoch 942/1000 \t Train Err: 2.0124\n",
"Epoch 943/1000 \t Train Err: 2.0101\n",
"Epoch 944/1000 \t Train Err: 2.0024\n",
"Epoch 945/1000 \t Train Err: 2.0011\n",
"Epoch 946/1000 \t Train Err: 1.9871\n",
"Epoch 947/1000 \t Train Err: 1.9816\n",
"Epoch 948/1000 \t Train Err: 1.9875\n",
"Epoch 949/1000 \t Train Err: 2.0660\n",
"Epoch 950/1000 \t Train Err: 2.0591\n",
"Epoch 951/1000 \t Train Err: 2.0214\n",
"Epoch 952/1000 \t Train Err: 2.0312\n",
"Epoch 953/1000 \t Train Err: 2.0470\n",
"Epoch 954/1000 \t Train Err: 2.0365\n",
"Epoch 955/1000 \t Train Err: 2.0143\n",
"Epoch 956/1000 \t Train Err: 2.0104\n",
"Epoch 957/1000 \t Train Err: 2.0289\n",
"Epoch 958/1000 \t Train Err: 2.0097\n",
"Epoch 959/1000 \t Train Err: 1.9998\n",
"Epoch 960/1000 \t Train Err: 2.0095\n",
"Epoch 961/1000 \t Train Err: 2.0110\n",
"Epoch 962/1000 \t Train Err: 2.0009\n",
"Epoch 963/1000 \t Train Err: 1.9930\n",
"Epoch 964/1000 \t Train Err: 2.0003\n",
"Epoch 965/1000 \t Train Err: 1.9912\n",
"Epoch 966/1000 \t Train Err: 1.9859\n",
"Epoch 967/1000 \t Train Err: 1.9843\n",
"Epoch 968/1000 \t Train Err: 1.9828\n",
"Epoch 969/1000 \t Train Err: 1.9776\n",
"Epoch 970/1000 \t Train Err: 1.9790\n",
"Epoch 971/1000 \t Train Err: 1.9697\n",
"Epoch 972/1000 \t Train Err: 1.9671\n",
"Epoch 973/1000 \t Train Err: 1.9673\n",
"Epoch 974/1000 \t Train Err: 1.9585\n",
"Epoch 975/1000 \t Train Err: 1.9605\n",
"Epoch 976/1000 \t Train Err: 1.9537\n",
"Epoch 977/1000 \t Train Err: 1.9529\n",
"Epoch 978/1000 \t Train Err: 1.9477\n",
"Epoch 979/1000 \t Train Err: 1.9485\n",
"Epoch 980/1000 \t Train Err: 1.9376\n",
"Epoch 981/1000 \t Train Err: 1.9426\n",
"Epoch 982/1000 \t Train Err: 1.9416\n",
"Epoch 983/1000 \t Train Err: 1.9334\n",
"Epoch 984/1000 \t Train Err: 1.9249\n",
"Epoch 985/1000 \t Train Err: 1.9216\n",
"Epoch 986/1000 \t Train Err: 1.9268\n",
"Epoch 987/1000 \t Train Err: 1.9630\n",
"Epoch 988/1000 \t Train Err: 2.0237\n",
"Epoch 989/1000 \t Train Err: 2.0037\n",
"Epoch 990/1000 \t Train Err: 1.9824\n",
"Epoch 991/1000 \t Train Err: 1.9718\n",
"Epoch 992/1000 \t Train Err: 1.9726\n",
"Epoch 993/1000 \t Train Err: 1.9536\n",
"Epoch 994/1000 \t Train Err: 1.9662\n",
"Epoch 995/1000 \t Train Err: 1.9492\n",
"Epoch 996/1000 \t Train Err: 1.9482\n",
"Epoch 997/1000 \t Train Err: 1.9375\n",
"Epoch 998/1000 \t Train Err: 1.9492\n",
"Epoch 999/1000 \t Train Err: 1.9351\n"
]
}
],
"source": [
"model.train()\n",
"for epoch in range(NEPOCHS):\n",
" optimizer.zero_grad()\n",
" data, labels = mkbatch_rtl(BSZ)\n",
" # shift labels to prevent cheating\n",
" shifted_labels = torch.roll(labels, 1, dims=1)\n",
" shifted_labels[:, 0] = VOCAB_SIZE # start token\n",
" outputs = model(data, shifted_labels).permute((0, 2, 1))\n",
" loss = criterion(outputs, labels)\n",
" train_loss = loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" train_err.append(train_loss)\n",
"\n",
" with open('loss', 'a') as f:\n",
" f.write(f\"{train_loss}\\n\")\n",
" print(f\"Epoch {epoch}/{NEPOCHS} \\t Train Err: {train_loss:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 96,
"execution_state": "idle",
"id": "a3c41150-4541-4722-83a7-e7ad937f6c4f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[3, 8, 4, 0]], device='cuda:0') tensor([[7, 8]], device='cuda:0')\n",
"tensor([-4.4248e+00, -1.0567e+00, 1.2971e+00, -2.0221e+00, -6.6597e-01,\n",
" -2.6027e+00, -1.5254e-02, 8.1894e+00, -1.6939e-03, -1.2252e+00],\n",
" device='cuda:0')\n",
"tensor([-3.7663, -1.7898, -1.4273, 1.9667, -2.3513, -4.7138, -2.2421, 3.6817,\n",
" 8.9049, 3.1622], device='cuda:0')\n",
"tensor([[7, 8]], device='cuda:0', dtype=torch.int32) tensor([[7, 8]], device='cuda:0')\n"
]
}
],
"source": [
"model.eval()\n",
"data, labels = mkbatch_rtl(1)\n",
"print(data, labels)\n",
"with torch.no_grad():\n",
" ans = torch.zeros((1, NUM_LEN), dtype=torch.int, device=device)\n",
" ans[0, 0] = VOCAB_SIZE\n",
" for i in range(NUM_LEN):\n",
" outputs = model(data, ans)\n",
" print(outputs[0, i])\n",
" # break\n",
" ans[0, (i + 1) % NUM_LEN] = torch.argmax(outputs[0, i])\n",
"ans = torch.roll(ans, -1, dims=1)\n",
"print(ans, labels)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"execution_state": "idle",
"id": "1843b944-bab5-40ee-b26e-5d3b87ea9454",
"metadata": {},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'add-ltr-loss'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[32], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmath\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43madd-ltr-loss\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 5\u001b[0m plt\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;28mrange\u001b[39m(NEPOCHS), \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28;01mlambda\u001b[39;00m x: math\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;28mfloat\u001b[39m(x)), f\u001b[38;5;241m.\u001b[39mreadlines())))\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124madd-rtl-loss\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n",
"File \u001b[0;32m~/.venv/lib64/python3.12/site-packages/IPython/core/interactiveshell.py:324\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m}:\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 319\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIPython won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m by default \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124myou can use builtins\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m open.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 322\u001b[0m )\n\u001b[0;32m--> 324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'add-ltr-loss'"
]
}
],
"source": [
"import math\n",
"import matplotlib.pyplot as plt\n",
"\n",
"with open(\"add-ltr-loss\") as f:\n",
" plt.plot(range(NEPOCHS), list(map(lambda x: math.log(float(x)), f.readlines())))\n",
"with open(\"add-rtl-loss\") as f:\n",
" plt.plot(range(NEPOCHS), list(map(lambda x: math.log(float(x)), f.readlines())))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b97b349f-f20b-441d-8c7f-1724e8cf30cc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|