Skip to content

Commit 9a51d1e

Browse files
authored
Merge pull request #32038 from bkirwi/flat-tree
[persist] Optimizations for the merge tree
2 parents 7898480 + 603a855 commit 9a51d1e

File tree

1 file changed

+118
-53
lines changed

1 file changed

+118
-53
lines changed

src/persist-client/src/internal/merge.rs

+118-53
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use mz_ore::task::{JoinHandle, JoinHandleExt};
1111
use std::fmt::{Debug, Formatter};
1212
use std::mem;
13+
use std::ops::{Deref, DerefMut};
1314

1415
/// A merge tree.
1516
///
@@ -23,21 +24,28 @@ use std::mem;
2324
/// - The "depth" of the merge tree - the number of merges any particular element may undergo -
2425
/// is `O(log N)`.
2526
pub struct MergeTree<T> {
26-
pub(crate) max_len: usize,
27-
pub(crate) levels: Vec<Vec<T>>,
27+
/// Configuration: the largest any level in the tree is allowed to grow.
28+
max_level_len: usize,
29+
/// The length of each level in the tree, stored in order from shallowest to deepest.
30+
level_lens: Vec<usize>,
31+
/// A flattened representation of the contents of the tree, stored in order from earliest /
32+
/// deepest to newest / shallowest.
33+
data: Vec<T>,
2834
merge_fn: Box<dyn Fn(Vec<T>) -> T + Sync + Send>,
2935
}
3036

