aboutsummaryrefslogtreecommitdiff
path: root/fenwick.dfy
blob: aeee011500fe36e491c9746bf6e6520f22ad8035 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
// Proof of some functions from fenwick.c
// Proof walkthough at:
// https://unnamed.website/posts/formally-verifying-fenwick-trees/
// https://unnamed.website/posts/actually-useful-formal-verification/

// TODO: Convert comments to docstrings
// Although Dafny has some pretty weird docstring syntax
// https://dafny.org/latest/DafnyRef/DafnyRef#sec-documentation-comments

// Dafny only supports bitwise ops on bitvectors, not ints
// But bitvectors are super gross
// Ex: https://github.com/dafny-lang/dafny/wiki/Bit-Vector-Cookbook
// So let's implement i & -i for ints
// Basically the largest power of 2 that divides i
// opaque for perf reasons since be always use lsb_* lemmas to reason about it
// https://dafny.org/latest/VerificationOptimization/VerificationOptimization#opaque-definitions
// ghost since we use this function for proofs and lsb_fast for executed code
opaque ghost function lsb(i: int): int
    requires i > 0
    ensures 0 < lsb(i) <= i
{
    if i % 2 == 1 then 1 else 2 * lsb(i / 2)
}

// 32-bit unsigned integer
// TODO: nativeType only works for newtype
// https://dafny.org/latest/DafnyRef/DafnyRef#sec-nativetype
type {:nativeType "uint"} u32 = x: int | 0 <= x <= 0xFFFFFFFF

// Faster implementation of lsb
// Use the bitwise ops if i is small
opaque function lsb_fast(i: u32): u32
    requires i > 0
    ensures 0 < lsb_fast(i) <= i
{
    ((i as bv32) & -(i as bv32)) as u32
}

// Verify lsb(i) == i & -i
lemma lsb_correct(i: u32)
    requires i > 0
    ensures lsb(i) == lsb_fast(i)
{
    // Dafny can autoprove this using inlining
    // Literally just check every i
    reveal lsb;
    reveal lsb_fast;
}

// Crucial lemma used in several places
// i + lsb(i) is the parent of i
// The lsb of i is less than the lsb of its parent
// Definitely the most important property of lsb
lemma lsb_add(i: int)
    requires i > 0
    ensures 2 * lsb(i) <= lsb(i + lsb(i))
{
    reveal lsb;
    if i % 2 == 0 {
        lsb_add(i / 2);
    }
}

// No intervals touch i between i and i + lsb(i)
lemma lsb_next(i: int)
    requires i > 0
    ensures forall j :: i < j < i + lsb(i) ==> j - lsb(j) >= i
{
    reveal lsb;
}

// If the i and i + lsb(i) intervals start at the same place
// Then i is the first child for i + lsb(i)
lemma lsb_no_par(i: int, j: int)
    requires 0 < j < i
    requires i - lsb(i) == i + lsb(i) - lsb(i + lsb(i))
    ensures j + lsb(j) != i + lsb(i)
{
    reveal lsb;
    lsb_add(j);
}

// If the interval for i starts after the interval for i + lsb(i)
// Then i - lsb(i) also has parent i + lsb(i)
lemma lsb_shared_par(i: int)
    requires i > 0
    requires i - lsb(i) > i + lsb(i) - lsb(i + lsb(i))
    ensures i - lsb(i) + lsb(i - lsb(i)) == i + lsb(i)
{
    reveal lsb;
    if i % 2 == 0 {
        lsb_shared_par(i / 2);
    }
}

// If j is pretty close to its parent
// Then its parent cannot be i + lsb(i)
lemma lsb_bad_par(i: int, j: int)
    requires 0 < j < i
    requires j + (lsb(j) + 1) / 2 > i
    ensures j + lsb(j) != i + lsb(i)
{
    reveal lsb;
}

