Skip to content

Commit 4c854dd

Browse files
authored
Rework Segment Tree to Support More Operations (TheAlgorithms#486)
1 parent 70a685e commit 4c854dd

File tree

1 file changed

+151
-53
lines changed

1 file changed

+151
-53
lines changed
Lines changed: 151 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,76 @@
1-
/// This stucture implements a segmented tree that
2-
/// can efficiently answer range queries on arrays.
3-
pub struct SegmentTree<T: Default + Ord + Copy> {
4-
len: usize,
5-
buf: Vec<T>,
6-
op: Ops,
7-
}
1+
use std::cmp::min;
2+
use std::fmt::Debug;
3+
use std::ops::Range;
84

9-
pub enum Ops {
10-
Max,
11-
Min,
5+
/// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays.
6+
/// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...]
7+
/// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together.
8+
/// Note that this function should be commutative and associative
9+
/// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b`
10+
pub struct SegmentTree<T: Debug + Default + Ord + Copy> {
11+
len: usize, // length of the represented
12+
tree: Vec<T>, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance)
13+
merge: fn(T, T) -> T, // how we merge two values together
1214
}
1315

14-
impl<T: Default + Ord + Copy> SegmentTree<T> {
15-
/// function to build the tree
16-
pub fn from_vec(arr: &[T], op: Ops) -> Self {
16+
impl<T: Debug + Default + Ord + Copy> SegmentTree<T> {
17+
/// Builds a SegmentTree from an array and a merge function
18+
pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self {
1719
let len = arr.len();
1820
let mut buf: Vec<T> = vec![T::default(); 2 * len];
19-
buf[len..(len + len)].clone_from_slice(&arr[0..len]);
21+
// Populate the tree bottom-up, from right to left
22+
buf[len..(2 * len)].clone_from_slice(&arr[0..len]); // last len pos is the bottom of the tree -> every individual value
2023
for i in (1..len).rev() {
21-
buf[i] = match op {
22-
Ops::Max => buf[2 * i].max(buf[2 * i + 1]),
23-
Ops::Min => buf[2 * i].min(buf[2 * i + 1]),
24-
};
24+
// a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2
25+
buf[i] = merge(buf[2 * i], buf[2 * i + 1]);
26+
}
27+
SegmentTree {
28+
len,
29+
tree: buf,
30+
merge,
2531
}
26-
SegmentTree { len, buf, op }
2732
}
2833

29-
/// function to get sum on interval [l, r]
30-
pub fn query(&self, mut l: usize, mut r: usize) -> T {
31-
l += self.len;
32-
r += self.len;
33-
let mut res = self.buf[l];
34-
while l <= r {
34+
/// Query the range (exclusive)
35+
/// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.)
36+
/// return the aggregate of values over this range otherwise
37+
pub fn query(&self, range: Range<usize>) -> Option<T> {
38+
let mut l = range.start + self.len;
39+
let mut r = min(self.len, range.end) + self.len;
40+
let mut res = None;
41+
// Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations
42+
while l < r {
3543
if l % 2 == 1 {
36-
res = match self.op {
37-
Ops::Max => res.max(self.buf[l]),
38-
Ops::Min => res.min(self.buf[l]),
39-
};
44+
res = Some(match res {
45+
None => self.tree[l],
46+
Some(old) => (self.merge)(old, self.tree[l]),
47+
});
4048
l += 1;
4149
}
42-
if r % 2 == 0 {
43-
res = match self.op {
44-
Ops::Max => res.max(self.buf[r]),
45-
Ops::Min => res.min(self.buf[r]),
46-
};
50+
if r % 2 == 1 {
4751
r -= 1;
52+
res = Some(match res {
53+
None => self.tree[r],
54+
Some(old) => (self.merge)(old, self.tree[r]),
55+
});
4856
}
4957
l /= 2;
5058
r /= 2;
5159
}
5260
res
5361
}
5462

55-
/// function to update a tree node
56-
pub fn update(&mut self, mut idx: usize, val: T) {
57-
idx += self.len;
58-
self.buf[idx] = val;
59-
idx /= 2;
63+
/// Updates the value at index `idx` in the original array with a new value `val`
64+
pub fn update(&mut self, idx: usize, val: T) {
65+
// change every value where `idx` plays a role, bottom -> up
66+
// 1: change in the right-hand side of the tree (bottom row)
67+
let mut idx = idx + self.len;
68+
self.tree[idx] = val;
6069

70+
// 2: then bubble up
71+
idx /= 2;
6172
while idx != 0 {
62-
self.buf[idx] = match self.op {
63-
Ops::Max => self.buf[2 * idx].max(self.buf[2 * idx + 1]),
64-
Ops::Min => self.buf[2 * idx].min(self.buf[2 * idx + 1]),
65-
};
73+
self.tree[idx] = (self.merge)(self.tree[2 * idx], self.tree[2 * idx + 1]);
6674
idx /= 2;
6775
}
6876
}
@@ -71,17 +79,107 @@ impl<T: Default + Ord + Copy> SegmentTree<T> {
7179
#[cfg(test)]
7280
mod tests {
7381
use super::*;
82+
use quickcheck::TestResult;
83+
use quickcheck_macros::quickcheck;
84+
use std::cmp::{max, min};
7485

7586
#[test]
76-
fn it_works() {
77-
let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
78-
let min_seg_tree = SegmentTree::from_vec(&vec, Ops::Min);
79-
assert_eq!(-5, min_seg_tree.query(4, 6));
80-
assert_eq!(-20, min_seg_tree.query(0, vec.len() - 1));
81-
let mut max_seg_tree = SegmentTree::from_vec(&vec, Ops::Max);
82-
assert_eq!(6, max_seg_tree.query(4, 6));
83-
assert_eq!(15, max_seg_tree.query(0, vec.len() - 1));
84-
max_seg_tree.update(6, 8);
85-
assert_eq!(8, max_seg_tree.query(4, 6));
87+
fn test_min_segments() {
88+
let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8];
89+
let min_seg_tree = SegmentTree::from_vec(&vec, min);
90+
assert_eq!(Some(-5), min_seg_tree.query(4..7));
91+
assert_eq!(Some(-30), min_seg_tree.query(0..vec.len()));
92+
assert_eq!(Some(-30), min_seg_tree.query(0..2));
93+
assert_eq!(Some(-4), min_seg_tree.query(1..3));
94+
assert_eq!(Some(-5), min_seg_tree.query(1..7));
95+
}
96+
97+
#[test]
98+
fn test_max_segments() {
99+
let val_at_6 = 6;
100+
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
101+
let mut max_seg_tree = SegmentTree::from_vec(&vec, max);
102+
assert_eq!(Some(15), max_seg_tree.query(0..vec.len()));
103+
let max_4_to_6 = 6;
104+
assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7));
105+
let delta = 2;
106+
max_seg_tree.update(6, val_at_6 + delta);
107+
assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7));
108+
}
109+
110+
#[test]
111+
fn test_sum_segments() {
112+
let val_at_6 = 6;
113+
let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8];
114+
let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b);
115+
for (i, val) in vec.iter().enumerate() {
116+
assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1)));
117+
}
118+
let sum_4_to_6 = sum_seg_tree.query(4..7);
119+
assert_eq!(Some(4), sum_4_to_6);
120+
let delta = 3;
121+
sum_seg_tree.update(6, val_at_6 + delta);
122+
assert_eq!(
123+
sum_4_to_6.unwrap() + delta,
124+
sum_seg_tree.query(4..7).unwrap()
125+
);
126+
}
127+
128+
// Some properties over segment trees:
129+
// When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc.
130+
// When asking for an interval containing a single value, return this value, no matter the merge function
131+
132+
#[quickcheck]
133+
fn check_overall_interval_min(array: Vec<i32>) -> TestResult {
134+
let seg_tree = SegmentTree::from_vec(&array, min);
135+
TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len()))
136+
}
137+
138+
#[quickcheck]
139+
fn check_overall_interval_max(array: Vec<i32>) -> TestResult {
140+
let seg_tree = SegmentTree::from_vec(&array, max);
141+
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
142+
}
143+
144+
#[quickcheck]
145+
fn check_overall_interval_sum(array: Vec<i32>) -> TestResult {
146+
let seg_tree = SegmentTree::from_vec(&array, max);
147+
TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len()))
148+
}
149+
150+
#[quickcheck]
151+
fn check_single_interval_min(array: Vec<i32>) -> TestResult {
152+
let seg_tree = SegmentTree::from_vec(&array, min);
153+
for (i, value) in array.into_iter().enumerate() {
154+
let res = seg_tree.query(i..(i + 1));
155+
if res != Some(value) {
156+
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
157+
}
158+
}
159+
TestResult::passed()
160+
}
161+
162+
#[quickcheck]
163+
fn check_single_interval_max(array: Vec<i32>) -> TestResult {
164+
let seg_tree = SegmentTree::from_vec(&array, max);
165+
for (i, value) in array.into_iter().enumerate() {
166+
let res = seg_tree.query(i..(i + 1));
167+
if res != Some(value) {
168+
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
169+
}
170+
}
171+
TestResult::passed()
172+
}
173+
174+
#[quickcheck]
175+
fn check_single_interval_sum(array: Vec<i32>) -> TestResult {
176+
let seg_tree = SegmentTree::from_vec(&array, max);
177+
for (i, value) in array.into_iter().enumerate() {
178+
let res = seg_tree.query(i..(i + 1));
179+
if res != Some(value) {
180+
return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res));
181+
}
182+
}
183+
TestResult::passed()
86184
}
87185
}

0 commit comments

Comments
 (0)