3137
impl<T: Debug> Debug for MergeTree<T> {
3238
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
3339
let Self {
34-
max_len,
35-
levels,
40+
max_level_len,
41+
level_lens,
42+
data,
3643
merge_fn: _,
3744
} = self;
3845
f.debug_struct("MergeTree")
39-
.field("max_len", max_len)
40-
.field("levels", levels)
46+
.field("max_level_len", max_level_len)
47+
.field("level_lens", level_lens)
48+
.field("data", data)
4149
.finish_non_exhaustive()
4250
}
4351
}
@@ -48,90 +56,109 @@ impl<T> MergeTree<T> {
4856
/// limit, the provided `merge_fn` is used to combine adjacent elements together.
4957
pub fn new(max_len: usize, merge_fn: impl Fn(Vec<T>) -> T + Send + Sync + 'static) -> Self {
5058
let new = Self {
51-
max_len,
52-
levels: vec![vec![]],
59+
max_level_len: max_len,
60+
level_lens: vec![0],
61+
data: vec![],
5362
merge_fn: Box::new(merge_fn),
5463
};
5564
new.assert_invariants();
5665
new
5766
}
5867

59-
/// Iterate over (references to) the parts in this tree in first-to-latest order.
60-
#[allow(unused)]
61-
pub fn iter(&self) -> impl Iterator<Item = &T> + DoubleEndedIterator {
62-
self.levels.iter().rev().flat_map(|l| l.iter())
63-
}
64-
65-
/// Iterate over (mutable references to) the parts in this tree in first-to-latest order.
66-
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> + DoubleEndedIterator {
67-
self.levels.iter_mut().rev().flat_map(|l| l.iter_mut())
68+
fn merge_last(&mut self, level_len: usize) {
69+
let offset = self.data.len() - level_len;
70+
let split = self.data.split_off(offset);
71+
let merged = (self.merge_fn)(split);
72+
self.data.push(merged);
6873
}
6974

7075
/// Push a new part onto the end of this tree, possibly triggering a merge.
71-
pub fn push(&mut self, mut part: T) {
76+
pub fn push(&mut self, part: T) {
7277
// Normally, all levels have strictly less than max_len elements.
7378
// However, the _deepest_ level is allowed to have exactly max_len elements,
7479
// since that can save us an unnecessary merge in some cases.
7580
// (For example, when precisely max_len elements are added.)
76-
if let Some(last) = self.levels.last_mut() {
77-
if last.len() == self.max_len {
78-
let merged = (self.merge_fn)(mem::take(last));
79-
self.levels.push(vec![merged]);
81+
if let Some(last_len) = self.level_lens.last_mut() {
82+
if *last_len == self.max_level_len {
83+
let len = mem::take(last_len);
84+
self.merge_last(len);
85+
self.level_lens.push(1);
8086
}
8187
}
8288

8389
// At this point, all levels have room. Add our new part, then continue
8490
// merging up the tree until either there's still room in the current level
8591
// or we've reached the top.
86-
let max_level = self.levels.len() - 1;
92+
self.data.push(part);
93+
94+
let max_level = self.level_lens.len() - 1;
8795
for depth in 0..=max_level {
88-
let level = &mut self.levels[depth];
89-
level.push(part);
96+
let level_len = &mut self.level_lens[depth];
97+
*level_len += 1;
9098

91-
if level.len() < self.max_len || depth == max_level {
99+
if *level_len < self.max_level_len || depth == max_level {
92100
break;
93101
}
94102

95-
part = (self.merge_fn)(mem::take(level));
103+
let len = mem::take(level_len);
104+
self.merge_last(len);
96105
}
97106
}
98107

99108
/// Return the contents of this merge tree, flattened into at most `max_len` parts.
100-
pub fn finish(self) -> Vec<T> {
101-
self.levels
102-
.into_iter()
103-
.reduce(|mut shallower, mut deeper| {
104-
if shallower.len() + deeper.len() <= self.max_len {
105-
// Optimization: if there's enough room in the next level for everything at the
106-
// current level, add it directly.
107-
deeper.append(&mut shallower);
108-
} else {
109-
// Otherwise, merge this up as if it were a full level.
110-
let merged = (self.merge_fn)(shallower);
111-
deeper.push(merged);
112-
}
113-
deeper
114-
})
115-
.expect("non-empty level array")
109+
pub fn finish(mut self) -> Vec<T> {
110+
let mut tail_len = 0;
111+
for level_len in mem::take(&mut self.level_lens) {
112+
if tail_len + level_len <= self.max_level_len {
113+
// Optimization: we can combine the current level with the last level without
114+
// going over our limit.
115+
tail_len += level_len;
116+
} else {
117+
// Otherwise, perform the merge and start a new tail.
118+
self.merge_last(tail_len);
119+
tail_len = level_len + 1
120+
}
121+
}
122+
assert!(self.data.len() <= self.max_level_len);
123+
self.data
116124
}
117125

118126
pub(crate) fn assert_invariants(&self) {
119-
assert!(self.max_len >= 2, "max_len must be at least 2");
127+
assert!(self.max_level_len >= 2, "max_len must be at least 2");
120128

121-
let (deepest, shallow) = self.levels.split_last().expect("non-empty level array");
122-
for (depth, level) in shallow.iter().enumerate() {
129+
assert_eq!(
130+
self.data.len(),
131+
self.level_lens.iter().copied().sum::<usize>(),
132+
"level sizes should sum to overall len"
133+
);
134+
let (deepest_len, shallow) = self.level_lens.split_last().expect("non-empty level array");
135+
for (depth, level_len) in shallow.iter().enumerate() {
123136
assert!(
124-
level.len() < self.max_len,
137+
*level_len < self.max_level_len,
125138
"strictly less than max elements at level {depth}"
126139
);
127140
}
128141
assert!(
129-
deepest.len() <= self.max_len,
142+
*deepest_len <= self.max_level_len,
130143
"at most max elements at deepest level"
131144
);
132145
}
133146
}
134147

148+
impl<T> Deref for MergeTree<T> {
149+
type Target = [T];
150+
151+
fn deref(&self) -> &Self::Target {
152+
&*self.data
153+
}
154+
}
155+
156+
impl<T> DerefMut for MergeTree<T> {
157+
fn deref_mut(&mut self) -> &mut Self::Target {
158+
&mut *self.data
159+
}
160+
}
161+
135162
/// Either a handle to a task that returns a value or the value itself.
136163
#[derive(Debug)]
137164
pub enum Pending<T> {
@@ -167,21 +194,37 @@ impl<T: Send + 'static> Pending<T> {
167194
#[cfg(test)]
168195
mod tests {
169196
use super::*;
197+
use mz_ore::cast::CastLossy;
170198

171199
#[mz_ore::test]
172200
#[cfg_attr(miri, ignore)] // too slow
173201
fn test_merge_tree() {
174202
// Exhaustively test the merge tree for small sizes.
203+
struct Value {
204+
merge_depth: usize,
205+
elements: Vec<i64>,
206+
}
207+
175208
for max_len in 2..8 {
176209
for items in 0..100 {
177-
let mut merge_tree = MergeTree::new(max_len, |vals: Vec<Vec<usize>>| {
210+
let mut merge_tree = MergeTree::new(max_len, |vals: Vec<Value>| {
178211
// Merge sequences by concatenation.
179-
vals.into_iter().flatten().collect()
212+
Value {
213+
merge_depth: vals.iter().map(|v| v.merge_depth).max().unwrap_or(0) + 1,
214+
elements: vals.into_iter().flat_map(|e| e.elements).collect(),
215+
}
180216
});
181217
for i in 0..items {
182-
merge_tree.push(vec![i]);
218+
merge_tree.push(Value {
219+
merge_depth: 0,
220+
elements: vec![i],
221+
});
183222
assert!(
184-
merge_tree.iter().flatten().copied().eq(0..=i),
223+
merge_tree
224+
.iter()
225+
.flat_map(|v| v.elements.iter())
226+
.copied()
227+
.eq(0..=i),
185228
"no parts should be lost"
186229
);
187230
merge_tree.assert_invariants();
@@ -191,7 +234,29 @@ mod tests {
191234
parts.len() <= max_len,
192235
"no more than {max_len} finished parts"
193236
);
194-
assert!(parts.into_iter().flatten().eq(0..items), "no parts lost");
237+
238+
// We want our merged tree to be "balanced".
239+
// If we have 2^N elements in a binary tree, we want the depth to be N;
240+
// and more generally, we want a depth of N for a K-ary tree with K^N elements...
241+
// which is to say, a depth of log_K N for a tree with N elements.
242+
let expected_merge_depth =
243+
usize::cast_lossy(f64::cast_lossy(items).log(f64::cast_lossy(max_len)).floor());
244+
for part in &parts {
245+
assert!(
246+
part.merge_depth <= expected_merge_depth,
247+
"expected at most {expected_merge_depth} merges for a tree \
248+
with max len {max_len} and {items} elements, but got {}",
249+
part.merge_depth
250+
);
251+
}
252+
assert!(
253+
parts
254+
.iter()
255+
.flat_map(|v| v.elements.iter())
256+
.copied()
257+
.eq(0..items),
258+
"no parts lost"
259+
);
195260
}
196261
}
197262
}

0 commit comments

Comments
 (0)