10
10
use mz_ore:: task:: { JoinHandle , JoinHandleExt } ;
11
11
use std:: fmt:: { Debug , Formatter } ;
12
12
use std:: mem;
13
+ use std:: ops:: { Deref , DerefMut } ;
13
14
14
15
/// A merge tree.
15
16
///
@@ -23,21 +24,28 @@ use std::mem;
23
24
/// - The "depth" of the merge tree - the number of merges any particular element may undergo -
24
25
/// is `O(log N)`.
25
26
pub struct MergeTree < T > {
26
- pub ( crate ) max_len : usize ,
27
- pub ( crate ) levels : Vec < Vec < T > > ,
27
+ /// Configuration: the largest any level in the tree is allowed to grow.
28
+ max_level_len : usize ,
29
+ /// The length of each level in the tree, stored in order from shallowest to deepest.
30
+ level_lens : Vec < usize > ,
31
+ /// A flattened representation of the contents of the tree, stored in order from earliest /
32
+ /// deepest to newest / shallowest.
33
+ data : Vec < T > ,
28
34
merge_fn : Box < dyn Fn ( Vec < T > ) -> T + Sync + Send > ,
29
35
}
30
36
31
37
impl < T : Debug > Debug for MergeTree < T > {
32
38
fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
33
39
let Self {
34
- max_len,
35
- levels,
40
+ max_level_len,
41
+ level_lens,
42
+ data,
36
43
merge_fn : _,
37
44
} = self ;
38
45
f. debug_struct ( "MergeTree" )
39
- . field ( "max_len" , max_len)
40
- . field ( "levels" , levels)
46
+ . field ( "max_level_len" , max_level_len)
47
+ . field ( "level_lens" , level_lens)
48
+ . field ( "data" , data)
41
49
. finish_non_exhaustive ( )
42
50
}
43
51
}
@@ -48,90 +56,109 @@ impl<T> MergeTree<T> {
48
56
/// limit, the provided `merge_fn` is used to combine adjacent elements together.
49
57
pub fn new ( max_len : usize , merge_fn : impl Fn ( Vec < T > ) -> T + Send + Sync + ' static ) -> Self {
50
58
let new = Self {
51
- max_len,
52
- levels : vec ! [ vec![ ] ] ,
59
+ max_level_len : max_len,
60
+ level_lens : vec ! [ 0 ] ,
61
+ data : vec ! [ ] ,
53
62
merge_fn : Box :: new ( merge_fn) ,
54
63
} ;
55
64
new. assert_invariants ( ) ;
56
65
new
57
66
}
58
67
59
- /// Iterate over (references to) the parts in this tree in first-to-latest order.
60
- #[ allow( unused) ]
61
- pub fn iter ( & self ) -> impl Iterator < Item = & T > + DoubleEndedIterator {
62
- self . levels . iter ( ) . rev ( ) . flat_map ( |l| l. iter ( ) )
63
- }
64
-
65
- /// Iterate over (mutable references to) the parts in this tree in first-to-latest order.
66
- pub fn iter_mut ( & mut self ) -> impl Iterator < Item = & mut T > + DoubleEndedIterator {
67
- self . levels . iter_mut ( ) . rev ( ) . flat_map ( |l| l. iter_mut ( ) )
68
+ fn merge_last ( & mut self , level_len : usize ) {
69
+ let offset = self . data . len ( ) - level_len;
70
+ let split = self . data . split_off ( offset) ;
71
+ let merged = ( self . merge_fn ) ( split) ;
72
+ self . data . push ( merged) ;
68
73
}
69
74
70
75
/// Push a new part onto the end of this tree, possibly triggering a merge.
71
- pub fn push ( & mut self , mut part : T ) {
76
+ pub fn push ( & mut self , part : T ) {
72
77
// Normally, all levels have strictly less than max_len elements.
73
78
// However, the _deepest_ level is allowed to have exactly max_len elements,
74
79
// since that can save us an unnecessary merge in some cases.
75
80
// (For example, when precisely max_len elements are added.)
76
- if let Some ( last) = self . levels . last_mut ( ) {
77
- if last. len ( ) == self . max_len {
78
- let merged = ( self . merge_fn ) ( mem:: take ( last) ) ;
79
- self . levels . push ( vec ! [ merged] ) ;
81
+ if let Some ( last_len) = self . level_lens . last_mut ( ) {
82
+ if * last_len == self . max_level_len {
83
+ let len = mem:: take ( last_len) ;
84
+ self . merge_last ( len) ;
85
+ self . level_lens . push ( 1 ) ;
80
86
}
81
87
}
82
88
83
89
// At this point, all levels have room. Add our new part, then continue
84
90
// merging up the tree until either there's still room in the current level
85
91
// or we've reached the top.
86
- let max_level = self . levels . len ( ) - 1 ;
92
+ self . data . push ( part) ;
93
+
94
+ let max_level = self . level_lens . len ( ) - 1 ;
87
95
for depth in 0 ..=max_level {
88
- let level = & mut self . levels [ depth] ;
89
- level . push ( part ) ;
96
+ let level_len = & mut self . level_lens [ depth] ;
97
+ * level_len += 1 ;
90
98
91
- if level . len ( ) < self . max_len || depth == max_level {
99
+ if * level_len < self . max_level_len || depth == max_level {
92
100
break ;
93
101
}
94
102
95
- part = ( self . merge_fn ) ( mem:: take ( level) ) ;
103
+ let len = mem:: take ( level_len) ;
104
+ self . merge_last ( len) ;
96
105
}
97
106
}
98
107
99
108
/// Return the contents of this merge tree, flattened into at most `max_len` parts.
100
- pub fn finish ( self ) -> Vec < T > {
101
- self . levels
102
- . into_iter ( )
103
- . reduce ( |mut shallower, mut deeper| {
104
- if shallower. len ( ) + deeper. len ( ) <= self . max_len {
105
- // Optimization: if there's enough room in the next level for everything at the
106
- // current level, add it directly.
107
- deeper. append ( & mut shallower) ;
108
- } else {
109
- // Otherwise, merge this up as if it were a full level.
110
- let merged = ( self . merge_fn ) ( shallower) ;
111
- deeper. push ( merged) ;
112
- }
113
- deeper
114
- } )
115
- . expect ( "non-empty level array" )
109
+ pub fn finish ( mut self ) -> Vec < T > {
110
+ let mut tail_len = 0 ;
111
+ for level_len in mem:: take ( & mut self . level_lens ) {
112
+ if tail_len + level_len <= self . max_level_len {
113
+ // Optimization: we can combine the current level with the last level without
114
+ // going over our limit.
115
+ tail_len += level_len;
116
+ } else {
117
+ // Otherwise, perform the merge and start a new tail.
118
+ self . merge_last ( tail_len) ;
119
+ tail_len = level_len + 1
120
+ }
121
+ }
122
+ assert ! ( self . data. len( ) <= self . max_level_len) ;
123
+ self . data
116
124
}
117
125
118
126
pub ( crate ) fn assert_invariants ( & self ) {
119
- assert ! ( self . max_len >= 2 , "max_len must be at least 2" ) ;
127
+ assert ! ( self . max_level_len >= 2 , "max_len must be at least 2" ) ;
120
128
121
- let ( deepest, shallow) = self . levels . split_last ( ) . expect ( "non-empty level array" ) ;
122
- for ( depth, level) in shallow. iter ( ) . enumerate ( ) {
129
+ assert_eq ! (
130
+ self . data. len( ) ,
131
+ self . level_lens. iter( ) . copied( ) . sum:: <usize >( ) ,
132
+ "level sizes should sum to overall len"
133
+ ) ;
134
+ let ( deepest_len, shallow) = self . level_lens . split_last ( ) . expect ( "non-empty level array" ) ;
135
+ for ( depth, level_len) in shallow. iter ( ) . enumerate ( ) {
123
136
assert ! (
124
- level . len ( ) < self . max_len ,
137
+ * level_len < self . max_level_len ,
125
138
"strictly less than max elements at level {depth}"
126
139
) ;
127
140
}
128
141
assert ! (
129
- deepest . len ( ) <= self . max_len ,
142
+ * deepest_len <= self . max_level_len ,
130
143
"at most max elements at deepest level"
131
144
) ;
132
145
}
133
146
}
134
147
148
+ impl < T > Deref for MergeTree < T > {
149
+ type Target = [ T ] ;
150
+
151
+ fn deref ( & self ) -> & Self :: Target {
152
+ & * self . data
153
+ }
154
+ }
155
+
156
+ impl < T > DerefMut for MergeTree < T > {
157
+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
158
+ & mut * self . data
159
+ }
160
+ }
161
+
135
162
/// Either a handle to a task that returns a value or the value itself.
136
163
#[ derive( Debug ) ]
137
164
pub enum Pending < T > {
@@ -167,21 +194,37 @@ impl<T: Send + 'static> Pending<T> {
167
194
#[ cfg( test) ]
168
195
mod tests {
169
196
use super :: * ;
197
+ use mz_ore:: cast:: CastLossy ;
170
198
171
199
#[ mz_ore:: test]
172
200
#[ cfg_attr( miri, ignore) ] // too slow
173
201
fn test_merge_tree ( ) {
174
202
// Exhaustively test the merge tree for small sizes.
203
+ struct Value {
204
+ merge_depth : usize ,
205
+ elements : Vec < i64 > ,
206
+ }
207
+
175
208
for max_len in 2 ..8 {
176
209
for items in 0 ..100 {
177
- let mut merge_tree = MergeTree :: new ( max_len, |vals : Vec < Vec < usize > > | {
210
+ let mut merge_tree = MergeTree :: new ( max_len, |vals : Vec < Value > | {
178
211
// Merge sequences by concatenation.
179
- vals. into_iter ( ) . flatten ( ) . collect ( )
212
+ Value {
213
+ merge_depth : vals. iter ( ) . map ( |v| v. merge_depth ) . max ( ) . unwrap_or ( 0 ) + 1 ,
214
+ elements : vals. into_iter ( ) . flat_map ( |e| e. elements ) . collect ( ) ,
215
+ }
180
216
} ) ;
181
217
for i in 0 ..items {
182
- merge_tree. push ( vec ! [ i] ) ;
218
+ merge_tree. push ( Value {
219
+ merge_depth : 0 ,
220
+ elements : vec ! [ i] ,
221
+ } ) ;
183
222
assert ! (
184
- merge_tree. iter( ) . flatten( ) . copied( ) . eq( 0 ..=i) ,
223
+ merge_tree
224
+ . iter( )
225
+ . flat_map( |v| v. elements. iter( ) )
226
+ . copied( )
227
+ . eq( 0 ..=i) ,
185
228
"no parts should be lost"
186
229
) ;
187
230
merge_tree. assert_invariants ( ) ;
@@ -191,7 +234,29 @@ mod tests {
191
234
parts. len( ) <= max_len,
192
235
"no more than {max_len} finished parts"
193
236
) ;
194
- assert ! ( parts. into_iter( ) . flatten( ) . eq( 0 ..items) , "no parts lost" ) ;
237
+
238
+ // We want our merged tree to be "balanced".
239
+ // If we have 2^N elements in a binary tree, we want the depth to be N;
240
+ // and more generally, we want a depth of N for a K-ary tree with K^N elements...
241
+ // which is to say, a depth of log_K N for a tree with N elements.
242
+ let expected_merge_depth =
243
+ usize:: cast_lossy ( f64:: cast_lossy ( items) . log ( f64:: cast_lossy ( max_len) ) . floor ( ) ) ;
244
+ for part in & parts {
245
+ assert ! (
246
+ part. merge_depth <= expected_merge_depth,
247
+ "expected at most {expected_merge_depth} merges for a tree \
248
+ with max len {max_len} and {items} elements, but got {}",
249
+ part. merge_depth
250
+ ) ;
251
+ }
252
+ assert ! (
253
+ parts
254
+ . iter( )
255
+ . flat_map( |v| v. elements. iter( ) )
256
+ . copied( )
257
+ . eq( 0 ..items) ,
258
+ "no parts lost"
259
+ ) ;
195
260
}
196
261
}
197
262
}
0 commit comments