Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions crates/burn-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,46 @@ pub mod network;

/// Parallel utilities.
pub mod parallel;

/// Tensor utilities.
pub mod tensor {
use alloc::vec::Vec;

/// Check if the current tensor is contiguous.
///
/// A tensor is considered contiguous if its elements are stored in memory
/// such that the stride at position `k` is equal to the product of the shapes
/// of all dimensions greater than `k`.
///
/// This means that strides increase as you move from the rightmost to the leftmost dimension.
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think the check could be simplified to

pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
    if shape.is_empty() {
        return true;
    }

    let mut expected_stride = 1;

    for (&shape, &stride) in shape.iter().zip(strides).rev() {
        if shape == 1 {
            continue;
        }

        if stride != expected_stride {
            return false;
        }

        expected_stride *= shape;
    }

    true
}

But leaving this here as a note and not a required change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer that strategy actually, we could even create a function that could be used to create the strides AND be used for the checks:

pub fn contiguous_strides(shape: &[usize]) -> impl Iterator<Item = usize> {
    let mut current = 1;
    shape
            .iter()
            .rev()
            .map(|val| {
                current *= val;
                current
            })
            .rev()
}
pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { 
    if shape.is_empty() {
         return true;
     }
     
     for (i, expected) in contiguous_strides(shape).enumerate() {
         if expected != strides[i] {
             return false;
         }
     }
     
     true

Copy link
Member Author

@laggui laggui Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The double .rev() cancel out for lazy iterator operations. Also, the first computed stride value should be 1.

I changed the suggested contiguous_strides implementation to match these criteria.

if shape.is_empty() {
return true;
}

for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
if expected != stride {
return false;
}
}

true
}

/// Computes the strides for a contiguous tensor with the given shape.
///
/// In a contiguous row-major tensor, the stride for each dimension
/// equals the product of all dimension sizes to its right.
pub fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = Vec::with_capacity(shape.len());
let mut current = 1;

for &dim in shape.iter().rev() {
strides.push(current);
current *= dim;
}

strides.reverse();
strides
}
}
30 changes: 0 additions & 30 deletions crates/burn-cubecl-fusion/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,33 +137,3 @@ impl<R: Runtime> CubeFusionHandle<R> {
}
}
}

pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}

if shape.len() == 1 {
return strides[0] == 1;
}

let mut prev_stride = 1;
let mut current_num_elems_shape = 1;

for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
if i > 0 {
if current_num_elems_shape != *stride {
return false;
}

if prev_stride >= *stride {
return false;
}
}

current_num_elems_shape *= shape;
prev_stride = *stride;
}

true
}
3 changes: 2 additions & 1 deletion crates/burn-cubecl-fusion/src/shared/trace/output.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use burn_common::tensor::is_contiguous;
use burn_fusion::stream::Context;
use burn_ir::{TensorId, TensorIr};
use burn_tensor::DType;
use cubecl::{CubeElement, Runtime, client::ComputeClient, ir::Elem};

use crate::{
CubeFusionHandle, elem_dtype, is_contiguous,
CubeFusionHandle, elem_dtype,
shared::ir::{Arg, FuseOp, LayoutInfo},
strides_dyn_rank,
};
Expand Down
44 changes: 9 additions & 35 deletions crates/burn-cubecl/src/tensor/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::CubeRuntime;
use crate::element::CubeElement;
use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric};
use burn_common::tensor::is_contiguous;
use burn_tensor::quantization::QTensorPrimitive;
use burn_tensor::{DType, Shape, TensorMetadata};
use cubecl::client::ComputeClient;
Expand Down Expand Up @@ -435,8 +436,8 @@ where

/// Check if the current tensor is contiguous.
///
/// A tensor is contiguous if the elements are stored
/// if the strides in strict decreasing order and the
/// A tensor is contiguous if the elements are stored in memory
/// if the strides in non-increasing order and the
/// strides at position k is equal to the product of the shapes
/// at all positions greater than k. However, all axes with a shape of 1 are ignored.
pub fn is_contiguous(&self) -> bool {
Expand All @@ -450,41 +451,14 @@ where
}
}

pub(crate) fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}

let mut prev_stride = 1;
let mut current_num_elems_shape = 1;

for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
if *shape == 1 {
continue;
}

if i > 0 {
if current_num_elems_shape != *stride {
return false;
}

if prev_stride >= *stride {
return false;
}
} else if *stride != 1 {
return false;
}

current_num_elems_shape *= shape;
prev_stride = *stride;
}

true
}

#[cfg(test)]
mod tests {
use crate::tensor::base::is_contiguous;
use super::*;

#[test]
fn is_contiguous_non_increasing() {
assert!(is_contiguous(&[3, 1], &[1, 1]));
}

#[test]
fn is_contiguous_basic() {
Expand Down
37 changes: 7 additions & 30 deletions crates/burn-ndarray/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ macro_rules! execute_with_float_dtype {
}

mod utils {
use burn_common::tensor::is_contiguous;

use super::*;

impl<E> NdArrayTensor<E>
Expand Down Expand Up @@ -219,40 +221,15 @@ mod utils {

pub(crate) fn is_contiguous(&self) -> bool {
let shape = self.array.shape();
let strides = self.array.strides();

if shape.is_empty() {
return true;
}

if shape.len() == 1 {
return strides[0] == 1;
}

let mut prev_stride = 1;
let mut current_num_elems_shape = 1;
let mut strides = Vec::with_capacity(self.array.strides().len());

for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
let stride = if *stride <= 0 {
for &stride in self.array.strides() {
if stride <= 0 {
return false;
} else {
*stride as usize
};
if i > 0 {
if current_num_elems_shape != stride {
return false;
}

if prev_stride > stride {
return false;
}
}

current_num_elems_shape *= shape;
prev_stride = stride;
strides.push(stride as usize);
}

true
is_contiguous(shape, &strides)
}
}
}
Expand Down