1
- /// This stucture implements a segmented tree that
2
- /// can efficiently answer range queries on arrays.
3
- pub struct SegmentTree < T : Default + Ord + Copy > {
4
- len : usize ,
5
- buf : Vec < T > ,
6
- op : Ops ,
7
- }
1
+ use std:: cmp:: min;
2
+ use std:: fmt:: Debug ;
3
+ use std:: ops:: Range ;
8
4
9
- pub enum Ops {
10
- Max ,
11
- Min ,
5
+ /// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays.
6
+ /// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...]
7
+ /// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together.
8
+ /// Note that this function should be commutative and associative
9
+ /// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b`
10
+ pub struct SegmentTree < T : Debug + Default + Ord + Copy > {
11
+ len : usize , // length of the represented
12
+ tree : Vec < T > , // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance)
13
+ merge : fn ( T , T ) -> T , // how we merge two values together
12
14
}
13
15
14
- impl < T : Default + Ord + Copy > SegmentTree < T > {
15
- /// function to build the tree
16
- pub fn from_vec ( arr : & [ T ] , op : Ops ) -> Self {
16
+ impl < T : Debug + Default + Ord + Copy > SegmentTree < T > {
17
+ /// Builds a SegmentTree from an array and a merge function
18
+ pub fn from_vec ( arr : & [ T ] , merge : fn ( T , T ) -> T ) -> Self {
17
19
let len = arr. len ( ) ;
18
20
let mut buf: Vec < T > = vec ! [ T :: default ( ) ; 2 * len] ;
19
- buf[ len..( len + len) ] . clone_from_slice ( & arr[ 0 ..len] ) ;
21
+ // Populate the tree bottom-up, from right to left
22
+ buf[ len..( 2 * len) ] . clone_from_slice ( & arr[ 0 ..len] ) ; // last len pos is the bottom of the tree -> every individual value
20
23
for i in ( 1 ..len) . rev ( ) {
21
- buf[ i] = match op {
22
- Ops :: Max => buf[ 2 * i] . max ( buf[ 2 * i + 1 ] ) ,
23
- Ops :: Min => buf[ 2 * i] . min ( buf[ 2 * i + 1 ] ) ,
24
- } ;
24
+ // a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2
25
+ buf[ i] = merge ( buf[ 2 * i] , buf[ 2 * i + 1 ] ) ;
26
+ }
27
+ SegmentTree {
28
+ len,
29
+ tree : buf,
30
+ merge,
25
31
}
26
- SegmentTree { len, buf, op }
27
32
}
28
33
29
- /// function to get sum on interval [l, r]
30
- pub fn query ( & self , mut l : usize , mut r : usize ) -> T {
31
- l += self . len ;
32
- r += self . len ;
33
- let mut res = self . buf [ l] ;
34
- while l <= r {
34
+ /// Query the range (exclusive)
35
+ /// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.)
36
+ /// return the aggregate of values over this range otherwise
37
+ pub fn query ( & self , range : Range < usize > ) -> Option < T > {
38
+ let mut l = range. start + self . len ;
39
+ let mut r = min ( self . len , range. end ) + self . len ;
40
+ let mut res = None ;
41
+ // Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations
42
+ while l < r {
35
43
if l % 2 == 1 {
36
- res = match self . op {
37
- Ops :: Max => res . max ( self . buf [ l] ) ,
38
- Ops :: Min => res . min ( self . buf [ l] ) ,
39
- } ;
44
+ res = Some ( match res {
45
+ None => self . tree [ l] ,
46
+ Some ( old ) => ( self . merge ) ( old , self . tree [ l] ) ,
47
+ } ) ;
40
48
l += 1 ;
41
49
}
42
- if r % 2 == 0 {
43
- res = match self . op {
44
- Ops :: Max => res. max ( self . buf [ r] ) ,
45
- Ops :: Min => res. min ( self . buf [ r] ) ,
46
- } ;
50
+ if r % 2 == 1 {
47
51
r -= 1 ;
52
+ res = Some ( match res {
53
+ None => self . tree [ r] ,
54
+ Some ( old) => ( self . merge ) ( old, self . tree [ r] ) ,
55
+ } ) ;
48
56
}
49
57
l /= 2 ;
50
58
r /= 2 ;
51
59
}
52
60
res
53
61
}
54
62
55
- /// function to update a tree node
56
- pub fn update ( & mut self , mut idx : usize , val : T ) {
57
- idx += self . len ;
58
- self . buf [ idx] = val;
59
- idx /= 2 ;
63
+ /// Updates the value at index `idx` in the original array with a new value `val`
64
+ pub fn update ( & mut self , idx : usize , val : T ) {
65
+ // change every value where `idx` plays a role, bottom -> up
66
+ // 1: change in the right-hand side of the tree (bottom row)
67
+ let mut idx = idx + self . len ;
68
+ self . tree [ idx] = val;
60
69
70
+ // 2: then bubble up
71
+ idx /= 2 ;
61
72
while idx != 0 {
62
- self . buf [ idx] = match self . op {
63
- Ops :: Max => self . buf [ 2 * idx] . max ( self . buf [ 2 * idx + 1 ] ) ,
64
- Ops :: Min => self . buf [ 2 * idx] . min ( self . buf [ 2 * idx + 1 ] ) ,
65
- } ;
73
+ self . tree [ idx] = ( self . merge ) ( self . tree [ 2 * idx] , self . tree [ 2 * idx + 1 ] ) ;
66
74
idx /= 2 ;
67
75
}
68
76
}
@@ -71,17 +79,107 @@ impl<T: Default + Ord + Copy> SegmentTree<T> {
71
79
#[ cfg( test) ]
72
80
mod tests {
73
81
use super :: * ;
82
+ use quickcheck:: TestResult ;
83
+ use quickcheck_macros:: quickcheck;
84
+ use std:: cmp:: { max, min} ;
74
85
75
86
#[ test]
76
- fn it_works ( ) {
77
- let vec = vec ! [ 1 , 2 , -4 , 7 , 3 , -5 , 6 , 11 , -20 , 9 , 14 , 15 , 5 , 2 , -8 ] ;
78
- let min_seg_tree = SegmentTree :: from_vec ( & vec, Ops :: Min ) ;
79
- assert_eq ! ( -5 , min_seg_tree. query( 4 , 6 ) ) ;
80
- assert_eq ! ( -20 , min_seg_tree. query( 0 , vec. len( ) - 1 ) ) ;
81
- let mut max_seg_tree = SegmentTree :: from_vec ( & vec, Ops :: Max ) ;
82
- assert_eq ! ( 6 , max_seg_tree. query( 4 , 6 ) ) ;
83
- assert_eq ! ( 15 , max_seg_tree. query( 0 , vec. len( ) - 1 ) ) ;
84
- max_seg_tree. update ( 6 , 8 ) ;
85
- assert_eq ! ( 8 , max_seg_tree. query( 4 , 6 ) ) ;
87
+ fn test_min_segments ( ) {
88
+ let vec = vec ! [ -30 , 2 , -4 , 7 , 3 , -5 , 6 , 11 , -20 , 9 , 14 , 15 , 5 , 2 , -8 ] ;
89
+ let min_seg_tree = SegmentTree :: from_vec ( & vec, min) ;
90
+ assert_eq ! ( Some ( -5 ) , min_seg_tree. query( 4 ..7 ) ) ;
91
+ assert_eq ! ( Some ( -30 ) , min_seg_tree. query( 0 ..vec. len( ) ) ) ;
92
+ assert_eq ! ( Some ( -30 ) , min_seg_tree. query( 0 ..2 ) ) ;
93
+ assert_eq ! ( Some ( -4 ) , min_seg_tree. query( 1 ..3 ) ) ;
94
+ assert_eq ! ( Some ( -5 ) , min_seg_tree. query( 1 ..7 ) ) ;
95
+ }
96
+
97
+ #[ test]
98
+ fn test_max_segments ( ) {
99
+ let val_at_6 = 6 ;
100
+ let vec = vec ! [ 1 , 2 , -4 , 7 , 3 , -5 , val_at_6, 11 , -20 , 9 , 14 , 15 , 5 , 2 , -8 ] ;
101
+ let mut max_seg_tree = SegmentTree :: from_vec ( & vec, max) ;
102
+ assert_eq ! ( Some ( 15 ) , max_seg_tree. query( 0 ..vec. len( ) ) ) ;
103
+ let max_4_to_6 = 6 ;
104
+ assert_eq ! ( Some ( max_4_to_6) , max_seg_tree. query( 4 ..7 ) ) ;
105
+ let delta = 2 ;
106
+ max_seg_tree. update ( 6 , val_at_6 + delta) ;
107
+ assert_eq ! ( Some ( val_at_6 + delta) , max_seg_tree. query( 4 ..7 ) ) ;
108
+ }
109
+
110
+ #[ test]
111
+ fn test_sum_segments ( ) {
112
+ let val_at_6 = 6 ;
113
+ let vec = vec ! [ 1 , 2 , -4 , 7 , 3 , -5 , val_at_6, 11 , -20 , 9 , 14 , 15 , 5 , 2 , -8 ] ;
114
+ let mut sum_seg_tree = SegmentTree :: from_vec ( & vec, |a, b| a + b) ;
115
+ for ( i, val) in vec. iter ( ) . enumerate ( ) {
116
+ assert_eq ! ( Some ( * val) , sum_seg_tree. query( i..( i + 1 ) ) ) ;
117
+ }
118
+ let sum_4_to_6 = sum_seg_tree. query ( 4 ..7 ) ;
119
+ assert_eq ! ( Some ( 4 ) , sum_4_to_6) ;
120
+ let delta = 3 ;
121
+ sum_seg_tree. update ( 6 , val_at_6 + delta) ;
122
+ assert_eq ! (
123
+ sum_4_to_6. unwrap( ) + delta,
124
+ sum_seg_tree. query( 4 ..7 ) . unwrap( )
125
+ ) ;
126
+ }
127
+
128
+ // Some properties over segment trees:
129
+ // When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc.
130
+ // When asking for an interval containing a single value, return this value, no matter the merge function
131
+
132
+ #[ quickcheck]
133
+ fn check_overall_interval_min ( array : Vec < i32 > ) -> TestResult {
134
+ let seg_tree = SegmentTree :: from_vec ( & array, min) ;
135
+ TestResult :: from_bool ( array. iter ( ) . min ( ) . copied ( ) == seg_tree. query ( 0 ..array. len ( ) ) )
136
+ }
137
+
138
+ #[ quickcheck]
139
+ fn check_overall_interval_max ( array : Vec < i32 > ) -> TestResult {
140
+ let seg_tree = SegmentTree :: from_vec ( & array, max) ;
141
+ TestResult :: from_bool ( array. iter ( ) . max ( ) . copied ( ) == seg_tree. query ( 0 ..array. len ( ) ) )
142
+ }
143
+
144
+ #[ quickcheck]
145
+ fn check_overall_interval_sum ( array : Vec < i32 > ) -> TestResult {
146
+ let seg_tree = SegmentTree :: from_vec ( & array, max) ;
147
+ TestResult :: from_bool ( array. iter ( ) . max ( ) . copied ( ) == seg_tree. query ( 0 ..array. len ( ) ) )
148
+ }
149
+
150
+ #[ quickcheck]
151
+ fn check_single_interval_min ( array : Vec < i32 > ) -> TestResult {
152
+ let seg_tree = SegmentTree :: from_vec ( & array, min) ;
153
+ for ( i, value) in array. into_iter ( ) . enumerate ( ) {
154
+ let res = seg_tree. query ( i..( i + 1 ) ) ;
155
+ if res != Some ( value) {
156
+ return TestResult :: error ( format ! ( "Expected {:?}, got {:?}" , Some ( value) , res) ) ;
157
+ }
158
+ }
159
+ TestResult :: passed ( )
160
+ }
161
+
162
+ #[ quickcheck]
163
+ fn check_single_interval_max ( array : Vec < i32 > ) -> TestResult {
164
+ let seg_tree = SegmentTree :: from_vec ( & array, max) ;
165
+ for ( i, value) in array. into_iter ( ) . enumerate ( ) {
166
+ let res = seg_tree. query ( i..( i + 1 ) ) ;
167
+ if res != Some ( value) {
168
+ return TestResult :: error ( format ! ( "Expected {:?}, got {:?}" , Some ( value) , res) ) ;
169
+ }
170
+ }
171
+ TestResult :: passed ( )
172
+ }
173
+
174
+ #[ quickcheck]
175
+ fn check_single_interval_sum ( array : Vec < i32 > ) -> TestResult {
176
+ let seg_tree = SegmentTree :: from_vec ( & array, max) ;
177
+ for ( i, value) in array. into_iter ( ) . enumerate ( ) {
178
+ let res = seg_tree. query ( i..( i + 1 ) ) ;
179
+ if res != Some ( value) {
180
+ return TestResult :: error ( format ! ( "Expected {:?}, got {:?}" , Some ( value) , res) ) ;
181
+ }
182
+ }
183
+ TestResult :: passed ( )
86
184
}
87
185
}
0 commit comments