11use crate :: CubeRuntime ;
22use crate :: element:: CubeElement ;
33use crate :: kernel:: { NumericUnaryOp , NumericUnaryOpFamily , launch_unary_numeric} ;
4+ use burn_common:: tensor:: is_contiguous;
45use burn_tensor:: quantization:: QTensorPrimitive ;
56use burn_tensor:: { DType , Shape , TensorMetadata } ;
67use cubecl:: client:: ComputeClient ;
@@ -440,8 +441,8 @@ where
440441
441442 /// Check if the current tensor is contiguous.
442443 ///
443- /// A tensor is contiguous if the elements are stored
444- /// if the strides in strict decreasing order and the
444+ /// A tensor is contiguous if the elements are stored in memory
445+ /// if the strides in non-increasing order and the
445446 /// strides at position k is equal to the product of the shapes
446447 /// at all positions greater than k. However, all axes with a shape of 1 are ignored.
447448 pub fn is_contiguous ( & self ) -> bool {
@@ -455,39 +456,14 @@ where
455456 }
456457}
457458
458- pub ( crate ) fn is_contiguous ( shape : & [ usize ] , strides : & [ usize ] ) -> bool {
459- if shape. is_empty ( ) {
460- return true ;
461- }
462-
463- let mut prev_stride = 1 ;
464- let mut current_num_elems_shape = 1 ;
465-
466- for ( i, ( stride, shape) ) in strides. iter ( ) . zip ( shape) . rev ( ) . enumerate ( ) {
467- if * shape == 1 {
468- continue ;
469- }
470-
471- if i > 0 {
472- if current_num_elems_shape != * stride {
473- return false ;
474- }
475-
476- if prev_stride >= * stride {
477- return false ;
478- }
479- }
480-
481- current_num_elems_shape *= shape;
482- prev_stride = * stride;
483- }
484-
485- true
486- }
487-
488459#[ cfg( test) ]
489460mod tests {
490- use crate :: tensor:: base:: is_contiguous;
461+ use super :: * ;
462+
463+ #[ test]
464+ fn is_contiguous_non_increasing ( ) {
465+ assert ! ( is_contiguous( & [ 3 , 1 ] , & [ 1 , 1 ] ) ) ;
466+ }
491467
492468 #[ test]
493469 fn is_contiguous_basic ( ) {
0 commit comments