diff options
-rw-r--r-- | sd.go | 46 | ||||
-rw-r--r-- | segmenttree.go | 51 |
2 files changed, 51 insertions, 46 deletions
@@ -14,52 +14,6 @@ import ( var verbose = flag.Bool("v", false, "debug output") var file = flag.String("f", "cards", "cards file") -type segmentTree struct { - N int - seg []int -} - -// Build segment tree -func (s segmentTree) build(a *sql.Rows, l, r, n int) { - if l == r { - a.Next() - a.Scan(&s.seg[n]) - return - } - m := (l + r) >> 1 - s.build(a, l, m, n<<1) - s.build(a, m+1, r, n<<1|1) - s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1] -} - -// Update value at index x -func (s segmentTree) update(x, v, l, r, n int) { - if l == r { - s.seg[n] = v - return - } - m := (l + r) >> 1 - if x <= m { - s.update(x, v, l, m, n<<1) - } else { - s.update(x, v, m+1, r, n<<1|1) - } - s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1] -} - -// Find element with prefix sum v -func (s segmentTree) query(v, l, r, n int) (int, int) { - if l == r { - return s.seg[n], l - } - m := (l + r) >> 1 - if s.seg[n<<1] >= v { - return s.query(v, l, m, n<<1) - } else { - return s.query(v-s.seg[n<<1], m+1, r, n<<1|1) - } -} - func main() { flag.Parse() diff --git a/segmenttree.go b/segmenttree.go new file mode 100644 index 0000000..e2be46a --- /dev/null +++ b/segmenttree.go @@ -0,0 +1,51 @@ +package main + +import ( + "database/sql" +) + +type segmentTree struct { + N int + seg []int +} + +// Build segment tree +func (s segmentTree) build(a *sql.Rows, l, r, n int) { + if l == r { + a.Next() + a.Scan(&s.seg[n]) + return + } + m := (l + r) >> 1 + s.build(a, l, m, n<<1) + s.build(a, m+1, r, n<<1|1) + s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1] +} + +// Update value at index x +func (s segmentTree) update(x, v, l, r, n int) { + if l == r { + s.seg[n] = v + return + } + m := (l + r) >> 1 + if x <= m { + s.update(x, v, l, m, n<<1) + } else { + s.update(x, v, m+1, r, n<<1|1) + } + s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1] +} + +// Find element with prefix sum v +func (s segmentTree) query(v, l, r, n int) (int, int) { + if l == r { + return s.seg[n], l + } + m := (l + r) >> 1 + if s.seg[n<<1] >= v { + return s.query(v, l, m, n<<1) + } else { + return s.query(v-s.seg[n<<1], m+1, r, n<<1|1) + } +} |