11use crate :: CubeRuntime ;
22use crate :: element:: CubeElement ;
33use crate :: kernel:: { NumericUnaryOp , NumericUnaryOpFamily , launch_unary_numeric} ;
4+ use burn_common:: tensor:: is_contiguous;
45use burn_tensor:: quantization:: QTensorPrimitive ;
56use burn_tensor:: { DType , Shape , TensorMetadata } ;
67use cubecl:: client:: ComputeClient ;
@@ -435,8 +436,8 @@ where
435436
436437 /// Check if the current tensor is contiguous.
437438 ///
438- /// A tensor is contiguous if the elements are stored
439- /// if the strides in strict decreasing order and the
439+ /// A tensor is contiguous if the elements are stored in memory
440+ /// if the strides in non-increasing order and the
440441 /// strides at position k is equal to the product of the shapes
441442 /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
442443 pub fn is_contiguous ( & self ) -> bool {
@@ -450,41 +451,14 @@ where
450451 }
451452}
452453
453- pub ( crate ) fn is_contiguous ( shape : & [ usize ] , strides : & [ usize ] ) -> bool {
454- if shape. is_empty ( ) {
455- return true ;
456- }
457-
458- let mut prev_stride = 1 ;
459- let mut current_num_elems_shape = 1 ;
460-
461- for ( i, ( stride, shape) ) in strides. iter ( ) . zip ( shape) . rev ( ) . enumerate ( ) {
462- if * shape == 1 {
463- continue ;
464- }
465-
466- if i > 0 {
467- if current_num_elems_shape != * stride {
468- return false ;
469- }
470-
471- if prev_stride >= * stride {
472- return false ;
473- }
474- } else if * stride != 1 {
475- return false ;
476- }
477-
478- current_num_elems_shape *= shape;
479- prev_stride = * stride;
480- }
481-
482- true
483- }
484-
485454#[ cfg( test) ]
486455mod tests {
487- use crate :: tensor:: base:: is_contiguous;
456+ use super :: * ;
457+
458+ #[ test]
459+ fn is_contiguous_non_increasing ( ) {
460+ assert ! ( is_contiguous( & [ 3 , 1 ] , & [ 1 , 1 ] ) ) ;
461+ }
488462
489463 #[ test]
490464 fn is_contiguous_basic ( ) {
0 commit comments