aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2025-05-11 23:28:57 -0400
committerAnthony Wang2025-05-11 23:28:57 -0400
commitb5c372bfd4d236d6249fba0cd4fc55956b4c345e (patch)
tree7a2133b7a654a6feefc808def5d00bb498bb429f
parent90b95fa71a56ff732fbe2c01b5f5a43a7c031a0c (diff)
Add top-scoring submissionHEADmaster
-rw-r--r--src/main.rs.best653
1 files changed, 653 insertions, 0 deletions
diff --git a/src/main.rs.best b/src/main.rs.best
new file mode 100644
index 0000000..e491fdd
--- /dev/null
+++ b/src/main.rs.best
@@ -0,0 +1,653 @@
+use rand::Rng;
+use rand::distr::Distribution;
+use rand::distr::weighted::WeightedIndex;
+use rand_xoshiro::Xoshiro256PlusPlus;
+use rand_xoshiro::rand_core::SeedableRng;
+use std::cmp::Ordering;
+use std::collections::{HashSet, VecDeque};
+use std::env;
+use std::fs::File;
+use std::io::{BufReader, prelude::*};
+use std::thread;
+use tch::nn::{Module, OptimizerConfig};
+use tch::{Device, IndexOp, Kind, NewAxis, Reduction, Tensor, nn, no_grad};
+
+const LR: f64 = 5e-4;
+const WD: f64 = 1e-4;
+const BS: usize = 4096;
+const EPOCHS: usize = 20;
+const SEQ_LEN: usize = 256;
+const EMBED: i64 = 12;
+const HIDDEN: i64 = 512;
+const KERNEL: i64 = 16;
+const WINDOW: usize = 30;
+const ITERS: usize = 15000; // Iters for length 100
+const THREADS: usize = 4;
+const VOCAB: usize = 28;
+// Sorted from lowest to highest freq
+const VOCAB_ORDER: [u8; VOCAB] = [
+ 25, 16, 9, 23, 10, 21, 27, 1, 15, 6, 24, 5, 2, 22, 12, 20, 11, 3, 17, 18, 7, 8, 13, 14, 0, 19,
+ 4, 26,
+];
+
+#[rustfmt::skip]
+const TRANS: [[f32; VOCAB]; VOCAB] = [
+ [-7.591_938_5, -2.892_999_6, -2.175_499, -3.498_550_7, -2.996_972_6, -2.549_879_6, -2.822_673_3, -1.661_220_1, -4.226_744_7, -1.847_851_5, -5.350_03, -2.655_289_2, -2.019_139, -4.166_14, -5.359_595_3, -2.123_517_8, -11.513_105, -2.907_192, -3.031_055_5, -3.436_183, -3.792_863_6, -2.951_992_8, -1.479_117_6, -2.626_861_3, -5.394_068, -0.961_980_4, -2.162_345, -20.723_267],
+ [-3.829_590_8, -5.279_662, -9.647_869, -7.151_901_2, -6.973_281, -8.775_774, -6.730_87, -6.306_915_3, -5.006_757_7, -11.513_105, -6.914_897_4, -7.158_250_3, -4.238_712, -7.696_574, -4.852_602, -6.468_646, -11.513_105, -6.261_517_5, -6.552_708_6, -8.077_111, -4.138_724_3, -11.513_105, -4.975_327_5, -8.290_131, -5.404_796_6, -11.513_105, -3.067_354, -20.723_267],
+ [-3.567_029, -10.307_745, -3.879_610_8, -8.355_694, -3.945_243_6, -7.539_704, -8.209_732, -9.787_797, -3.231_729, -11.513_105, -7.435_984, -6.633_628_4, -8.991_505, -3.292_566_8, -4.787_024, -10.329_214, -11.513_105, -4.649_024, -4.578_408_2, -5.882_728_6, -2.919_008_5, -11.513_105, -7.183_579, -1.746_624_2, -7.366_245_7, -7.672_502_5, -3.292_724_4, -20.723_267],
+ [-2.840_364_5, -7.370_669, -9.647_869, -4.581_870_6, -2.539_729, -7.668_403_6, -7.623_522, -8.273_909, -3.187_738, -11.513_105, -8.247_089, -2.392_907_1, -8.366_328, -1.767_054_9, -4.048_143, -9.203_718, -11.513_105, -3.572_408_7, -7.468_721, -9.571_307, -4.192_383_3, -11.513_105, -6.453_733, -8.944_214, -6.014_779, -11.513_105, -3.493_481, -20.723_267],
+ [-8.789_968_5, -0.882_509_6, -1.761_314_5, -2.229_609, -3.460_188_2, -2.379_223_8, -2.100_699, -0.839_311_1, -3.267_797_2, -1.279_895, -1.137_539_4, -1.836_477_3, -1.454_008, -2.393_423_3, -5.954_031_5, -1.575_469_5, -11.513_105, -1.506_992_6, -2.176_778_3, -2.483_203, -3.771_802, -0.265_394_7, -1.942_219, -2.516_527_7, -3.196_823, -1.083_491_7, -3.787_814_6, -20.723_267],
+ [-4.656_895_6, -10.307_745, -11.513_105, -6.434_101, -4.661_793, -3.189_781, -8.278_472, -7.476_047, -3.797_279_8, -11.513_105, -4.357_318, -3.576_648_7, -5.378_657, -5.115_966_3, -2.389_812_7, -8.831_921, -11.513_105, -4.785_914_4, -5.850_82, -6.378_231_5, -5.453_339, -11.513_105, -5.192_757, -5.090_478, -6.588_798_5, -11.513_105, -3.367_658, -20.723_267],
+ [-3.935_778_6, -11.513_105, -11.513_105, -5.198_305, -5.182_173_3, -9.076_6, -4.308_548, -9.307_589, -3.593_207_8, -11.513_105, -7.141_876_7, -6.870_805_7, -10.692_42, -2.028_623, -5.771_454, -9.323_16, -11.513_105, -5.052_397_3, -6.821_178, -9.731_103, -2.909_238_3, -11.513_105, -9.689_901, -8.944_214, -8.006_65, -11.513_105, -4.130_123, -20.723_267],
+ [-6.750_226, -8.066_504, -1.719_597, -6.110_409_7, -6.189_956_7, -7.899_103_6, -1.912_184_1, -9.160_122, -9.226_412, -11.513_105, -5.135_840_4, -8.518_253, -8.545_525, -6.787_862, -5.718_264_6, -4.609_116, -11.513_105, -5.813_02, -2.525_067_6, -1.224_304_8, -10.253_361, -11.513_105, -1.766_828_5, -4.545_65, -7.245_962_6, -7.672_502_5, -2.483_950_6, -20.723_267],
+ [-3.126_791_7, -3.799_575, -3.186_328_2, -2.768_115_3, -4.389_353, -2.612_637, -3.096_265_6, -2.015_972_1, -8.032_84, -3.987_442, -1.910_051_2, -2.178_937, -2.289_748, -3.666_899_2, -4.623_344, -2.913_505_8, -11.513_105, -2.678_056_2, -2.861_862, -2.780_634, -3.503_753, -2.036_225_6, -1.760_814_5, -2.264_577_6, -4.232_542, -2.710_460_4, -2.579_534_8, -20.723_267],
+ [-9.585_668, -4.257_89, -11.513_105, -8.256_346, -7.803_908_3, -11.513_105, -10.480_922, -11.057_719, -11.513_105, -11.513_105, -9.365_77, -9.140_743, -10.692_42, -6.505_830_3, -9.031_633, -10.755_326, -11.513_105, -9.805_901, -10.500_641, -10.708_67, -10.696_741, -11.513_105, -10.603_504, -11.513_105, -10.070_196, -11.513_105, -5.602_34, -20.723_267],
+ [-4.254_719_7, -11.513_105, -3.200_003, -8.096_455, -6.754_756_5, -10.923_909, -9.805_095, -10.508_677_5, -4.890_956, -11.513_105, -8.732_795, -4.552_042, -11.020_834, -4.463_423_3, -4.283_471_6, -8.257_408, -11.513_105, -5.009_569_6, -5.384_791_4, -11.325_874, -8.812_76, -11.513_105, -7.243_991, -11.513_105, -10.551_27, -11.513_105, -5.107_608, -20.723_267],
+ [-2.719_842, -2.171_247_2, -3.465_415_2, -4.541_519, -3.172_815_8, -4.154_808_5, -3.646_469_4, -6.766_412, -3.120_190_6, -11.513_105, -4.313_071_3, -1.951_761_5, -5.967_743, -4.602_898_6, -3.833_565, -2.322_928_4, -11.513_105, -4.435_82, -5.262_661, -4.106_762_4, -2.086_468_5, -11.513_105, -5.110_801_7, -6.713_080_4, -5.764_374_7, -3.582_983_7, -3.718_748_6, -20.723_267],
+ [-3.592_866_7, -6.562_377_5, -10.936_696, -5.287_308, -3.877_893, -9.623_346, -6.248_005_4, -6.303_758, -3.124_795_2, -11.513_105, -7.080_122_5, -5.256_596_6, -3.885_822_5, -7.657_765_4, -2.880_621, -7.727_230_5, -11.513_105, -4.395_987, -5.172_2, -6.633_761, -4.039_342, -11.513_105, -8.661_436, -11.513_105, -5.230_367_7, -11.513_105, -2.871_839_8, -20.723_267],
+ [-1.597_543_4, -9.295_766, -10.097_605, -5.292_990_7, -2.487_039_3, -9.313_064, -4.270_337_6, -6.744_131_6, -1.308_721_7, -11.513_105, -1.928_257_2, -5.963_909, -6.134_99, -4.110_112_7, -1.975_486_4, -9.323_16, -11.513_105, -3.969_218_3, -6.406_316_3, -6.633_761, -2.287_234_3, -10.490_972, -3.211_650_1, -11.513_105, -6.668_236, -11.513_105, -3.653_006, -20.723_267],
+ [-7.741_305_4, -2.454_400_5, -1.525_927, -2.863_657_2, -6.311_463, -1.725_130_9, -2.603_309_2, -2.540_201_4, -3.049_867, -1.419_446_5, -6.538_593, -2.566_781_3, -2.362_048_6, -2.469_979_5, -3.197_990_2, -2.080_492, -11.513_105, -2.697_022_7, -2.865_752, -2.179_993, -6.900_775, -3.221_613_4, -2.312_064_2, -4.879_478_5, -1.818_898_6, -4.474_464, -2.843_442_7, -20.723_267],
+ [-4.030_286_3, -11.513_105, -10.573_381, -9.024_024, -4.715_614_3, -7.836_251, -7.501_858, -9.787_797, -5.341_391, -11.513_105, -8.247_089, -5.940_762_5, -3.205_545_4, -7.367_222, -4.124_150_8, -2.470_337_9, -11.513_105, -5.421_519_3, -3.863_525_9, -8.861_415, -3.146_821_3, -11.513_105, -8.619_498, -1.339_567_1, -7.223_549_4, -11.513_105, -3.670_306_2, -20.723_267],
+ [-10.869_481, -11.513_105, -5.161_777_5, -10.273_109, -6.273_897, -10.923_909, -10.869_361, -10.620_307, -8.873_802, -11.513_105, -11.513_105, -10.857_676, -11.513_105, -6.483_735, -9.112_832, -10.755_326, -11.513_105, -8.188_78, -7.532_092, -10.912_28, -10.088_636, -11.513_105, -11.513_105, -4.863_750_5, -10.551_27, -11.513_105, -6.010_124_7, -20.723_267],
+ [-2.380_518_7, -2.971_973_2, -3.193_902_3, -4.207_732, -1.864_907_7, -2.681_471_8, -2.695_279_8, -4.959_060_7, -3.280_368_3, -11.513_105, -8.155_254, -6.200_581_6, -2.820_928_8, -6.244_217, -2.220_730_8, -1.956_054, -11.513_105, -3.837_347, -7.571_002_5, -3.770_41, -1.872_708_4, -6.901_649, -4.567_485, -11.513_105, -7.268_889_4, -11.513_105, -3.974_993, -20.723_267],
+ [-2.199_485_3, -4.014_287, -7.049_959, -3.797_005_2, -2.864_42, -7.275_875_6, -3.802_058_2, -6.339_048_4, -2.140_158_4, -11.513_105, -3.389_352_3, -4.379_51, -3.850_523_7, -3.391_756_5, -3.693_605_7, -3.922_179, -11.513_105, -2.862_128_3, -2.919_803_9, -4.157_688, -2.148_169, -11.513_105, -4.797_732_4, -5.466_213, -3.208_151, -11.513_105, -2.598_197_7, -20.723_267],
+ [-2.041_496_8, -4.551_41, -2.453_904_6, -7.047_818, -3.565_677, -3.377_406_8, -5.630_275, -3.469_57, -1.993_827_8, -11.513_105, -7.791_498_7, -4.063_527, -7.335_05, -2.424_767_5, -2.803_703_8, -3.327_58, -11.513_105, -3.117_431_9, -2.197_329_8, -3.673_175_6, -2.026_581_5, -11.513_105, -7.587_344, -1.843_796_1, -4.424_683_6, -6.990_154, -2.010_327_6, -20.723_267],
+ [-4.493_191_7, -2.045_081_9, -3.441_515_4, -4.625_169_8, -7.377_274, -3.366_884_7, -4.019_038_7, -4.824_732_3, -7.603_225_7, -1.204_362, -8.348_22, -4.263_171, -3.008_42, -5.452_8, -1.893_440_6, -3.538_099, -0.001_714_452_5, -4.397_697, -3.250_448_5, -4.170_668_6, -10.450_676, -6.800_828_5, -8.467_393, -4.592_125, -10.070_196, -5.612_035_3, -4.575_990_7, -20.723_267],
+ [-3.431_514_7, -8.532_299, -11.513_105, -5.830_799_6, -3.924_140_2, -10.555_642, -11.513_105, -10.889_751, -3.970_075_8, -11.513_105, -10.254_454, -5.386_309_6, -10.247_817, -5.416_301_3, -4.264_887, -11.513_105, -11.513_105, -4.975_447_7, -9.886_048, -10.912_28, -7.955_444_3, -11.513_105, -10.958_263, -5.653_776, -6.996_506, -7.672_502_5, -4.962_283_6, -20.723_267],
+ [-4.383_549, -9.175_94, -11.513_105, -6.060_079_6, -5.145_225, -8.170_727, -7.974_089_6, -6.377_324_6, -10.387_997, -11.513_105, -5.921_405_3, -4.953_966_6, -9.043_949, -6.696_472, -3.019_770_4, -7.606_258_4, -11.513_105, -5.849_935, -5.396_963_6, -4.984_56, -9.293_391, -11.513_105, -8.849_356, -7.218_435, -6.520_682_3, -4.082_706_5, -2.607_651, -20.723_267],
+ [-7.124_895, -11.513_105, -11.513_105, -11.513_105, -4.681_599, -11.513_105, -11.513_105, -11.513_105, -6.373_989, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -6.898_58, -7.228_384, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -7.596_529_5, -11.513_105, -11.513_105, -5.466_213, -10.920_747, -11.513_105, -9.556_946, -20.723_267],
+ [-3.479_500_8, -2.599_211_7, -4.462_72, -4.215_739_3, -4.337_405_7, -6.045_270_4, -6.257_39, -5.687_123, -11.281_939, -11.513_105, -5.127_308, -2.319_987_3, -2.680_487_9, -4.002_873_4, -5.709_786, -4.407_800_7, -8.677_725, -3.129_120_3, -5.966_902, -4.185_307, -7.634_954, -4.926_794_5, -7.647_650_2, -6.452_617, -8.497_671, -2.881_854_5, -4.042_910_6, -20.723_267],
+ [-7.170_823_6, -11.513_105, -11.513_105, -11.513_105, -9.303_842, -11.513_105, -10.201_862, -11.513_105, -5.639_653_7, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -9.832_733, -8.523_063, -11.513_105, -11.513_105, -10.723_01, -11.513_105, -9.001_49, -7.790_009_5, -11.513_105, -10.958_263, -11.513_105, -10.920_747, -2.501_138_4, -10.541_421, -20.723_267],
+ [-2.715_362_5, -5.991_672_5, -5.186_562_5, -0.495_729_1, -1.120_387_1, -0.994_527_76, -1.014_298_6, -2.282_205, -2.596_138_5, -11.513_105, -1.262_556_9, -2.028_100_3, -1.974_505_3, -1.473_895_5, -1.899_170_9, -2.886_320_6, -8.677_725, -1.367_780_4, -1.043_720_5, -1.285_834_2, -2.472_834_8, -7.485_031, -2.315_513_6, -2.507_613_4, -0.409_93, -4.361_229, -20.723_267, 1e-09],
+ [-7.057_105, -7.864_855_3, -6.832_732_7, -3.767_181_4, -4.200_749, -4.790_685, -4.105_086, -5.590_689, -7.180_27, -5.557_085, -4.052_763, -4.921_688_6, -4.251_131, -4.469_908_7, -6.025_704, -5.421_922, -6.783_435, -3.429_898, -3.662_474, -4.345_318, -5.380_872_7, -8.071_246, -5.128_164_3, -4.732_333_7, -3.275_528_7, -11.513_105, -20.723_267, -20.723_267],
+];
+
+/// Conv2D with middle element masked
+#[derive(Debug)]
+struct MaskedConv {
+ pub convl: nn::Conv2D,
+ pub convr: nn::Conv2D,
+}
+
+impl MaskedConv {
+ fn new(
+ vs: &nn::Path,
+ in_dim: i64,
+ out_dim: i64,
+ kernel_dim: [i64; 2],
+ config: nn::ConvConfigND<[i64; 2]>,
+ ) -> MaskedConv {
+ MaskedConv {
+ convl: nn::conv(vs, in_dim, out_dim, kernel_dim, config),
+ convr: nn::conv(vs, in_dim, out_dim, kernel_dim, config),
+ }
+ }
+}
+
+impl Module for MaskedConv {
+ fn forward(&self, xs: &Tensor) -> Tensor {
+ // tch-rs doesn't like negative indices in slices
+ (self.convl.forward(&xs.i((.., NewAxis, KERNEL + 1.., ..)))
+ + self
+ .convr
+ .forward(&xs.i((.., NewAxis, ..xs.size()[1] - KERNEL - 1, ..))))
+ .squeeze()
+ .transpose(-2, -1)
+ }
+}
+
+fn net(vs: &nn::Path) -> impl Module + use<> {
+ nn::seq()
+ .add(nn::embedding(vs, VOCAB as i64, EMBED, Default::default()))
+ .add(MaskedConv::new(
+ vs,
+ 1,
+ HIDDEN,
+ [KERNEL, EMBED],
+ Default::default(),
+ ))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, HIDDEN, HIDDEN, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, HIDDEN, HIDDEN, Default::default()))
+ .add_fn(|xs| xs.relu())
+ .add(nn::linear(vs, HIDDEN, VOCAB as i64, Default::default()))
+ .add_fn(|xs| xs.log_softmax(-1, Kind::Float))
+}
+
+fn probs(net: &impl Module, s: &[u8]) -> (Tensor, f64) {
+ let data: Vec<_> = s.iter().map(|x| *x as i64).collect();
+ // We only do inference on CPU so no need to move to device
+ let xs = Tensor::from_slice2(&[data]);
+ // Remove batch dim
+ let ys = net.forward(&xs).squeeze();
+ let loss = ys
+ .nll_loss(&xs.i((0, KERNEL..s.len() as i64 - KERNEL)))
+ .double_value(&[]);
+ (ys, loss)
+}
+
+/// Pad both sides with spaces
+fn probs_padded(net: &impl Module, s: &[u8]) -> (Tensor, f64) {
+ probs(
+ net,
+ &[&[26_u8; KERNEL as usize], s, &[26_u8; KERNEL as usize]].concat(),
+ )
+}
+
+fn char_to_label(c: char) -> u8 {
+ if c.is_ascii_lowercase() {
+ c as u8 - b'a'
+ } else if c == ' ' {
+ 26
+ } else {
+ 27
+ }
+}
+
+fn label_to_char(l: u8) -> char {
+ match l.cmp(&26) {
+ Ordering::Less => (b'a' + l) as char,
+ Ordering::Equal => ' ',
+ Ordering::Greater => '.',
+ }
+}
+
+fn to_string(s: &[u8]) -> String {
+ s.iter().map(|l| label_to_char(*l)).collect()
+}
+
+fn permute(s: &[u8], p: &[u8; VOCAB]) -> Vec<u8> {
+ s.iter().map(|l| p[*l as usize]).collect()
+}
+
+fn logprob(
+ s: &[u8],
+ p: &[u8; VOCAB],
+ cnts: &[[i32; VOCAB]; VOCAB],
+ grams: &[[[f32; VOCAB]; VOCAB]; VOCAB],
+) -> f32 {
+ let mut lp = 0.;
+ if s.len() < 300 {
+ // Use trigrams for short and hard sequences
+ for i in 1..s.len() {
+ lp += TRANS[p[s[i] as usize] as usize][p[s[i - 1] as usize] as usize];
+ }
+ lp += TRANS[p[s[1] as usize] as usize][p[s[0] as usize] as usize];
+ for i in 2..s.len() {
+ lp += grams[p[s[i - 2] as usize] as usize][p[s[i - 1] as usize] as usize]
+ [p[s[i] as usize] as usize]
+ / 2.;
+ }
+ } else {
+ for i in 0..VOCAB {
+ for j in 0..VOCAB {
+ lp += cnts[i][j] as f32 * TRANS[p[j] as usize][p[i] as usize];
+ }
+ }
+ }
+ // unsafe {
+ // if s.len() < 300 {
+ // // Use trigrams for short and hard sequences
+ // for i in 1..s.len() {
+ // lp += TRANS
+ // .get_unchecked(*p.get_unchecked(s[i] as usize) as usize)
+ // .get_unchecked(*p.get_unchecked(s[i - 1] as usize) as usize);
+ // }
+ // lp += TRANS[p[s[1] as usize] as usize][p[s[0] as usize] as usize];
+ // for i in 2..s.len() {
+ // lp += grams
+ // .get_unchecked(*p.get_unchecked(s[i - 2] as usize) as usize)
+ // .get_unchecked(*p.get_unchecked(s[i - 1] as usize) as usize)
+ // .get_unchecked(*p.get_unchecked(s[i] as usize) as usize)
+ // / 2.;
+ // }
+ // } else {
+ // for i in 0..VOCAB {
+ // for j in 0..VOCAB {
+ // lp += cnts[i][j] as f32
+ // * TRANS
+ // .get_unchecked(p[j] as usize)
+ // .get_unchecked(p[i] as usize);
+ // }
+ // }
+ // }
+ // }
+ lp
+}
+
+fn weights(s: &[u8], p: &[u8; VOCAB], ys: &Tensor) -> Vec<f64> {
+ // dbg!(to_string(&permute(s, p)));
+ let mut swaps = [[0.; VOCAB]; VOCAB];
+ let mut cnts = [[0; VOCAB]; VOCAB];
+ for i in 0..s.len() {
+ for j in 0..VOCAB {
+ // NOTE: ys is padded
+ let v = ys.i((i as i64, j as i64)).double_value(&[]);
+ swaps[p[s[i] as usize] as usize][j] += v;
+ cnts[p[s[i] as usize] as usize][j] += 1;
+ }
+ }
+ let mut w = vec![0.; VOCAB * VOCAB];
+ // print!(" ");
+ // for i in 0..VOCAB {
+ // print!("{:5} ", p
+ // .iter()
+ // .position(|x| *x as usize == i)
+ // .unwrap());
+ // }
+ // println!();
+ for i in 0..VOCAB {
+ // print!("{:2} ", p
+ // .iter()
+ // .position(|x| *x as usize == i)
+ // .unwrap());
+ for j in 0..VOCAB {
+ if j < i {
+ if cnts[i][j] > 0 {
+ w[i * VOCAB + j] += swaps[i][j] / cnts[i][j] as f64;
+ }
+ if cnts[j][i] > 0 {
+ w[i * VOCAB + j] += swaps[j][i] / cnts[j][i] as f64;
+ }
+ // print!("{:>5} ", swaps[i][j] / cnts[i][j] as f64);
+ // Oh no floating point comparison spooky spooky
+ if w[i * VOCAB + j] != 0. {
+ // TODO: adjust this factor
+ w[i * VOCAB + j] = (w[i * VOCAB + j] / 2.5).exp();
+ }
+ // print!("{:5.2} ", w[i * VOCAB + j]);
+ }
+ }
+ // println!();
+ }
+ w
+}
+
+fn refiner(net: &impl Module, s: &[u8], p2: &[u8; VOCAB], iters: usize) -> ([u8; VOCAB], f64) {
+ let mut p = *p2;
+ // Cleanup using CNN
+ let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
+ // Use padding to properly handle the ends
+ let (mut ys, mut loss) = probs_padded(net, &permute(s, &p));
+ // dbg!(loss, to_string(&permute(s, &p)));
+ if loss > 3. {
+ // Not worth trying
+ return (p, loss);
+ }
+ let mut pbest = *p2;
+ let mut lossbest = loss;
+ let mut w = weights(s, &p, &ys);
+ for _ in 0..300 * iters / s.len() {
+ let dist = WeightedIndex::new(&w).unwrap();
+ let sample = dist.sample(&mut rng);
+ let a = p
+ .iter()
+ .position(|x| *x as usize == sample / VOCAB)
+ .unwrap();
+ let b = p
+ .iter()
+ .position(|x| *x as usize == sample % VOCAB)
+ .unwrap();
+ let mut q = p;
+ (q[a], q[b]) = (q[b], q[a]);
+ let (ys2, loss2) = probs_padded(net, &permute(s, &q));
+ let acc = f64::min(0., (loss - loss2) * 20.);
+ if rng.random::<f64>() < acc.exp() {
+ ys = ys2;
+ p = q;
+ loss = loss2;
+ w = weights(s, &p, &ys);
+ }
+ if loss2 < lossbest {
+ lossbest = loss2;
+ pbest = q;
+ }
+ }
+ // dbg!(lossbest, to_string(&permute(s, &pbest)));
+ (pbest, lossbest)
+}
+
+fn decode(
+ net: &impl Module,
+ s: &[u8],
+ grams: &[[[f32; VOCAB]; VOCAB]; VOCAB],
+) -> ([u8; VOCAB], f64) {
+ // Initialize using naive freqs
+ let mut porig = [0; VOCAB];
+ let mut cnt = [0; VOCAB];
+ for c in s {
+ cnt[*c as usize] += 1;
+ }
+ let mut indices: Vec<_> = (0..VOCAB).collect();
+ indices.sort_unstable_by_key(|&a| cnt[a]);
+ for i in 0..VOCAB {
+ porig[indices[i]] = VOCAB_ORDER[i];
+ }
+ // MCMC for some iters
+ let mut cnts = [[0; VOCAB]; VOCAB];
+ for i in 1..s.len() {
+ cnts[s[i - 1] as usize][s[i] as usize] += 1;
+ }
+ // Is this easy to decode?
+ let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
+ for _ in 0..20 {
+ let mut p = porig;
+ let mut lp = logprob(s, &p, &cnts, grams);
+ for _ in 0..10000 {
+ let a = rng.random_range(0..VOCAB);
+ let b = rng.random_range(0..VOCAB);
+ if a != b {
+ let mut q = p;
+ (q[a], q[b]) = (q[b], q[a]);
+ let lp2 = logprob(s, &q, &cnts, grams);
+ let acc = f32::min(0., lp2 - lp);
+ if rng.random::<f32>() < acc.exp() {
+ p = q;
+ lp = lp2;
+ }
+ }
+ }
+ let (p2, loss) = refiner(net, s, &p, 100);
+ // dbg!(to_string(&permute(s, &p2)), loss);
+ if loss < 0.7 {
+ // Probably done
+ return refiner(net, s, &p2, 1000);
+ }
+ }
+ // Nope we're gonna have to try hard ugh
+ // FEARLESS CONCURRENCY
+ let real_iters = ITERS * 100 / s.len() / THREADS * THREADS;
+ let mut res: Vec<([u8; VOCAB], f32)> = vec![];
+ thread::scope(|sc| {
+ let mut handles = vec![];
+ for _ in 0..THREADS {
+ let handle = sc.spawn(move || {
+ let mut tmp = vec![];
+ let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng());
+ for _k in 0..real_iters / THREADS {
+ let mut p = porig;
+ let mut lp = logprob(s, &p, &cnts, grams);
+ let mut pbest = p;
+ let mut lpbest = lp;
+ // TODO: investigate if we need more/less iters here
+ for _ in 0..10000 {
+ let a = rng.random_range(0..VOCAB);
+ let b = rng.random_range(0..VOCAB);
+ if a != b {
+ let mut q = p;
+ (q[a], q[b]) = (q[b], q[a]);
+ // We *could* reuse some of the computation from lp
+ // But logprob is already a tight nested loop so it's not much faster
+ // And floating point garbage builds up
+ let lp2 = logprob(s, &q, &cnts, grams);
+ let acc = f32::min(0., lp2 - lp);
+ if rng.random::<f32>() < acc.exp() {
+ p = q;
+ lp = lp2;
+ if lp2 > lpbest {
+ lpbest = lp2;
+ pbest = q;
+ }
+ }
+ }
+ }
+ // dbg!(_k, to_string(&permute(s, &pbest)), lpbest);
+ tmp.push((pbest, lpbest));
+ }
+ tmp
+ });
+ handles.push(handle);
+ }
+ for handle in handles {
+ res.extend(&handle.join().unwrap());
+ }
+ });
+ let mut ord = (0..real_iters).collect::<Vec<_>>();
+ ord.sort_unstable_by(|&i, &j| res[j].1.partial_cmp(&res[i].1).unwrap());
+ // Try 50 best perms
+ let mut pbest = porig;
+ let mut lossbest = 100.;
+ // Don't retry perms that decode to the same thing
+ let mut tried = HashSet::new();
+ let mut i = 0;
+ while tried.len() < 40 {
+ let p2 = res[ord[i]].0;
+ let s2 = permute(s, &p2);
+ if !tried.contains(&s2) {
+ tried.insert(s2);
+ let (p, loss) = refiner(net, s, &p2, 50);
+ if loss < lossbest {
+ pbest = p;
+ lossbest = loss;
+ }
+ if lossbest < 0.7 {
+ // Probably done
+ break;
+ }
+ }
+ i += 1;
+ }
+ refiner(net, s, &pbest, 1000)
+}
+
+fn finish(
+ net: &impl Module,
+ text: &[u8],
+ pl: &[u8; 28],
+ pr: &[u8; 28],
+ jbest: usize,
+) -> (Vec<u8>, f64) {
+ let pl2 = &refiner(net, &text[..jbest], pl, 1000).0;
+ let pr2 = &refiner(net, &text[jbest..], pr, 1000).0;
+ let ans = [permute(&text[..jbest], pl2), permute(&text[jbest..], pr2)].concat();
+ (ans.clone(), probs(net, &ans).1)
+}
+
+pub fn main() {
+ let device = Device::cuda_if_available();
+ // dbg!(device);
+ let mut vs = nn::VarStore::new(device);
+ let net = net(&vs.root());
+ let args: Vec<String> = env::args().collect();
+ // dbg!(args.len());
+ if args.len() == 3 {
+ no_grad(|| {
+ // Load trigrams
+ let file = File::open("trigrams").unwrap();
+ let mut reader = BufReader::new(file);
+ let mut grams = [[[0.; VOCAB]; VOCAB]; VOCAB];
+ for i in 0..VOCAB {
+ for j in 0..VOCAB {
+ for k in 0..VOCAB {
+ let mut buf = String::new();
+ reader.read_line(&mut buf).unwrap();
+ buf.pop(); // Remove newline
+ grams[i][j][k] = buf.parse().unwrap();
+ }
+ }
+ }
+ // Decode
+ vs.load("model.safetensors").unwrap();
+ let text: Vec<_> = args[1].chars().map(char_to_label).collect();
+ if args[2].to_lowercase() != "true" {
+ // No breakpoint
+ // dbg!(probs(&net, &text).1);
+ // return;
+ let p = decode(&net, &text, &grams).0;
+ println!("{}", to_string(&permute(&text, &p)));
+ } else {
+ // Decode left and right halves
+ let m = text.len() / 2;
+ let (pl, lossl) = decode(&net, &text[..m], &grams);
+ let (pr, lossr) = decode(&net, &text[m..], &grams);
+ // dbg!(to_string(&permute(&text[..m], &pl)));
+ // dbg!(to_string(&permute(&text[m..], &pr)));
+ // dbg!(lossl, lossr);
+ let mut reallossbest = 100.;
+ let mut ans = vec![];
+ // let mut should_loop = true;
+ // This is spaghetti code yeah I know
+ if lossl < 1.5 {
+ // First half correct
+ let mut deque = VecDeque::new();
+ let mut sum = 0.;
+ let ys = probs(&net, &permute(&text[m - WINDOW..], &pl)).0;
+ for i in m - WINDOW + KERNEL as usize..text.len() - KERNEL as usize {
+ deque.push_back(
+ -ys.i((
+ (i + WINDOW - m) as i64 - KERNEL,
+ pl[text[i] as usize] as i64,
+ ))
+ .double_value(&[]),
+ );
+ sum += deque.back().unwrap();
+ if deque.len() > WINDOW {
+ sum -= deque.front().unwrap();
+ deque.pop_front();
+ }
+ // dbg!(i, sum / WINDOW as f64);
+ if deque.len() == WINDOW && sum / WINDOW as f64 > 4. {
+ // Breakpoint probably in i - WINDOW to i - 10
+ let pr2 = decode(&net, &text[i - 10..], &grams).0;
+ let mut jbest = 0;
+ let mut lossbest = 100.;
+ let mut pt = permute(
+ &text[i - WINDOW - KERNEL as usize..i + KERNEL as usize],
+ &pr2,
+ );
+ for j in i - WINDOW..i {
+ pt[j + WINDOW - i + KERNEL as usize] = pl[text[j] as usize];
+ let loss = probs(&net, &pt).1;
+ // dbg!(j, to_string(&pt), loss);
+ if loss < lossbest {
+ jbest = j + 1; // Off-by-1 error
+ lossbest = loss;
+ }
+ }
+ // dbg!(jbest, i);
+ (ans, reallossbest) = finish(&net, &text, &pl, &pr2, jbest);
+ // should_loop = false;
+ break;
+ }
+ }
+ if reallossbest == 100. {
+ // No breakpoint found???
+ ans = permute(&text, &refiner(&net, &text, &pl, 1000).0);
+ reallossbest = probs(&net, &ans).1;
+ }
+ }
+ if lossl >= 1.5 || lossr < 1.5 {
+ // Second half correct
+ // Same alg, just reversed
+ let mut deque = VecDeque::new();
+ let mut sum = 0.;
+ let ys = probs(&net, &permute(&text[..m + WINDOW], &pr)).0;
+ // TODO: Are these indices correct???
+ for i in (KERNEL as usize..m + WINDOW - KERNEL as usize).rev() {
+ deque.push_back(
+ -ys.i((i as i64 - KERNEL, pr[text[i] as usize] as i64))
+ .double_value(&[]),
+ );
+ sum += deque.back().unwrap();
+ if deque.len() > WINDOW {
+ sum -= deque.front().unwrap();
+ deque.pop_front();
+ }
+ if deque.len() == WINDOW && sum / WINDOW as f64 > 4. {
+ // Breakpoint probably in i + 10 to i + WINDOW
+ let pl2 = decode(&net, &text[..i + 10], &grams).0;
+ let mut jbest = 0;
+ let mut lossbest = 100.;
+ let mut pt = permute(
+ &text[i - KERNEL as usize..i + WINDOW + KERNEL as usize],
+ &pr,
+ );
+ for j in i..i + WINDOW {
+ pt[j + KERNEL as usize - i] = pl2[text[j] as usize];
+ let loss = probs(&net, &pt).1;
+ // dbg!(j, to_string(&pt), loss);
+ if loss < lossbest {
+ jbest = j + 1; // Off-by-1 error
+ lossbest = loss;
+ }
+ }
+ // dbg!(jbest, i);
+ let (ans2, loss2) = finish(&net, &text, &pl2, &pr, jbest);
+ if loss2 < reallossbest {
+ ans = ans2;
+ reallossbest = loss2;
+ }
+ // should_loop = false;
+ break;
+ }
+ }
+ if reallossbest == 100. {
+ // No breakpoint found???
+ ans = permute(&text, &refiner(&net, &text, &pr, 1000).0);
+ // reallossbest = probs(&net, &ans).1;
+ }
+ }
+ // if should_loop {
+ // // failed to find breakpoint
+ // sleep(time::Duration::from_secs(1000000));
+ // }
+ println!("{}", to_string(&ans));
+ }
+ })
+ } else {
+ // vs.load("model.safetensors").unwrap();
+ // Train model
+ let mut opt = nn::Adam::default().build(&vs, LR).unwrap();
+ opt.set_weight_decay(WD);
+ let file = File::open("wikitext").unwrap();
+ let reader = BufReader::new(file);
+ let mut data = vec![];
+ for line in reader.lines() {
+ // net doesn't accept u8 as input, must be i64
+ let linedata: Vec<_> = line
+ .unwrap()
+ .chars()
+ .map(|c| char_to_label(c) as i64)
+ .collect();
+ for i in 0..linedata.len() / SEQ_LEN {
+ // Pad with spaces
+ // I should probably use a dedicated padding token but I don't want to retrain
+ data.push(
+ [
+ &[26; KERNEL as usize],
+ &linedata[i * SEQ_LEN..(i + 1) * SEQ_LEN],
+ &[26; KERNEL as usize],
+ ]
+ .concat(),
+ );
+ }
+ }
+ for i in 0..EPOCHS {
+ for j in 0..data.len() / BS {
+ let xs = Tensor::from_slice2(&data[j * BS..(j + 1) * BS]).to(device);
+ let ys = net.forward(&xs).transpose(-2, -1);
+ let loss = ys.nll_loss_nd::<Tensor>(
+ &xs.i((.., KERNEL..SEQ_LEN as i64 + KERNEL)),
+ None,
+ Reduction::Mean,
+ -100,
+ );
+ println!("{i} {j} {loss}");
+ opt.backward_step(&loss);
+ }
+ }
+ vs.save("model.safetensors").unwrap();
+ // vs.freeze();
+ // let mut closure = |input: &[Tensor]| vec![net.forward(&input[0])];
+ // let model = CModule::create_by_tracing(
+ // "MyModule",
+ // "forward",
+ // &[Tensor::zeros([1, 784], (tch::Kind::Int64, device))],
+ // &mut closure,
+ // )
+ // .unwrap();
+ // // I think this has the input size hardcoded though sad
+ // model.save("model.pt").unwrap();
+ }
+}