// Similar to lsb_add
lemma lsb_sub(i: int)
    requires i > 0
    requires i > lsb(i)
    ensures 2 * lsb(i) <= lsb(i - lsb(i))
{
    reveal lsb;
    if i % 2 == 0 {
        lsb_sub(i / 2);
    }
}

// Some properties of powers of 2
// We define a power of 2 as a number i such that lsb(i) == i
lemma lsb_pow2(i: int)
    requires i > 1
    requires lsb(i) == i
    ensures lsb(i / 2) == i / 2
    ensures lsb(i) == 2 * lsb(i / 2)
{
    reveal lsb;
}

// lsb of sum is lsb of smaller
lemma lsb_sum(i: int, j: int)
    requires i > 0 && j > 0
    requires lsb(i) > lsb(j)
    ensures lsb(i + j) == lsb(j)
{
    reveal lsb;
    if j % 2 == 0 {
        lsb_sum(i / 2, j / 2);
    }
}

// Largest power of 2 less than or equal to i
function msb(i: int): int
    requires i > 0
    ensures 0 < msb(i) <= i
    ensures 2 * msb(i) > i
{
    if i == 1 then 1 else 2 * msb(i / 2)
}

// Verify that msb is indeed a power of 2
lemma msb_is_pow2(i: int)
    requires i > 0
    ensures lsb(msb(i)) == msb(i)
{
    reveal lsb;
}

// Pretty trivial recursive sum
// opaque for perf reasons since we always use split or split3 to reason about it
// ghost because we don't use it outside of proof
// So it doesn't need to be compiled
opaque ghost function sum(s: seq<int>): int
{
    if |s| == 0 then 0 else s[0] + sum(s[1..])
}

// Sum behaves nicely when splitting a slice
lemma split(A: seq<int>, i: int, j: int, k: int)
    decreases |A| - i
    requires 0 <= i <= j <= k <= |A|
    ensures sum(A[i..j]) + sum(A[j..k]) == sum(A[i..k])
{
    reveal sum;
    if i < j {
        split(A, i + 1, j, k);
    }
}

// Similar to split, but with an element in the middle
lemma split3(A: seq<int>, i: int, j: int, k: int)
    requires 0 <= i < j <= k <= |A|
    ensures sum(A[i..j - 1]) + A[j - 1] + sum(A[j..k]) == sum(A[i..k])
{
    reveal sum;
    split(A, i, j - 1, k);
    split(A, j - 1, j, k);
}

// Nonnegative arrays have nondecreasing prefix sums
lemma sum_nondecreasing(A: seq<int>, i: int, j: int)
    requires 0 <= i <= j <= |A|
    requires forall k :: 0 <= k < |A| ==> A[k] >= 0
    ensures sum(A[..i]) <= sum(A[..j])
{
    reveal sum;
    if j > i {
        sum_nondecreasing(A, i, j - 1);
        split(A, 0, j - 1, j);
    }
}

// Alright time for the fun stuff
// Buckle up, it's gonna get super crazy
class fenwick {
    // Size
    var N: int
    // "Ground truth" seq
    // It's a ghost spooky spooky
    // It's easier to use a seq here because unlike arrays, they're immutable
    // So it's easier for Dafny to reason about them
    // Anyways ghost arrays are weird in Dafny
    // And I couldn't figure out how to modify an element in a ghost array
    // So maybe I'm just dumb
    // 0-indexed and size N
    ghost var A: seq<int>
    // FT array
    // 1-indexed and size N + 1 for Fenwick reasons
    // Sorry...
    // const means we can modify F, but not reassign a completely new array to F
    const F: array<int>

    // Check if F is a valid FT over A
    ghost predicate valid()
        reads this
        reads F
    {
        N < 0x80000000 && |A| == N && F.Length == N + 1 && forall i :: 0 < i <= N ==> F[i] == sum(A[i - lsb(i)..i])
    }

