diff options
-rw-r--r-- | src/main.rs.best | 653 |
1 files changed, 653 insertions, 0 deletions
diff --git a/src/main.rs.best b/src/main.rs.best new file mode 100644 index 0000000..e491fdd --- /dev/null +++ b/src/main.rs.best @@ -0,0 +1,653 @@ +use rand::Rng; +use rand::distr::Distribution; +use rand::distr::weighted::WeightedIndex; +use rand_xoshiro::Xoshiro256PlusPlus; +use rand_xoshiro::rand_core::SeedableRng; +use std::cmp::Ordering; +use std::collections::{HashSet, VecDeque}; +use std::env; +use std::fs::File; +use std::io::{BufReader, prelude::*}; +use std::thread; +use tch::nn::{Module, OptimizerConfig}; +use tch::{Device, IndexOp, Kind, NewAxis, Reduction, Tensor, nn, no_grad}; + +const LR: f64 = 5e-4; +const WD: f64 = 1e-4; +const BS: usize = 4096; +const EPOCHS: usize = 20; +const SEQ_LEN: usize = 256; +const EMBED: i64 = 12; +const HIDDEN: i64 = 512; +const KERNEL: i64 = 16; +const WINDOW: usize = 30; +const ITERS: usize = 15000; // Iters for length 100 +const THREADS: usize = 4; +const VOCAB: usize = 28; +// Sorted from lowest to highest freq +const VOCAB_ORDER: [u8; VOCAB] = [ + 25, 16, 9, 23, 10, 21, 27, 1, 15, 6, 24, 5, 2, 22, 12, 20, 11, 3, 17, 18, 7, 8, 13, 14, 0, 19, + 4, 26, +]; + +#[rustfmt::skip] +const TRANS: [[f32; VOCAB]; VOCAB] = [ + [-7.591_938_5, -2.892_999_6, -2.175_499, -3.498_550_7, -2.996_972_6, -2.549_879_6, -2.822_673_3, -1.661_220_1, -4.226_744_7, -1.847_851_5, -5.350_03, -2.655_289_2, -2.019_139, -4.166_14, -5.359_595_3, -2.123_517_8, -11.513_105, -2.907_192, -3.031_055_5, -3.436_183, -3.792_863_6, -2.951_992_8, -1.479_117_6, -2.626_861_3, -5.394_068, -0.961_980_4, -2.162_345, -20.723_267], + [-3.829_590_8, -5.279_662, -9.647_869, -7.151_901_2, -6.973_281, -8.775_774, -6.730_87, -6.306_915_3, -5.006_757_7, -11.513_105, -6.914_897_4, -7.158_250_3, -4.238_712, -7.696_574, -4.852_602, -6.468_646, -11.513_105, -6.261_517_5, -6.552_708_6, -8.077_111, -4.138_724_3, -11.513_105, -4.975_327_5, -8.290_131, -5.404_796_6, -11.513_105, -3.067_354, -20.723_267], + [-3.567_029, -10.307_745, -3.879_610_8, -8.355_694, -3.945_243_6, -7.539_704, -8.209_732, -9.787_797, -3.231_729, -11.513_105, -7.435_984, -6.633_628_4, -8.991_505, -3.292_566_8, -4.787_024, -10.329_214, -11.513_105, -4.649_024, -4.578_408_2, -5.882_728_6, -2.919_008_5, -11.513_105, -7.183_579, -1.746_624_2, -7.366_245_7, -7.672_502_5, -3.292_724_4, -20.723_267], + [-2.840_364_5, -7.370_669, -9.647_869, -4.581_870_6, -2.539_729, -7.668_403_6, -7.623_522, -8.273_909, -3.187_738, -11.513_105, -8.247_089, -2.392_907_1, -8.366_328, -1.767_054_9, -4.048_143, -9.203_718, -11.513_105, -3.572_408_7, -7.468_721, -9.571_307, -4.192_383_3, -11.513_105, -6.453_733, -8.944_214, -6.014_779, -11.513_105, -3.493_481, -20.723_267], + [-8.789_968_5, -0.882_509_6, -1.761_314_5, -2.229_609, -3.460_188_2, -2.379_223_8, -2.100_699, -0.839_311_1, -3.267_797_2, -1.279_895, -1.137_539_4, -1.836_477_3, -1.454_008, -2.393_423_3, -5.954_031_5, -1.575_469_5, -11.513_105, -1.506_992_6, -2.176_778_3, -2.483_203, -3.771_802, -0.265_394_7, -1.942_219, -2.516_527_7, -3.196_823, -1.083_491_7, -3.787_814_6, -20.723_267], + [-4.656_895_6, -10.307_745, -11.513_105, -6.434_101, -4.661_793, -3.189_781, -8.278_472, -7.476_047, -3.797_279_8, -11.513_105, -4.357_318, -3.576_648_7, -5.378_657, -5.115_966_3, -2.389_812_7, -8.831_921, -11.513_105, -4.785_914_4, -5.850_82, -6.378_231_5, -5.453_339, -11.513_105, -5.192_757, -5.090_478, -6.588_798_5, -11.513_105, -3.367_658, -20.723_267], + [-3.935_778_6, -11.513_105, -11.513_105, -5.198_305, -5.182_173_3, -9.076_6, -4.308_548, -9.307_589, -3.593_207_8, -11.513_105, -7.141_876_7, -6.870_805_7, -10.692_42, -2.028_623, -5.771_454, -9.323_16, -11.513_105, -5.052_397_3, -6.821_178, -9.731_103, -2.909_238_3, -11.513_105, -9.689_901, -8.944_214, -8.006_65, -11.513_105, -4.130_123, -20.723_267], + [-6.750_226, -8.066_504, -1.719_597, -6.110_409_7, -6.189_956_7, -7.899_103_6, -1.912_184_1, -9.160_122, -9.226_412, -11.513_105, -5.135_840_4, -8.518_253, -8.545_525, -6.787_862, -5.718_264_6, -4.609_116, -11.513_105, -5.813_02, -2.525_067_6, -1.224_304_8, -10.253_361, -11.513_105, -1.766_828_5, -4.545_65, -7.245_962_6, -7.672_502_5, -2.483_950_6, -20.723_267], + [-3.126_791_7, -3.799_575, -3.186_328_2, -2.768_115_3, -4.389_353, -2.612_637, -3.096_265_6, -2.015_972_1, -8.032_84, -3.987_442, -1.910_051_2, -2.178_937, -2.289_748, -3.666_899_2, -4.623_344, -2.913_505_8, -11.513_105, -2.678_056_2, -2.861_862, -2.780_634, -3.503_753, -2.036_225_6, -1.760_814_5, -2.264_577_6, -4.232_542, -2.710_460_4, -2.579_534_8, -20.723_267], + [-9.585_668, -4.257_89, -11.513_105, -8.256_346, -7.803_908_3, -11.513_105, -10.480_922, -11.057_719, -11.513_105, -11.513_105, -9.365_77, -9.140_743, -10.692_42, -6.505_830_3, -9.031_633, -10.755_326, -11.513_105, -9.805_901, -10.500_641, -10.708_67, -10.696_741, -11.513_105, -10.603_504, -11.513_105, -10.070_196, -11.513_105, -5.602_34, -20.723_267], + [-4.254_719_7, -11.513_105, -3.200_003, -8.096_455, -6.754_756_5, -10.923_909, -9.805_095, -10.508_677_5, -4.890_956, -11.513_105, -8.732_795, -4.552_042, -11.020_834, -4.463_423_3, -4.283_471_6, -8.257_408, -11.513_105, -5.009_569_6, -5.384_791_4, -11.325_874, -8.812_76, -11.513_105, -7.243_991, -11.513_105, -10.551_27, -11.513_105, -5.107_608, -20.723_267], + [-2.719_842, -2.171_247_2, -3.465_415_2, -4.541_519, -3.172_815_8, -4.154_808_5, -3.646_469_4, -6.766_412, -3.120_190_6, -11.513_105, -4.313_071_3, -1.951_761_5, -5.967_743, -4.602_898_6, -3.833_565, -2.322_928_4, -11.513_105, -4.435_82, -5.262_661, -4.106_762_4, -2.086_468_5, -11.513_105, -5.110_801_7, -6.713_080_4, -5.764_374_7, -3.582_983_7, -3.718_748_6, -20.723_267], + [-3.592_866_7, -6.562_377_5, -10.936_696, -5.287_308, -3.877_893, -9.623_346, -6.248_005_4, -6.303_758, -3.124_795_2, -11.513_105, -7.080_122_5, -5.256_596_6, -3.885_822_5, -7.657_765_4, -2.880_621, -7.727_230_5, -11.513_105, -4.395_987, -5.172_2, -6.633_761, -4.039_342, -11.513_105, -8.661_436, -11.513_105, -5.230_367_7, -11.513_105, -2.871_839_8, -20.723_267], + [-1.597_543_4, -9.295_766, -10.097_605, -5.292_990_7, -2.487_039_3, -9.313_064, -4.270_337_6, -6.744_131_6, -1.308_721_7, -11.513_105, -1.928_257_2, -5.963_909, -6.134_99, -4.110_112_7, -1.975_486_4, -9.323_16, -11.513_105, -3.969_218_3, -6.406_316_3, -6.633_761, -2.287_234_3, -10.490_972, -3.211_650_1, -11.513_105, -6.668_236, -11.513_105, -3.653_006, -20.723_267], + [-7.741_305_4, -2.454_400_5, -1.525_927, -2.863_657_2, -6.311_463, -1.725_130_9, -2.603_309_2, -2.540_201_4, -3.049_867, -1.419_446_5, -6.538_593, -2.566_781_3, -2.362_048_6, -2.469_979_5, -3.197_990_2, -2.080_492, -11.513_105, -2.697_022_7, -2.865_752, -2.179_993, -6.900_775, -3.221_613_4, -2.312_064_2, -4.879_478_5, -1.818_898_6, -4.474_464, -2.843_442_7, -20.723_267], + [-4.030_286_3, -11.513_105, -10.573_381, -9.024_024, -4.715_614_3, -7.836_251, -7.501_858, -9.787_797, -5.341_391, -11.513_105, -8.247_089, -5.940_762_5, -3.205_545_4, -7.367_222, -4.124_150_8, -2.470_337_9, -11.513_105, -5.421_519_3, -3.863_525_9, -8.861_415, -3.146_821_3, -11.513_105, -8.619_498, -1.339_567_1, -7.223_549_4, -11.513_105, -3.670_306_2, -20.723_267], + [-10.869_481, -11.513_105, -5.161_777_5, -10.273_109, -6.273_897, -10.923_909, -10.869_361, -10.620_307, -8.873_802, -11.513_105, -11.513_105, -10.857_676, -11.513_105, -6.483_735, -9.112_832, -10.755_326, -11.513_105, -8.188_78, -7.532_092, -10.912_28, -10.088_636, -11.513_105, -11.513_105, -4.863_750_5, -10.551_27, -11.513_105, -6.010_124_7, -20.723_267], + [-2.380_518_7, -2.971_973_2, -3.193_902_3, -4.207_732, -1.864_907_7, -2.681_471_8, -2.695_279_8, -4.959_060_7, -3.280_368_3, -11.513_105, -8.155_254, -6.200_581_6, -2.820_928_8, -6.244_217, -2.220_730_8, -1.956_054, -11.513_105, -3.837_347, -7.571_002_5, -3.770_41, -1.872_708_4, -6.901_649, -4.567_485, -11.513_105, -7.268_889_4, -11.513_105, -3.974_993, -20.723_267], + [-2.199_485_3, -4.014_287, -7.049_959, -3.797_005_2, -2.864_42, -7.275_875_6, -3.802_058_2, -6.339_048_4, -2.140_158_4, -11.513_105, -3.389_352_3, -4.379_51, -3.850_523_7, -3.391_756_5, -3.693_605_7, -3.922_179, -11.513_105, -2.862_128_3, -2.919_803_9, -4.157_688, -2.148_169, -11.513_105, -4.797_732_4, -5.466_213, -3.208_151, -11.513_105, -2.598_197_7, -20.723_267], + [-2.041_496_8, -4.551_41, -2.453_904_6, -7.047_818, -3.565_677, -3.377_406_8, -5.630_275, -3.469_57, -1.993_827_8, -11.513_105, -7.791_498_7, -4.063_527, -7.335_05, -2.424_767_5, -2.803_703_8, -3.327_58, -11.513_105, -3.117_431_9, -2.197_329_8, -3.673_175_6, -2.026_581_5, -11.513_105, -7.587_344, -1.843_796_1, -4.424_683_6, -6.990_154, -2.010_327_6, -20.723_267], + [-4.493_191_7, -2.045_081_9, -3.441_515_4, -4.625_169_8, -7.377_274, -3.366_884_7, -4.019_038_7, -4.824_732_3, -7.603_225_7, -1.204_362, -8.348_22, -4.263_171, -3.008_42, -5.452_8, -1.893_440_6, -3.538_099, -0.001_714_452_5, -4.397_697, -3.250_448_5, -4.170_668_6, -10.450_676, -6.800_828_5, -8.467_393, -4.592_125, -10.070_196, -5.612_035_3, -4.575_990_7, -20.723_267], + [-3.431_514_7, -8.532_299, -11.513_105, -5.830_799_6, -3.924_140_2, -10.555_642, -11.513_105, -10.889_751, -3.970_075_8, -11.513_105, -10.254_454, -5.386_309_6, -10.247_817, -5.416_301_3, -4.264_887, -11.513_105, -11.513_105, -4.975_447_7, -9.886_048, -10.912_28, -7.955_444_3, -11.513_105, -10.958_263, -5.653_776, -6.996_506, -7.672_502_5, -4.962_283_6, -20.723_267], + [-4.383_549, -9.175_94, -11.513_105, -6.060_079_6, -5.145_225, -8.170_727, -7.974_089_6, -6.377_324_6, -10.387_997, -11.513_105, -5.921_405_3, -4.953_966_6, -9.043_949, -6.696_472, -3.019_770_4, -7.606_258_4, -11.513_105, -5.849_935, -5.396_963_6, -4.984_56, -9.293_391, -11.513_105, -8.849_356, -7.218_435, -6.520_682_3, -4.082_706_5, -2.607_651, -20.723_267], + [-7.124_895, -11.513_105, -11.513_105, -11.513_105, -4.681_599, -11.513_105, -11.513_105, -11.513_105, -6.373_989, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -6.898_58, -7.228_384, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -7.596_529_5, -11.513_105, -11.513_105, -5.466_213, -10.920_747, -11.513_105, -9.556_946, -20.723_267], + [-3.479_500_8, -2.599_211_7, -4.462_72, -4.215_739_3, -4.337_405_7, -6.045_270_4, -6.257_39, -5.687_123, -11.281_939, -11.513_105, -5.127_308, -2.319_987_3, -2.680_487_9, -4.002_873_4, -5.709_786, -4.407_800_7, -8.677_725, -3.129_120_3, -5.966_902, -4.185_307, -7.634_954, -4.926_794_5, -7.647_650_2, -6.452_617, -8.497_671, -2.881_854_5, -4.042_910_6, -20.723_267], + [-7.170_823_6, -11.513_105, -11.513_105, -11.513_105, -9.303_842, -11.513_105, -10.201_862, -11.513_105, -5.639_653_7, -11.513_105, -11.513_105, -11.513_105, -11.513_105, -9.832_733, -8.523_063, -11.513_105, -11.513_105, -10.723_01, -11.513_105, -9.001_49, -7.790_009_5, -11.513_105, -10.958_263, -11.513_105, -10.920_747, -2.501_138_4, -10.541_421, -20.723_267], + [-2.715_362_5, -5.991_672_5, -5.186_562_5, -0.495_729_1, -1.120_387_1, -0.994_527_76, -1.014_298_6, -2.282_205, -2.596_138_5, -11.513_105, -1.262_556_9, -2.028_100_3, -1.974_505_3, -1.473_895_5, -1.899_170_9, -2.886_320_6, -8.677_725, -1.367_780_4, -1.043_720_5, -1.285_834_2, -2.472_834_8, -7.485_031, -2.315_513_6, -2.507_613_4, -0.409_93, -4.361_229, -20.723_267, 1e-09], + [-7.057_105, -7.864_855_3, -6.832_732_7, -3.767_181_4, -4.200_749, -4.790_685, -4.105_086, -5.590_689, -7.180_27, -5.557_085, -4.052_763, -4.921_688_6, -4.251_131, -4.469_908_7, -6.025_704, -5.421_922, -6.783_435, -3.429_898, -3.662_474, -4.345_318, -5.380_872_7, -8.071_246, -5.128_164_3, -4.732_333_7, -3.275_528_7, -11.513_105, -20.723_267, -20.723_267], +]; + +/// Conv2D with middle element masked +#[derive(Debug)] +struct MaskedConv { + pub convl: nn::Conv2D, + pub convr: nn::Conv2D, +} + +impl MaskedConv { + fn new( + vs: &nn::Path, + in_dim: i64, + out_dim: i64, + kernel_dim: [i64; 2], + config: nn::ConvConfigND<[i64; 2]>, + ) -> MaskedConv { + MaskedConv { + convl: nn::conv(vs, in_dim, out_dim, kernel_dim, config), + convr: nn::conv(vs, in_dim, out_dim, kernel_dim, config), + } + } +} + +impl Module for MaskedConv { + fn forward(&self, xs: &Tensor) -> Tensor { + // tch-rs doesn't like negative indices in slices + (self.convl.forward(&xs.i((.., NewAxis, KERNEL + 1.., ..))) + + self + .convr + .forward(&xs.i((.., NewAxis, ..xs.size()[1] - KERNEL - 1, ..)))) + .squeeze() + .transpose(-2, -1) + } +} + +fn net(vs: &nn::Path) -> impl Module + use<> { + nn::seq() + .add(nn::embedding(vs, VOCAB as i64, EMBED, Default::default())) + .add(MaskedConv::new( + vs, + 1, + HIDDEN, + [KERNEL, EMBED], + Default::default(), + )) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, HIDDEN, HIDDEN, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, HIDDEN, HIDDEN, Default::default())) + .add_fn(|xs| xs.relu()) + .add(nn::linear(vs, HIDDEN, VOCAB as i64, Default::default())) + .add_fn(|xs| xs.log_softmax(-1, Kind::Float)) +} + +fn probs(net: &impl Module, s: &[u8]) -> (Tensor, f64) { + let data: Vec<_> = s.iter().map(|x| *x as i64).collect(); + // We only do inference on CPU so no need to move to device + let xs = Tensor::from_slice2(&[data]); + // Remove batch dim + let ys = net.forward(&xs).squeeze(); + let loss = ys + .nll_loss(&xs.i((0, KERNEL..s.len() as i64 - KERNEL))) + .double_value(&[]); + (ys, loss) +} + +/// Pad both sides with spaces +fn probs_padded(net: &impl Module, s: &[u8]) -> (Tensor, f64) { + probs( + net, + &[&[26_u8; KERNEL as usize], s, &[26_u8; KERNEL as usize]].concat(), + ) +} + +fn char_to_label(c: char) -> u8 { + if c.is_ascii_lowercase() { + c as u8 - b'a' + } else if c == ' ' { + 26 + } else { + 27 + } +} + +fn label_to_char(l: u8) -> char { + match l.cmp(&26) { + Ordering::Less => (b'a' + l) as char, + Ordering::Equal => ' ', + Ordering::Greater => '.', + } +} + +fn to_string(s: &[u8]) -> String { + s.iter().map(|l| label_to_char(*l)).collect() +} + +fn permute(s: &[u8], p: &[u8; VOCAB]) -> Vec<u8> { + s.iter().map(|l| p[*l as usize]).collect() +} + +fn logprob( + s: &[u8], + p: &[u8; VOCAB], + cnts: &[[i32; VOCAB]; VOCAB], + grams: &[[[f32; VOCAB]; VOCAB]; VOCAB], +) -> f32 { + let mut lp = 0.; + if s.len() < 300 { + // Use trigrams for short and hard sequences + for i in 1..s.len() { + lp += TRANS[p[s[i] as usize] as usize][p[s[i - 1] as usize] as usize]; + } + lp += TRANS[p[s[1] as usize] as usize][p[s[0] as usize] as usize]; + for i in 2..s.len() { + lp += grams[p[s[i - 2] as usize] as usize][p[s[i - 1] as usize] as usize] + [p[s[i] as usize] as usize] + / 2.; + } + } else { + for i in 0..VOCAB { + for j in 0..VOCAB { + lp += cnts[i][j] as f32 * TRANS[p[j] as usize][p[i] as usize]; + } + } + } + // unsafe { + // if s.len() < 300 { + // // Use trigrams for short and hard sequences + // for i in 1..s.len() { + // lp += TRANS + // .get_unchecked(*p.get_unchecked(s[i] as usize) as usize) + // .get_unchecked(*p.get_unchecked(s[i - 1] as usize) as usize); + // } + // lp += TRANS[p[s[1] as usize] as usize][p[s[0] as usize] as usize]; + // for i in 2..s.len() { + // lp += grams + // .get_unchecked(*p.get_unchecked(s[i - 2] as usize) as usize) + // .get_unchecked(*p.get_unchecked(s[i - 1] as usize) as usize) + // .get_unchecked(*p.get_unchecked(s[i] as usize) as usize) + // / 2.; + // } + // } else { + // for i in 0..VOCAB { + // for j in 0..VOCAB { + // lp += cnts[i][j] as f32 + // * TRANS + // .get_unchecked(p[j] as usize) + // .get_unchecked(p[i] as usize); + // } + // } + // } + // } + lp +} + +fn weights(s: &[u8], p: &[u8; VOCAB], ys: &Tensor) -> Vec<f64> { + // dbg!(to_string(&permute(s, p))); + let mut swaps = [[0.; VOCAB]; VOCAB]; + let mut cnts = [[0; VOCAB]; VOCAB]; + for i in 0..s.len() { + for j in 0..VOCAB { + // NOTE: ys is padded + let v = ys.i((i as i64, j as i64)).double_value(&[]); + swaps[p[s[i] as usize] as usize][j] += v; + cnts[p[s[i] as usize] as usize][j] += 1; + } + } + let mut w = vec![0.; VOCAB * VOCAB]; + // print!(" "); + // for i in 0..VOCAB { + // print!("{:5} ", p + // .iter() + // .position(|x| *x as usize == i) + // .unwrap()); + // } + // println!(); + for i in 0..VOCAB { + // print!("{:2} ", p + // .iter() + // .position(|x| *x as usize == i) + // .unwrap()); + for j in 0..VOCAB { + if j < i { + if cnts[i][j] > 0 { + w[i * VOCAB + j] += swaps[i][j] / cnts[i][j] as f64; + } + if cnts[j][i] > 0 { + w[i * VOCAB + j] += swaps[j][i] / cnts[j][i] as f64; + } + // print!("{:>5} ", swaps[i][j] / cnts[i][j] as f64); + // Oh no floating point comparison spooky spooky + if w[i * VOCAB + j] != 0. { + // TODO: adjust this factor + w[i * VOCAB + j] = (w[i * VOCAB + j] / 2.5).exp(); + } + // print!("{:5.2} ", w[i * VOCAB + j]); + } + } + // println!(); + } + w +} + +fn refiner(net: &impl Module, s: &[u8], p2: &[u8; VOCAB], iters: usize) -> ([u8; VOCAB], f64) { + let mut p = *p2; + // Cleanup using CNN + let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng()); + // Use padding to properly handle the ends + let (mut ys, mut loss) = probs_padded(net, &permute(s, &p)); + // dbg!(loss, to_string(&permute(s, &p))); + if loss > 3. { + // Not worth trying + return (p, loss); + } + let mut pbest = *p2; + let mut lossbest = loss; + let mut w = weights(s, &p, &ys); + for _ in 0..300 * iters / s.len() { + let dist = WeightedIndex::new(&w).unwrap(); + let sample = dist.sample(&mut rng); + let a = p + .iter() + .position(|x| *x as usize == sample / VOCAB) + .unwrap(); + let b = p + .iter() + .position(|x| *x as usize == sample % VOCAB) + .unwrap(); + let mut q = p; + (q[a], q[b]) = (q[b], q[a]); + let (ys2, loss2) = probs_padded(net, &permute(s, &q)); + let acc = f64::min(0., (loss - loss2) * 20.); + if rng.random::<f64>() < acc.exp() { + ys = ys2; + p = q; + loss = loss2; + w = weights(s, &p, &ys); + } + if loss2 < lossbest { + lossbest = loss2; + pbest = q; + } + } + // dbg!(lossbest, to_string(&permute(s, &pbest))); + (pbest, lossbest) +} + +fn decode( + net: &impl Module, + s: &[u8], + grams: &[[[f32; VOCAB]; VOCAB]; VOCAB], +) -> ([u8; VOCAB], f64) { + // Initialize using naive freqs + let mut porig = [0; VOCAB]; + let mut cnt = [0; VOCAB]; + for c in s { + cnt[*c as usize] += 1; + } + let mut indices: Vec<_> = (0..VOCAB).collect(); + indices.sort_unstable_by_key(|&a| cnt[a]); + for i in 0..VOCAB { + porig[indices[i]] = VOCAB_ORDER[i]; + } + // MCMC for some iters + let mut cnts = [[0; VOCAB]; VOCAB]; + for i in 1..s.len() { + cnts[s[i - 1] as usize][s[i] as usize] += 1; + } + // Is this easy to decode? + let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng()); + for _ in 0..20 { + let mut p = porig; + let mut lp = logprob(s, &p, &cnts, grams); + for _ in 0..10000 { + let a = rng.random_range(0..VOCAB); + let b = rng.random_range(0..VOCAB); + if a != b { + let mut q = p; + (q[a], q[b]) = (q[b], q[a]); + let lp2 = logprob(s, &q, &cnts, grams); + let acc = f32::min(0., lp2 - lp); + if rng.random::<f32>() < acc.exp() { + p = q; + lp = lp2; + } + } + } + let (p2, loss) = refiner(net, s, &p, 100); + // dbg!(to_string(&permute(s, &p2)), loss); + if loss < 0.7 { + // Probably done + return refiner(net, s, &p2, 1000); + } + } + // Nope we're gonna have to try hard ugh + // FEARLESS CONCURRENCY + let real_iters = ITERS * 100 / s.len() / THREADS * THREADS; + let mut res: Vec<([u8; VOCAB], f32)> = vec![]; + thread::scope(|sc| { + let mut handles = vec![]; + for _ in 0..THREADS { + let handle = sc.spawn(move || { + let mut tmp = vec![]; + let mut rng = Xoshiro256PlusPlus::from_rng(&mut rand::rng()); + for _k in 0..real_iters / THREADS { + let mut p = porig; + let mut lp = logprob(s, &p, &cnts, grams); + let mut pbest = p; + let mut lpbest = lp; + // TODO: investigate if we need more/less iters here + for _ in 0..10000 { + let a = rng.random_range(0..VOCAB); + let b = rng.random_range(0..VOCAB); + if a != b { + let mut q = p; + (q[a], q[b]) = (q[b], q[a]); + // We *could* reuse some of the computation from lp + // But logprob is already a tight nested loop so it's not much faster + // And floating point garbage builds up + let lp2 = logprob(s, &q, &cnts, grams); + let acc = f32::min(0., lp2 - lp); + if rng.random::<f32>() < acc.exp() { + p = q; + lp = lp2; + if lp2 > lpbest { + lpbest = lp2; + pbest = q; + } + } + } + } + // dbg!(_k, to_string(&permute(s, &pbest)), lpbest); + tmp.push((pbest, lpbest)); + } + tmp + }); + handles.push(handle); + } + for handle in handles { + res.extend(&handle.join().unwrap()); + } + }); + let mut ord = (0..real_iters).collect::<Vec<_>>(); + ord.sort_unstable_by(|&i, &j| res[j].1.partial_cmp(&res[i].1).unwrap()); + // Try 50 best perms + let mut pbest = porig; + let mut lossbest = 100.; + // Don't retry perms that decode to the same thing + let mut tried = HashSet::new(); + let mut i = 0; + while tried.len() < 40 { + let p2 = res[ord[i]].0; + let s2 = permute(s, &p2); + if !tried.contains(&s2) { + tried.insert(s2); + let (p, loss) = refiner(net, s, &p2, 50); + if loss < lossbest { + pbest = p; + lossbest = loss; + } + if lossbest < 0.7 { + // Probably done + break; + } + } + i += 1; + } + refiner(net, s, &pbest, 1000) +} + +fn finish( + net: &impl Module, + text: &[u8], + pl: &[u8; 28], + pr: &[u8; 28], + jbest: usize, +) -> (Vec<u8>, f64) { + let pl2 = &refiner(net, &text[..jbest], pl, 1000).0; + let pr2 = &refiner(net, &text[jbest..], pr, 1000).0; + let ans = [permute(&text[..jbest], pl2), permute(&text[jbest..], pr2)].concat(); + (ans.clone(), probs(net, &ans).1) +} + +pub fn main() { + let device = Device::cuda_if_available(); + // dbg!(device); + let mut vs = nn::VarStore::new(device); + let net = net(&vs.root()); + let args: Vec<String> = env::args().collect(); + // dbg!(args.len()); + if args.len() == 3 { + no_grad(|| { + // Load trigrams + let file = File::open("trigrams").unwrap(); + let mut reader = BufReader::new(file); + let mut grams = [[[0.; VOCAB]; VOCAB]; VOCAB]; + for i in 0..VOCAB { + for j in 0..VOCAB { + for k in 0..VOCAB { + let mut buf = String::new(); + reader.read_line(&mut buf).unwrap(); + buf.pop(); // Remove newline + grams[i][j][k] = buf.parse().unwrap(); + } + } + } + // Decode + vs.load("model.safetensors").unwrap(); + let text: Vec<_> = args[1].chars().map(char_to_label).collect(); + if args[2].to_lowercase() != "true" { + // No breakpoint + // dbg!(probs(&net, &text).1); + // return; + let p = decode(&net, &text, &grams).0; + println!("{}", to_string(&permute(&text, &p))); + } else { + // Decode left and right halves + let m = text.len() / 2; + let (pl, lossl) = decode(&net, &text[..m], &grams); + let (pr, lossr) = decode(&net, &text[m..], &grams); + // dbg!(to_string(&permute(&text[..m], &pl))); + // dbg!(to_string(&permute(&text[m..], &pr))); + // dbg!(lossl, lossr); + let mut reallossbest = 100.; + let mut ans = vec![]; + // let mut should_loop = true; + // This is spaghetti code yeah I know + if lossl < 1.5 { + // First half correct + let mut deque = VecDeque::new(); + let mut sum = 0.; + let ys = probs(&net, &permute(&text[m - WINDOW..], &pl)).0; + for i in m - WINDOW + KERNEL as usize..text.len() - KERNEL as usize { + deque.push_back( + -ys.i(( + (i + WINDOW - m) as i64 - KERNEL, + pl[text[i] as usize] as i64, + )) + .double_value(&[]), + ); + sum += deque.back().unwrap(); + if deque.len() > WINDOW { + sum -= deque.front().unwrap(); + deque.pop_front(); + } + // dbg!(i, sum / WINDOW as f64); + if deque.len() == WINDOW && sum / WINDOW as f64 > 4. { + // Breakpoint probably in i - WINDOW to i - 10 + let pr2 = decode(&net, &text[i - 10..], &grams).0; + let mut jbest = 0; + let mut lossbest = 100.; + let mut pt = permute( + &text[i - WINDOW - KERNEL as usize..i + KERNEL as usize], + &pr2, + ); + for j in i - WINDOW..i { + pt[j + WINDOW - i + KERNEL as usize] = pl[text[j] as usize]; + let loss = probs(&net, &pt).1; + // dbg!(j, to_string(&pt), loss); + if loss < lossbest { + jbest = j + 1; // Off-by-1 error + lossbest = loss; + } + } + // dbg!(jbest, i); + (ans, reallossbest) = finish(&net, &text, &pl, &pr2, jbest); + // should_loop = false; + break; + } + } + if reallossbest == 100. { + // No breakpoint found??? + ans = permute(&text, &refiner(&net, &text, &pl, 1000).0); + reallossbest = probs(&net, &ans).1; + } + } + if lossl >= 1.5 || lossr < 1.5 { + // Second half correct + // Same alg, just reversed + let mut deque = VecDeque::new(); + let mut sum = 0.; + let ys = probs(&net, &permute(&text[..m + WINDOW], &pr)).0; + // TODO: Are these indices correct??? + for i in (KERNEL as usize..m + WINDOW - KERNEL as usize).rev() { + deque.push_back( + -ys.i((i as i64 - KERNEL, pr[text[i] as usize] as i64)) + .double_value(&[]), + ); + sum += deque.back().unwrap(); + if deque.len() > WINDOW { + sum -= deque.front().unwrap(); + deque.pop_front(); + } + if deque.len() == WINDOW && sum / WINDOW as f64 > 4. { + // Breakpoint probably in i + 10 to i + WINDOW + let pl2 = decode(&net, &text[..i + 10], &grams).0; + let mut jbest = 0; + let mut lossbest = 100.; + let mut pt = permute( + &text[i - KERNEL as usize..i + WINDOW + KERNEL as usize], + &pr, + ); + for j in i..i + WINDOW { + pt[j + KERNEL as usize - i] = pl2[text[j] as usize]; + let loss = probs(&net, &pt).1; + // dbg!(j, to_string(&pt), loss); + if loss < lossbest { + jbest = j + 1; // Off-by-1 error + lossbest = loss; + } + } + // dbg!(jbest, i); + let (ans2, loss2) = finish(&net, &text, &pl2, &pr, jbest); + if loss2 < reallossbest { + ans = ans2; + reallossbest = loss2; + } + // should_loop = false; + break; + } + } + if reallossbest == 100. { + // No breakpoint found??? + ans = permute(&text, &refiner(&net, &text, &pr, 1000).0); + // reallossbest = probs(&net, &ans).1; + } + } + // if should_loop { + // // failed to find breakpoint + // sleep(time::Duration::from_secs(1000000)); + // } + println!("{}", to_string(&ans)); + } + }) + } else { + // vs.load("model.safetensors").unwrap(); + // Train model + let mut opt = nn::Adam::default().build(&vs, LR).unwrap(); + opt.set_weight_decay(WD); + let file = File::open("wikitext").unwrap(); + let reader = BufReader::new(file); + let mut data = vec![]; + for line in reader.lines() { + // net doesn't accept u8 as input, must be i64 + let linedata: Vec<_> = line + .unwrap() + .chars() + .map(|c| char_to_label(c) as i64) + .collect(); + for i in 0..linedata.len() / SEQ_LEN { + // Pad with spaces + // I should probably use a dedicated padding token but I don't want to retrain + data.push( + [ + &[26; KERNEL as usize], + &linedata[i * SEQ_LEN..(i + 1) * SEQ_LEN], + &[26; KERNEL as usize], + ] + .concat(), + ); + } + } + for i in 0..EPOCHS { + for j in 0..data.len() / BS { + let xs = Tensor::from_slice2(&data[j * BS..(j + 1) * BS]).to(device); + let ys = net.forward(&xs).transpose(-2, -1); + let loss = ys.nll_loss_nd::<Tensor>( + &xs.i((.., KERNEL..SEQ_LEN as i64 + KERNEL)), + None, + Reduction::Mean, + -100, + ); + println!("{i} {j} {loss}"); + opt.backward_step(&loss); + } + } + vs.save("model.safetensors").unwrap(); + // vs.freeze(); + // let mut closure = |input: &[Tensor]| vec![net.forward(&input[0])]; + // let model = CModule::create_by_tracing( + // "MyModule", + // "forward", + // &[Tensor::zeros([1, 784], (tch::Kind::Int64, device))], + // &mut closure, + // ) + // .unwrap(); + // // I think this has the input size hardcoded though sad + // model.save("model.pt").unwrap(); + } +} |