Skip to content

Commit d01f798

Browse files
committed
impl Aggregate for [T;N]
1 parent 98b0ce3 commit d01f798

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

luisa_compute/src/lang.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ pub trait Aggregate: Sized {
154154
fn from_nodes<I: Iterator<Item = SafeNodeRef>>(iter: &mut I) -> Self;
155155
}
156156

157+
impl<const N: usize, T: Aggregate> Aggregate for [T; N] {
158+
fn to_nodes(&self, nodes: &mut Vec<SafeNodeRef>) {
159+
for x in self {
160+
x.to_nodes(nodes);
161+
}
162+
}
163+
fn from_nodes<I: Iterator<Item = SafeNodeRef>>(iter: &mut I) -> Self {
164+
unsafe {
165+
let mut ret = std::mem::MaybeUninit::<[T; N]>::uninit();
166+
for i in 0..N {
167+
let x = T::from_nodes(iter);
168+
ret.as_mut_ptr().cast::<T>().add(i).write(x);
169+
}
170+
ret.assume_init()
171+
}
172+
}
173+
}
157174
impl<T: Aggregate> Aggregate for Vec<T> {
158175
fn to_nodes(&self, nodes: &mut Vec<SafeNodeRef>) {
159176
let len_node = __new_user_node(nodes.len());
@@ -427,7 +444,8 @@ impl FnRecorder {
427444
device: None,
428445
block_size: None,
429446
pools: pools.clone(),
430-
arena: parent.as_ref()
447+
arena: parent
448+
.as_ref()
431449
.map(|p| p.borrow().arena.clone())
432450
.unwrap_or_else(|| Rc::new(Bump::new())),
433451
building_kernel: false,

0 commit comments

Comments
 (0)