    // Create FT from seq A'
    // Use {:isolate_assertions} for perf reasons
    constructor {:isolate_assertions} (A': seq<int>)
        requires |A'| < 0x80000000
        // https://dafny.org/latest/HowToFAQ/FAQUpdateArrayField
        ensures fresh(F)
        ensures A == A'
        ensures valid()
    {
        N := |A'|;
        A := A';
        F := new int[|A'| + 1](i => 0);
        new;

        for i: u32 := 1 to F.Length
            modifies F
            // Don't change A
            invariant A == A'
            // These three invariants are all interlocking
            // Yeah it's super complicated, sorry
            // Everything up to i is good
            invariant forall j :: 0 < j < i ==> F[j] == sum(A[j - lsb(j)..j])
            // Future stuff is blank
            invariant forall j :: i <= j <= N && j - lsb(j) / 2 >= i ==> F[j] == 0
            // Uhhhhhhh
            // The intervals covering i are in some pretty complicated state
            // Not sure why it's j + (lsb(j + 1) / 2 >= i other than that it works
            invariant forall j :: 0 < j < i && j + (lsb(j) + 1) / 2 >= i && j + lsb(j) <= N ==> lsb(j) < lsb(j + lsb(j)) && F[j + lsb(j)] == sum(A[j + lsb(j) - lsb(j + lsb(j))..j])
        {
            // Get initial state of F[i]
            if lsb(i) == 1 {
                assert F[i] == sum(A[i - lsb(i)..i - 1]) by { reveal sum; }
            } else {
                assert lsb(i - 1) == 1 by { reveal lsb; }
            }
            assert F[i] == sum(A[i - lsb(i)..i - 1]);
            assert A'[i - 1] == sum(A[i - 1..i]) by { reveal sum; }
            split(A, i - lsb(i), i - 1, i);
            F[i] := F[i] + A'[i - 1];
            // Yay F[i] is good now
            assert F[i] == sum(A[i - lsb(i)..i]);

            // Now comes the insanity
            // Took me 10 hours to write this
            lsb_correct(i);
            var j: u32 := i + lsb_fast(i);
            if j < F.Length {
                F[j] := F[j] + F[i];
                lsb_add(i);
                if i - lsb(i) == j - lsb(j) {
                    // No prev child
                    assert F[j] == F[i];
                    assert j - lsb(j) / 2 <= i;
                    assert F[j] == sum(A[j - lsb(j)..i]);
                    assert i + lsb(i) / 2 >= i && i + lsb(i) <= N;
                    assert i + lsb(i) / 2 >= i && i + lsb(i) <= N ==> F[i] == sum(A[j - lsb(j)..i]);
                    // Didn't break anything, hopefully
                    assert forall j :: 0 < j <= i && j + (lsb(j) + 1) / 2 > i && j + lsb(j) <= N ==> lsb(j) < lsb(j + lsb(j)) && F[j + lsb(j)] == sum(A[j + lsb(j) - lsb(j + lsb(j))..j]) by {
                        assert i + (lsb(i) + 1) / 2 >= i && i + lsb(i) <= N ==> lsb(i) < lsb(i + lsb(i)) && F[i + lsb(i)] == sum(A[i + lsb(i) - lsb(i + lsb(i))..i]);
                        forall j | 0 < j < i && j + (lsb(j) + 1) / 2 > i && j + lsb(j) <= N
                            ensures lsb(j) < lsb(j + lsb(j)) && F[j + lsb(j)] == sum(A[j + lsb(j) - lsb(j + lsb(j))..j])
                        {
                            lsb_no_par(i, j);
                        }
                    }
                } else {
                    assert i - lsb(i) > j - lsb(j);
                    // i - lsb(i) is prev child
                    lsb_shared_par(i);
                    assert lsb(i - lsb(i)) >= 2 * lsb(i);
                    // Extend from i - lsb(i) to i
                    split(A, j - lsb(j), i - lsb(i), i);
                    assert F[j] == sum(A[j - lsb(j)..i]);
                    // Didn't break anything, hopefully
                    assert forall j :: 0 < j <= i && j + (lsb(j) + 1) / 2 > i && j + lsb(j) <= N ==> lsb(j) < lsb(j + lsb(j)) && F[j + lsb(j)] == sum(A[j + lsb(j) - lsb(j + lsb(j))..j]) by {
                        assert i + (lsb(i) + 1) / 2 >= i && i + lsb(i) <= N ==> lsb(i) < lsb(i + lsb(i)) && F[i + lsb(i)] == sum(A[i + lsb(i) - lsb(i + lsb(i))..i]);
                        forall j | 0 < j < i && j + (lsb(j) + 1) / 2 > i && j + lsb(j) <= N
                            ensures lsb(j) < lsb(j + lsb(j)) && F[j + lsb(j)] == sum(A[j + lsb(j) - lsb(j + lsb(j))..j])
                        {
                            lsb_bad_par(i, j);
                        }
                    }
                }
            }
        }
    }

    // Add v to A[i']
    // {:isolate_assertions} seems to be needed to avoid timeouts in dafny build
    method {:isolate_assertions} update(i': u32, v: int)
        modifies this
        modifies F
        requires 0 < i' <= N
        requires valid()
        // Check that A is updated correctly
        ensures A == old(A)[0..i' - 1] + [old(A)[i' - 1] + v] + old(A)[i'..]
        // Make sure F is still a valid FT over A
        ensures valid()
    {
        // Update "ground truth" ghost seq
        A := A[0..i' - 1] + [A[i' - 1] + v] + A[i'..];

        var i: u32 := i';
        ghost var t := 1;
        while i <= N
            modifies F
            decreases N - i
            // Interval i always covers i'
            invariant i - lsb(i) < i'
            // F beyond i hasn't been modified
            invariant forall j :: 0 < j < i' || i <= j <= N ==> F[j] == old(F[j])
            // Only intervals covering i' have been modified so far
            invariant forall j :: i' <= j < i && j <= N ==> if j - lsb(j) < i' then F[j] == old(F[j]) + v else F[j] == old(F[j])
            // For proving time complexity
            invariant t <= lsb(i)
            invariant t <= 2 * N
        {
            F[i] := F[i] + v;
            // Ensure lsb keeps getting bigger
            lsb_add(i);
            // Ensure the next interval covering i is at i + lsb(i)
            lsb_next(i);
            lsb_correct(i);
            i := i + lsb_fast(i);
            t := 2 * t;
        }
        // Loop ran at most log(N) + 1 iterations
        assert t <= 2 * N;

        // All intervals covering i' have been updated
        assert forall i :: 0 < i <= N ==> if i - lsb(i) < i' <= i then F[i] == old(F[i]) + v else F[i] == old(F[i]);
        
        // Check intervals before i'
        forall i | 0 < i < i'
            ensures F[i] == sum(A[i - lsb(i)..i])
        {
            assert A[i - lsb(i)..i] == old(A[i - lsb(i)..i]);
        }
        // Check intervals after i'
        forall i | 0 < i <= N && i - lsb(i) >= i'
            ensures F[i] == sum(A[i - lsb(i)..i])
        {
            assert A[i - lsb(i)..i] == old(A[i - lsb(i)..i]);
        }
        // Check intervals covering i'
        forall i | 0 < i <= N && i - lsb(i) < i' <= i
            ensures F[i] == sum(A[i - lsb(i)..i])
        {
            // This is actually pretty simple but Dafny needs a lot of hand-holding
            // Since we have to manually reason about sum
            // Basically we use split3 to convince Dafny that the parts left and right of i' are unchanged
            // And i' gets bumped up by v
            calc == {
                F[i];
                old(F[i]) + v;
                sum(old(A[i - lsb(i)..i])) + v;
                {
                    assert sum(old(A)[i - lsb(i)..i' - 1]) + old(A)[i' - 1] + sum(old(A)[i'..i]) == sum(old(A[i - lsb(i)..i])) by {
                        split3(old(A), i - lsb(i), i', i);
                    }
                }
                sum(old(A)[i - lsb(i)..i' - 1]) + old(A)[i' - 1] + sum(old(A)[i'..i]) + v;
                sum(old(A)[i - lsb(i)..i' - 1]) + A[i' - 1] + sum(old(A)[i'..i]);
                {
                    assert A[i - lsb(i)..i' - 1] == old(A[i - lsb(i)..i' - 1]);
                }
                sum(A[i - lsb(i)..i' - 1]) + A[i' - 1] + sum(old(A)[i'..i]);
                {
                    assert A[i'..i] == old(A[i'..i]);
                }
                sum(A[i - lsb(i)..i' - 1]) + A[i' - 1] + sum(A[i'..i]);
                {
                    assert sum(A[i - lsb(i)..i' - 1]) + A[i' - 1] + sum(A[i'..i]) == sum(A[i - lsb(i)..i]) by {
                        split3(A, i - lsb(i), i', i);
                    }
                }
                sum(A[i - lsb(i)..i]);
            }
        }
    }

    // Query for prefix sum up to i' inclusive
    method query(i': u32) returns (ret: int)
        requires 0 < i' <= N
        requires valid()
        ensures ret == sum(A[0..i'])
    {
        ret := 0;
        var i: u32 := i';
        assert ret == sum(A[i'..i']) by { reveal sum; }
        ghost var t := 1;
        while i > 0
            invariant i >= 0
            invariant ret == sum(A[i..i'])
            // For proving time complexity
            invariant i == 0 || t <= lsb(i)
            invariant t <= 2 * N
        {
            ret := ret + F[i];
            // Extend sum from i to i - lsb(i)
            split(A, i - lsb(i), i, i');
            lsb_correct(i);
            if i > lsb(i) {
                lsb_sub(i);
            }
            i := i - lsb_fast(i);
            t := 2 * t;
        }
        // Loop ran at most log(N) + 1 iterations
        assert t <= 2 * N;
    }

    // Query for sum between indices i' and j' inclusive
    // Naive approach:
    // return fenwick_query(N, F, j) - fenwick_query(N, F, i - 1);
    // Slightly faster approach:
    method range_query(i': u32, j': u32) returns (ret: int)
        requires 0 < i' <= j' <= N
        requires valid()
        ensures ret == sum(A[i' - 1..j'])
    {
        ret := 0;
        var j: u32 := j';
        ghost var k := j';
        assert ret == sum(A[j'..j']) by { reveal sum; }
        ghost var t := 1;
        while j >= i'
            invariant j >= 0
            invariant ret == sum(A[j..j'])
            // k is the previous value of j
            invariant j == j' || j == k - lsb(k)
            invariant k >= i'
            // For proving time complexity
            invariant j == 0 || t <= lsb(j)
            invariant t <= 2 * N
        {
            ret := ret + F[j];
            // Extend sum from j to j - lsb(j)
            split(A, j - lsb(j), j, j');
            // Save j
            k := j;
            lsb_correct(j);
            if j > lsb(j) {
                lsb_sub(j);
            }
            j := j - lsb_fast(j);
            t := 2 * t;
        }
        // Loop ran at most log(N) + 1 iterations
        assert t <= 2 * N;
        if k > lsb(k) {
            // This ensures i' <= j + lsb(j);
            lsb_sub(k);
        }
        // Now start from i' - 1 and remove intervals until we reach j
        var i: u32 := i' - 1;
        assert ret == sum(A[j..j']) - sum(A[i..i' - 1]) by { reveal sum; }
        t := 1;
        while i > j
            invariant i >= j
            invariant ret == sum(A[j..j']) - sum(A[i..i' - 1])
            // For proving time complexity
            invariant i == 0 || t <= lsb(i)
            invariant t <= 2 * N
        {
            ret := ret - F[i];
            split(A, i - lsb(i), i, i' - 1);
            if j > 0 {
                // This prevents us from overshooting
                lsb_next(j);
            }
            lsb_correct(i);
            if i > lsb(i) {
                lsb_sub(i);
            }
            i := i - lsb_fast(i);
            t := 2 * t;
        }
        // Loop ran at most log(N) + 1 iterations
        assert t <= 2 * N;
        // Finally, remove the overlapping interval
        split(A, i, i' - 1, j');
    }

    // Search for largest index with prefix sum less than or equal to s'
    method search(s': int) returns (ret: u32)
        requires N > 0
        requires s' >= 0
        requires valid()
        // Requires nonnegative elements
        requires forall i :: 0 <= i < N ==> A[i] >= 0
        // Everything before ret is less than or equal to s'
        // Everything after is greater
        ensures forall i :: 0 <= i <= N ==> (sum(A[..i]) <= s' <==> i <= ret)
    {
        // This is kinda like a binary search
        // Lower bound on the answer
        ret := 0;
        // ret + 2 * i is the upper bound on the answer
        var s := s';
        msb_is_pow2(N);
        var i: u32 := msb(N);
        assert sum(A[..ret]) + s == s' by { reveal sum; }
        while i > 0
            // Sorry for all the invariants
            // i is always a power of 2
            invariant i == 0 || (i > 0 && lsb(i) == i)
            invariant ret == 0 || lsb(ret) > i
            // A[..ret] is still s away from s'
            invariant s >= 0
            invariant ret <= N && sum(A[..ret]) == s' - s
            // A[..ret + 2 * i] is larger
            invariant i > 0 ==> ret + 2 * i > N || sum(A[..ret + 2 * i]) > s'
            // At the end of the loop our upper bound is ret + 1
            invariant i == 0 ==> ret == N || sum(A[..ret + 1]) > s'
        {
            if ret + i <= N {
                if ret > 0 {
                    // Ensure that lsb(ret + i) == lsb(i)
                    lsb_sum(ret, i);
                }
                split(A, 0, ret, ret + i);
            }
            if ret + i <= N && F[ret + i] <= s {
                s := s - F[ret + i];
                // Move lower bound
                ret := ret + i;
            }
            // Otherwise, move upper bound
            if i > 1 {
                // i / 2 is still a power of 2
                lsb_pow2(i);
            }
            i := i / 2;
        }
        // Now we proved sum(A[..ret]) <= s' && (ret == N || sum(A[..ret + 1]) > s')
        // So everything to the left is smaller
        // And everything to the right is larger
        forall j | 0 <= j <= N
            ensures sum(A[..j]) <= s' <==> j <= ret
        {
            if j <= ret {
                sum_nondecreasing(A, j, ret);
            } else {
                sum_nondecreasing(A, ret + 1, j);
            }
        }
    }
}

// Example usage
/*method Main() {
    var A := [1, 5, 3, 14, -2, 4, -3];
    var FT := new fenwick(A);
    
    var a := FT.query(5);
    assert a == sum(A[0..5]);
    print a, "\n";
    ghost var B := FT.A[0..5];
    assert B[0..5] == FT.A[0..5];
    
    FT.update(3, 7);

    var b := FT.query(5);
    print b, "\n";

    split3(B, 0, 3, 5);
    split3(FT.A, 0, 3, 5);
    assert B[0..2] == FT.A[0..2];
    assert B[3..5] == FT.A[3..5];
    assert a + 7 == b;

    var N := 1000000;
    A := seq(N, i => i * i % N);
    FT := new fenwick(A);
    assert N == FT.N;

    var s := 0;
    for i := 1 to N
        invariant N == FT.N
        invariant FT.valid()
    {
        var q := FT.query(i);
        s := s + q;
    }
    print s, "\n";
}*/