Skip to content

Commit 37cfc1c

Browse files
committed
Make is_contiguous check common (#3083)
1 parent 6fbd852 commit 37cfc1c

File tree

5 files changed

+61
-94
lines changed

5 files changed

+61
-94
lines changed

crates/burn-common/src/lib.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,46 @@ pub mod network;
2222

2323
/// Parallel utilities.
2424
pub mod parallel;
25+
26+
/// Tensor utilities.
27+
pub mod tensor {
28+
use alloc::vec::Vec;
29+
30+
/// Check if the current tensor is contiguous.
31+
///
32+
/// A tensor is considered contiguous if its elements are stored in memory
33+
/// such that the stride at position `k` is equal to the product of the shapes
34+
/// of all dimensions greater than `k`.
35+
///
36+
/// This means that strides increase as you move from the rightmost to the leftmost dimension.
37+
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
38+
if shape.is_empty() {
39+
return true;
40+
}
41+
42+
for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
43+
if expected != stride {
44+
return false;
45+
}
46+
}
47+
48+
true
49+
}
50+
51+
/// Computes the strides for a contiguous tensor with the given shape.
52+
///
53+
/// In a contiguous row-major tensor, the stride for each dimension
54+
/// equals the product of all dimension sizes to its right.
55+
pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
56+
let mut strides = Vec::with_capacity(shape.len());
57+
let mut current = 1;
58+
59+
for &dim in shape.iter().rev() {
60+
strides.push(current);
61+
current *= dim;
62+
}
63+
64+
strides.reverse();
65+
strides
66+
}
67+
}

crates/burn-cubecl-fusion/src/base.rs

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -137,33 +137,3 @@ impl<R: Runtime> CubeFusionHandle<R> {
137137
}
138138
}
139139
}
140-
141-
pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
142-
if shape.is_empty() {
143-
return true;
144-
}
145-
146-
if shape.len() == 1 {
147-
return strides[0] == 1;
148-
}
149-
150-
let mut prev_stride = 1;
151-
let mut current_num_elems_shape = 1;
152-
153-
for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
154-
if i > 0 {
155-
if current_num_elems_shape != *stride {
156-
return false;
157-
}
158-
159-
if prev_stride >= *stride {
160-
return false;
161-
}
162-
}
163-
164-
current_num_elems_shape *= shape;
165-
prev_stride = *stride;
166-
}
167-
168-
true
169-
}

crates/burn-cubecl-fusion/src/shared/trace/output.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
use burn_common::tensor::is_contiguous;
12
use burn_fusion::stream::Context;
23
use burn_ir::{TensorId, TensorIr};
34
use burn_tensor::DType;
45
use cubecl::{CubeElement, Runtime, client::ComputeClient, ir::Elem};
56

67
use crate::{
7-
CubeFusionHandle, elem_dtype, is_contiguous,
8+
CubeFusionHandle, elem_dtype,
89
shared::ir::{Arg, FuseOp, LayoutInfo},
910
strides_dyn_rank,
1011
};

crates/burn-cubecl/src/tensor/base.rs

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::CubeRuntime;
22
use crate::element::CubeElement;
33
use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
4+
use burn_common::tensor::is_contiguous;
45
use burn_tensor::quantization::QTensorPrimitive;
56
use burn_tensor::{DType, Shape, TensorMetadata};
67
use cubecl::client::ComputeClient;
@@ -440,8 +441,8 @@ where
440441

441442
/// Check if the current tensor is contiguous.
442443
///
443-
/// A tensor is contiguous if the elements are stored
444-
/// if the strides in strict decreasing order and the
444+
/// A tensor is contiguous if the elements are stored in memory
445+
/// if the strides in non-increasing order and the
445446
/// strides at position k is equal to the product of the shapes
446447
/// at all positions greater than k. However, all axes with a shape of 1 are ignored.
447448
pub fn is_contiguous(&self) -> bool {
@@ -455,39 +456,14 @@ where
455456
}
456457
}
457458

458-
pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
459-
if shape.is_empty() {
460-
return true;
461-
}
462-
463-
let mut prev_stride = 1;
464-
let mut current_num_elems_shape = 1;
465-
466-
for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
467-
if *shape == 1 {
468-
continue;
469-
}
470-
471-
if i > 0 {
472-
if current_num_elems_shape != *stride {
473-
return false;
474-
}
475-
476-
if prev_stride >= *stride {
477-
return false;
478-
}
479-
}
480-
481-
current_num_elems_shape *= shape;
482-
prev_stride = *stride;
483-
}
484-
485-
true
486-
}
487-
488459
#[cfg(test)]
489460
mod tests {
490-
use crate::tensor::base::is_contiguous;
461+
use super::*;
462+
463+
#[test]
464+
fn is_contiguous_non_increasing() {
465+
assert!(is_contiguous(&[3, 1], &[1, 1]));
466+
}
491467

492468
#[test]
493469
fn is_contiguous_basic() {

crates/burn-ndarray/src/tensor.rs

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ macro_rules! execute_with_float_dtype {
190190
}
191191

192192
mod utils {
193+
use burn_common::tensor::is_contiguous;
194+
193195
use super::*;
194196

195197
impl<E> NdArrayTensor<E>
@@ -219,40 +221,15 @@ mod utils {
219221

220222
pub(crate) fn is_contiguous(&self) -> bool {
221223
let shape = self.array.shape();
222-
let strides = self.array.strides();
223-
224-
if shape.is_empty() {
225-
return true;
226-
}
227-
228-
if shape.len() == 1 {
229-
return strides[0] == 1;
230-
}
231-
232-
let mut prev_stride = 1;
233-
let mut current_num_elems_shape = 1;
224+
let mut strides = Vec::with_capacity(self.array.strides().len());
234225

235-
for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
236-
let stride = if *stride <= 0 {
226+
for &stride in self.array.strides() {
227+
if stride <= 0 {
237228
return false;
238-
} else {
239-
*stride as usize
240-
};
241-
if i > 0 {
242-
if current_num_elems_shape != stride {
243-
return false;
244-
}
245-
246-
if prev_stride > stride {
247-
return false;
248-
}
249229
}
250-
251-
current_num_elems_shape *= shape;
252-
prev_stride = stride;
230+
strides.push(stride as usize);
253231
}
254-
255-
true
232+
is_contiguous(shape, &strides)
256233
}
257234
}
258235
}

0 commit comments

Comments
 (0)