diff options
Diffstat (limited to 'sd.go')
-rw-r--r-- | sd.go | 52 |
1 files changed, 28 insertions, 24 deletions
@@ -14,45 +14,49 @@ import ( var verbose = flag.Bool("v", false, "debug output") var file = flag.String("f", "cards", "cards file") -var sum = 0 -var seg = []int{} +type segmentTree struct { + N int + seg []int +} -func build(a *sql.Rows, l, r, n int) { +// Build segment tree +func (s segmentTree) build(a *sql.Rows, l, r, n int) { if l == r { a.Next() - a.Scan(&seg[n]) - sum += seg[n] + a.Scan(&s.seg[n]) return } m := (l + r) >> 1 - build(a, l, m, n<<1) - build(a, m+1, r, n<<1|1) - seg[n] = seg[n<<1] + seg[n<<1|1] + s.build(a, l, m, n<<1) + s.build(a, m+1, r, n<<1|1) + s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1] } -func update(x, v, l, r, n int) { +// Update value at index x +func (s segmentTree) update(x, v, l, r, n int) { if l == r { - seg[n] = v + s.seg[n] = v return } m := (l + r) >> 1 if x <= m { - update(x, v, l, m, n<<1) + s.update(x, v, l, m, n<<1) } else { - update(x, v, m+1, r, n<<1|1) + s.update(x, v, m+1, r, n<<1|1) } - seg[n] = seg[n<<1] + seg[n<<1|1] + s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1] } -func query(v, l, r, n int) (int, int) { +// Find element with prefix sum v +func (s segmentTree) query(v, l, r, n int) (int, int) { if l == r { - return seg[n], l + return s.seg[n], l } m := (l + r) >> 1 - if seg[n<<1] >= v { - return query(v, l, m, n<<1) + if s.seg[n<<1] >= v { + return s.query(v, l, m, n<<1) } else { - return query(v-seg[n<<1], m+1, r, n<<1|1) + return s.query(v-s.seg[n<<1], m+1, r, n<<1|1) } } @@ -67,19 +71,19 @@ func main() { // Get number of cards var N int _ = db.QueryRow("SELECT COUNT(*) FROM cards").Scan(&N) - seg = make([]int, 4*N) + s := segmentTree{N, make([]int, 4*N)} // Build segment tree rows, err := db.Query("SELECT weight FROM cards") if err != nil { panic(err) } - build(rows, 0, N-1, 1) + s.build(rows, 0, N-1, 1) + sum := s.seg[1] rows.Close() if *verbose { - fmt.Println(N) - fmt.Println(seg) + fmt.Println(s) } // https://stackoverflow.com/questions/14094190/function-similar-to-getchar @@ -97,7 +101,7 @@ func main() { fmt.Println(sum) fmt.Println(x) } - w, i := query(x, 0, N-1, 1) + w, i := s.query(x, 0, N-1, 1) fmt.Println(w) fmt.Println(i) @@ -124,7 +128,7 @@ func main() { } // Update segment tree and database - update(i, w, 0, N-1, 1) + s.update(i, w, 0, N-1, 1) _, err = db.Exec("UPDATE cards SET weight=? WHERE idx=?", w, i) if err != nil { panic(err) |