Skip to content

Commit b7af7d6

Browse files
authored
fix: Implement dedicated map growable (#4435)
## Changes Made Implements dedicated map growable to circumvent unwanted physical casting of child fields ## Related Issues Closes #4432 ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 00a302b commit b7af7d6

File tree

3 files changed

+495
-2
lines changed

3 files changed

+495
-2
lines changed
Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
use common_error::DaftResult;
2+
3+
use super::{list_growable::ListGrowable, Growable};
4+
use crate::{
5+
datatypes::{logical::MapArray, DataType, Field},
6+
series::{IntoSeries, Series},
7+
};
8+
9+
pub struct MapGrowable<'a> {
10+
name: String,
11+
dtype: DataType,
12+
list_growable: ListGrowable<'a>,
13+
}
14+
15+
impl<'a> MapGrowable<'a> {
16+
pub fn new(
17+
name: &str,
18+
dtype: &DataType,
19+
arrays: Vec<&'a MapArray>,
20+
use_validity: bool,
21+
capacity: usize,
22+
) -> Self {
23+
match dtype {
24+
DataType::Map { key, value } => {
25+
let physical_arrays: Vec<&crate::array::ListArray> =
26+
arrays.iter().map(|a| &a.physical).collect();
27+
28+
let list_growable = ListGrowable::new(
29+
name,
30+
// instead of doing dtype.to_physical(), which will recursively convert all children to physical types,
31+
// we just want the top level physical type, which is a list of structs, and the inner dtypes should remain
32+
// untouched and dealt with by the struct growable.
33+
&DataType::List(Box::new(DataType::Struct(vec![
34+
Field::new("key", *key.clone()),
35+
Field::new("value", *value.clone()),
36+
]))),
37+
physical_arrays,
38+
use_validity,
39+
capacity,
40+
0, // child_capacity - use default for now
41+
);
42+
43+
Self {
44+
name: name.to_string(),
45+
dtype: dtype.clone(),
46+
list_growable,
47+
}
48+
}
49+
_ => panic!("Cannot create MapGrowable from dtype: {}", dtype),
50+
}
51+
}
52+
}
53+
54+
impl Growable for MapGrowable<'_> {
55+
fn extend(&mut self, index: usize, start: usize, len: usize) {
56+
self.list_growable.extend(index, start, len);
57+
}
58+
59+
fn add_nulls(&mut self, additional: usize) {
60+
self.list_growable.add_nulls(additional);
61+
}
62+
63+
fn build(&mut self) -> DaftResult<Series> {
64+
let physical_series = self.list_growable.build()?;
65+
let physical_list = physical_series.list()?;
66+
67+
let map_array = MapArray::new(
68+
Field::new(self.name.clone(), self.dtype.clone()),
69+
physical_list.clone(),
70+
);
71+
Ok(map_array.into_series())
72+
}
73+
}
74+
75+
#[cfg(test)]
76+
mod tests {
77+
use super::*;
78+
use crate::{
79+
array::{ListArray, StructArray},
80+
datatypes::{DaftArrayType, DataType, Field, Int64Array, Utf8Array},
81+
series::IntoSeries,
82+
};
83+
84+
fn create_test_map_array(name: &str, keys: Vec<&str>, values: Vec<i64>) -> MapArray {
85+
assert_eq!(keys.len(), values.len());
86+
87+
let num_entries = keys.len();
88+
89+
// Create key and value series using proper from_iter methods
90+
let key_array = Utf8Array::from_iter("key", keys.into_iter().map(|s| Some(s.to_string())));
91+
let value_array = Int64Array::from_iter(
92+
Field::new("value", DataType::Int64),
93+
values.into_iter().map(Some),
94+
);
95+
96+
// Create struct array with key-value pairs
97+
let struct_array = StructArray::new(
98+
Field::new(
99+
"entries",
100+
DataType::Struct(vec![
101+
Field::new("key", DataType::Utf8),
102+
Field::new("value", DataType::Int64),
103+
]),
104+
),
105+
vec![key_array.into_series(), value_array.into_series()],
106+
None,
107+
);
108+
109+
// Create list array with one entry containing all key-value pairs
110+
let list_array = ListArray::new(
111+
Field::new(
112+
name,
113+
DataType::List(Box::new(DataType::Struct(vec![
114+
Field::new("key", DataType::Utf8),
115+
Field::new("value", DataType::Int64),
116+
]))),
117+
),
118+
struct_array.into_series(),
119+
arrow2::offset::OffsetsBuffer::try_from(vec![0i64, num_entries as i64]).unwrap(),
120+
None,
121+
);
122+
123+
// Create map array
124+
MapArray::new(
125+
Field::new(
126+
name,
127+
DataType::Map {
128+
key: Box::new(DataType::Utf8),
129+
value: Box::new(DataType::Int64),
130+
},
131+
),
132+
list_array,
133+
)
134+
}
135+
136+
fn create_nested_map_array(
137+
name: &str,
138+
outer_key: &str,
139+
inner_key: &str,
140+
inner_value: i64,
141+
) -> MapArray {
142+
// Create inner map
143+
let inner_map = create_test_map_array("inner_map", vec![inner_key], vec![inner_value]);
144+
145+
// Create outer map
146+
let outer_key_array =
147+
Utf8Array::from_iter("key", std::iter::once(Some(outer_key.to_string())));
148+
let outer_struct_array = StructArray::new(
149+
Field::new(
150+
"entries",
151+
DataType::Struct(vec![
152+
Field::new("key", DataType::Utf8),
153+
Field::new(
154+
"value",
155+
DataType::Map {
156+
key: Box::new(DataType::Utf8),
157+
value: Box::new(DataType::Int64),
158+
},
159+
),
160+
]),
161+
),
162+
vec![
163+
outer_key_array.into_series(),
164+
inner_map.into_series().rename("value"),
165+
],
166+
None,
167+
);
168+
169+
let outer_list_array = ListArray::new(
170+
Field::new(
171+
name,
172+
DataType::List(Box::new(DataType::Struct(vec![
173+
Field::new("key", DataType::Utf8),
174+
Field::new(
175+
"value",
176+
DataType::Map {
177+
key: Box::new(DataType::Utf8),
178+
value: Box::new(DataType::Int64),
179+
},
180+
),
181+
]))),
182+
),
183+
outer_struct_array.into_series(),
184+
arrow2::offset::OffsetsBuffer::try_from(vec![0i64, 1i64]).unwrap(),
185+
None,
186+
);
187+
188+
MapArray::new(
189+
Field::new(
190+
name,
191+
DataType::Map {
192+
key: Box::new(DataType::Utf8),
193+
value: Box::new(DataType::Map {
194+
key: Box::new(DataType::Utf8),
195+
value: Box::new(DataType::Int64),
196+
}),
197+
},
198+
),
199+
outer_list_array,
200+
)
201+
}
202+
203+
fn verify_map_entry(
204+
map: &crate::datatypes::logical::MapArray,
205+
index: usize,
206+
expected_keys: &[&str],
207+
expected_values: &[i64],
208+
) -> DaftResult<()> {
209+
let entry = map.get(index).unwrap();
210+
let struct_array = entry.struct_()?;
211+
let keys_series = struct_array.get("key")?;
212+
let values_series = struct_array.get("value")?;
213+
let keys = keys_series.utf8()?;
214+
let values = values_series.i64()?;
215+
216+
assert_eq!(keys.len(), expected_keys.len());
217+
assert_eq!(values.len(), expected_values.len());
218+
219+
for (i, (expected_key, expected_value)) in
220+
expected_keys.iter().zip(expected_values.iter()).enumerate()
221+
{
222+
assert_eq!(keys.get(i), Some(*expected_key));
223+
assert_eq!(values.get(i), Some(*expected_value));
224+
}
225+
Ok(())
226+
}
227+
228+
fn verify_nested_map_entry(
229+
map: &crate::datatypes::logical::MapArray,
230+
index: usize,
231+
outer_key: &str,
232+
inner_key: &str,
233+
inner_value: i64,
234+
) -> DaftResult<()> {
235+
let entry = map.get(index).unwrap();
236+
let struct_array = entry.struct_()?;
237+
let keys_series = struct_array.get("key")?;
238+
let values_series = struct_array.get("value")?;
239+
let keys = keys_series.utf8()?;
240+
let values_map = values_series.map()?;
241+
242+
assert_eq!(keys.get(0), Some(outer_key));
243+
244+
let nested_entry = values_map.get(0).unwrap();
245+
let nested_struct = nested_entry.struct_()?;
246+
let nested_keys_series = nested_struct.get("key")?;
247+
let nested_values_series = nested_struct.get("value")?;
248+
let nested_keys = nested_keys_series.utf8()?;
249+
let nested_values = nested_values_series.i64()?;
250+
251+
assert_eq!(nested_keys.get(0), Some(inner_key));
252+
assert_eq!(nested_values.get(0), Some(inner_value));
253+
Ok(())
254+
}
255+
256+
#[test]
257+
fn test_map_growable_basic() -> DaftResult<()> {
258+
let map1 = create_test_map_array("test_map", vec!["a", "b"], vec![1, 2]);
259+
let map2 = create_test_map_array("test_map", vec!["c", "d"], vec![3, 4]);
260+
261+
let map_dtype = DataType::Map {
262+
key: Box::new(DataType::Utf8),
263+
value: Box::new(DataType::Int64),
264+
};
265+
266+
let mut growable = MapGrowable::new("result", &map_dtype, vec![&map1, &map2], false, 2);
267+
growable.extend(0, 0, 1);
268+
growable.extend(1, 0, 1);
269+
270+
let result = growable.build()?;
271+
let result_map = result.map()?;
272+
273+
assert_eq!(result_map.len(), 2);
274+
assert_eq!(result_map.data_type(), &map_dtype);
275+
276+
verify_map_entry(result_map, 0, &["a", "b"], &[1, 2])?;
277+
verify_map_entry(result_map, 1, &["c", "d"], &[3, 4])?;
278+
279+
Ok(())
280+
}
281+
282+
#[test]
283+
fn test_map_growable_with_nulls() -> DaftResult<()> {
284+
let map1 = create_test_map_array("test_map", vec!["a"], vec![1]);
285+
let map2 = create_test_map_array("test_map", vec!["b"], vec![2]);
286+
287+
let map_dtype = DataType::Map {
288+
key: Box::new(DataType::Utf8),
289+
value: Box::new(DataType::Int64),
290+
};
291+
292+
let mut growable = MapGrowable::new("result", &map_dtype, vec![&map1, &map2], true, 3);
293+
growable.extend(0, 0, 1);
294+
growable.add_nulls(1);
295+
growable.extend(1, 0, 1);
296+
297+
let result = growable.build()?;
298+
let result_map = result.map()?;
299+
300+
assert_eq!(result_map.len(), 3);
301+
verify_map_entry(result_map, 0, &["a"], &[1])?;
302+
assert!(result_map.get(1).is_none());
303+
verify_map_entry(result_map, 2, &["b"], &[2])?;
304+
305+
Ok(())
306+
}
307+
308+
#[test]
309+
fn test_map_growable_multiple_extends() -> DaftResult<()> {
310+
let map1 = create_test_map_array("test_map", vec!["a", "b", "c"], vec![1, 2, 3]);
311+
let map2 = create_test_map_array("test_map", vec!["d", "e"], vec![4, 5]);
312+
313+
let map_dtype = DataType::Map {
314+
key: Box::new(DataType::Utf8),
315+
value: Box::new(DataType::Int64),
316+
};
317+
318+
let mut growable = MapGrowable::new("result", &map_dtype, vec![&map1, &map2], false, 5);
319+
for _ in 0..3 {
320+
growable.extend(0, 0, 1);
321+
growable.extend(1, 0, 1);
322+
}
323+
324+
let result = growable.build()?;
325+
let result_map = result.map()?;
326+
327+
assert_eq!(result_map.len(), 6);
328+
for i in [0, 2, 4] {
329+
verify_map_entry(result_map, i, &["a", "b", "c"], &[1, 2, 3])?;
330+
}
331+
for i in [1, 3, 5] {
332+
verify_map_entry(result_map, i, &["d", "e"], &[4, 5])?;
333+
}
334+
335+
Ok(())
336+
}
337+
338+
#[test]
339+
fn test_map_growable_empty() -> DaftResult<()> {
340+
let map1 = create_test_map_array("test_map", vec![], vec![]);
341+
let map_dtype = DataType::Map {
342+
key: Box::new(DataType::Utf8),
343+
value: Box::new(DataType::Int64),
344+
};
345+
346+
let mut growable = MapGrowable::new("result", &map_dtype, vec![&map1], false, 1);
347+
growable.extend(0, 0, 1);
348+
349+
let result = growable.build()?;
350+
let result_map = result.map()?;
351+
352+
assert_eq!(result_map.len(), 1);
353+
verify_map_entry(result_map, 0, &[], &[])?;
354+
355+
Ok(())
356+
}
357+
358+
#[test]
359+
fn test_map_growable_nested_maps() -> DaftResult<()> {
360+
let nested_map = create_nested_map_array("test_map", "outer_key", "inner_key", 42);
361+
362+
let map_dtype = DataType::Map {
363+
key: Box::new(DataType::Utf8),
364+
value: Box::new(DataType::Map {
365+
key: Box::new(DataType::Utf8),
366+
value: Box::new(DataType::Int64),
367+
}),
368+
};
369+
370+
let mut growable = MapGrowable::new("result", &map_dtype, vec![&nested_map], false, 1);
371+
growable.extend(0, 0, 1);
372+
373+
let result = growable.build()?;
374+
let result_map = result.map()?;
375+
376+
assert_eq!(result_map.len(), 1);
377+
verify_nested_map_entry(result_map, 0, "outer_key", "inner_key", 42)?;
378+
379+
Ok(())
380+
}
381+
}

src/daft-core/src/array/growable/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod bitmap_growable;
1212
mod fixed_size_list_growable;
1313
mod list_growable;
1414
mod logical_growable;
15+
mod map_growable;
1516
mod struct_growable;
1617

1718
#[cfg(feature = "python")]
@@ -223,4 +224,4 @@ impl_growable_array!(
223224
);
224225
impl_growable_array!(ImageArray, logical_growable::LogicalImageGrowable<'a>);
225226
impl_growable_array!(TensorArray, logical_growable::LogicalTensorGrowable<'a>);
226-
impl_growable_array!(MapArray, logical_growable::LogicalMapGrowable<'a>);
227+
impl_growable_array!(MapArray, map_growable::MapGrowable<'a>);

0 commit comments

Comments
 (0)