aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2022-03-28 12:04:13 -0500
committerAnthony Wang2022-03-28 12:04:13 -0500
commit4936d52dc5e2e0e76852a68758f772b12a543fad (patch)
tree4e522c81f46a224e0557afe12b391f9eb7655d98
parent24d27437208170463abb81819f4d0b2f45544d0b (diff)
Use struct for segmentTree
-rw-r--r--sd.go52
1 files changed, 28 insertions, 24 deletions
diff --git a/sd.go b/sd.go
index fbbd8b9..3f5f3ce 100644
--- a/sd.go
+++ b/sd.go
@@ -14,45 +14,49 @@ import (
var verbose = flag.Bool("v", false, "debug output")
var file = flag.String("f", "cards", "cards file")
-var sum = 0
-var seg = []int{}
+type segmentTree struct {
+ N int
+ seg []int
+}
-func build(a *sql.Rows, l, r, n int) {
+// Build segment tree
+func (s segmentTree) build(a *sql.Rows, l, r, n int) {
if l == r {
a.Next()
- a.Scan(&seg[n])
- sum += seg[n]
+ a.Scan(&s.seg[n])
return
}
m := (l + r) >> 1
- build(a, l, m, n<<1)
- build(a, m+1, r, n<<1|1)
- seg[n] = seg[n<<1] + seg[n<<1|1]
+ s.build(a, l, m, n<<1)
+ s.build(a, m+1, r, n<<1|1)
+ s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1]
}
-func update(x, v, l, r, n int) {
+// Update value at index x
+func (s segmentTree) update(x, v, l, r, n int) {
if l == r {
- seg[n] = v
+ s.seg[n] = v
return
}
m := (l + r) >> 1
if x <= m {
- update(x, v, l, m, n<<1)
+ s.update(x, v, l, m, n<<1)
} else {
- update(x, v, m+1, r, n<<1|1)
+ s.update(x, v, m+1, r, n<<1|1)
}
- seg[n] = seg[n<<1] + seg[n<<1|1]
+ s.seg[n] = s.seg[n<<1] + s.seg[n<<1|1]
}
-func query(v, l, r, n int) (int, int) {
+// Find element with prefix sum v
+func (s segmentTree) query(v, l, r, n int) (int, int) {
if l == r {
- return seg[n], l
+ return s.seg[n], l
}
m := (l + r) >> 1
- if seg[n<<1] >= v {
- return query(v, l, m, n<<1)
+ if s.seg[n<<1] >= v {
+ return s.query(v, l, m, n<<1)
} else {
- return query(v-seg[n<<1], m+1, r, n<<1|1)
+ return s.query(v-s.seg[n<<1], m+1, r, n<<1|1)
}
}
@@ -67,19 +71,19 @@ func main() {
// Get number of cards
var N int
_ = db.QueryRow("SELECT COUNT(*) FROM cards").Scan(&N)
- seg = make([]int, 4*N)
+ s := segmentTree{N, make([]int, 4*N)}
// Build segment tree
rows, err := db.Query("SELECT weight FROM cards")
if err != nil {
panic(err)
}
- build(rows, 0, N-1, 1)
+ s.build(rows, 0, N-1, 1)
+ sum := s.seg[1]
rows.Close()
if *verbose {
- fmt.Println(N)
- fmt.Println(seg)
+ fmt.Println(s)
}
// https://stackoverflow.com/questions/14094190/function-similar-to-getchar
@@ -97,7 +101,7 @@ func main() {
fmt.Println(sum)
fmt.Println(x)
}
- w, i := query(x, 0, N-1, 1)
+ w, i := s.query(x, 0, N-1, 1)
fmt.Println(w)
fmt.Println(i)
@@ -124,7 +128,7 @@ func main() {
}
// Update segment tree and database
- update(i, w, 0, N-1, 1)
+ s.update(i, w, 0, N-1, 1)
_, err = db.Exec("UPDATE cards SET weight=? WHERE idx=?", w, i)
if err != nil {
panic(err)