Skip to content

Bug in cases like eltwise_mult_CSRxDense_oCSR #71

@pthomadakis

Description

@pthomadakis

Running the test case in ops/eltwise_mult_CSRxDense_oCSR.ta unveils a bug in the IR that is hard to notice since it does not affect the results.
Specifically, lowering to loops (--emit-loops) produces the following IR:


module {
  func.func @main() {
    %idx-1 = index.constant -1
    %idx1 = index.constant 1
    %idx0 = index.constant 0
    %cst = arith.constant 2.700000e+00 : f64
    %cst_0 = arith.constant 0.000000e+00 : f64
    %c10 = arith.constant 10 : index
    %c9 = arith.constant 9 : index
    %c8 = arith.constant 8 : index
    %c7 = arith.constant 7 : index
    %c6 = arith.constant 6 : index
    %c5 = arith.constant 5 : index
    %c4 = arith.constant 4 : index
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %c3 = arith.constant 3 : index
    %c2 = arith.constant 2 : index
    %c-1 = arith.constant -1 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() : memref<13xindex>
    %cast = memref.cast %alloc : memref<13xindex> to memref<*xindex>
    call @read_input_sizes_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, i32) -> ()
    %0 = memref.load %alloc[%c0] : memref<13xindex>
    %1 = memref.load %alloc[%c1] : memref<13xindex>
    %2 = memref.load %alloc[%c2] : memref<13xindex>
    %3 = memref.load %alloc[%c3] : memref<13xindex>
    %4 = memref.load %alloc[%c4] : memref<13xindex>
    %5 = memref.load %alloc[%c5] : memref<13xindex>
    %6 = memref.load %alloc[%c6] : memref<13xindex>
    %7 = memref.load %alloc[%c7] : memref<13xindex>
    %8 = memref.load %alloc[%c8] : memref<13xindex>
    %9 = memref.load %alloc[%c9] : memref<13xindex>
    %10 = memref.load %alloc[%c10] : memref<13xindex>
    %alloc_1 = memref.alloc(%0) : memref
    scf.for %arg0 = %c0 to %0 step %c1 {
      memref.store %c0, %alloc_1[%arg0] : memref
    }
    %cast_2 = memref.cast %alloc_1 : memref to memref<*xindex>
    %alloc_3 = memref.alloc(%1) : memref
    scf.for %arg0 = %c0 to %1 step %c1 {
      memref.store %c0, %alloc_3[%arg0] : memref
    }
    %cast_4 = memref.cast %alloc_3 : memref to memref<*xindex>
    %alloc_5 = memref.alloc(%2) : memref
    scf.for %arg0 = %c0 to %2 step %c1 {
      memref.store %c0, %alloc_5[%arg0] : memref
    }
    %cast_6 = memref.cast %alloc_5 : memref to memref<*xindex>
    %alloc_7 = memref.alloc(%3) : memref
    scf.for %arg0 = %c0 to %3 step %c1 {
      memref.store %c0, %alloc_7[%arg0] : memref
    }
    %cast_8 = memref.cast %alloc_7 : memref to memref<*xindex>
    %alloc_9 = memref.alloc(%4) : memref
    scf.for %arg0 = %c0 to %4 step %c1 {
      memref.store %c0, %alloc_9[%arg0] : memref
    }
     %cast_10 = memref.cast %alloc_9 : memref to memref<*xindex> 
     %alloc_11 = memref.alloc(%5) : memref 
    scf.for %arg0 = %c0 to %5 step %c1 {
      memref.store %c0, %alloc_11[%arg0] : memref
    }
     %cast_12 = memref.cast %alloc_11 : memref to memref<*xindex> 
    %alloc_13 = memref.alloc(%6) : memref
    scf.for %arg0 = %c0 to %6 step %c1 {
      memref.store %c0, %alloc_13[%arg0] : memref
    }
    %cast_14 = memref.cast %alloc_13 : memref to memref<*xindex>
    %alloc_15 = memref.alloc(%7) : memref
    scf.for %arg0 = %c0 to %7 step %c1 {
      memref.store %c0, %alloc_15[%arg0] : memref
    }
    %cast_16 = memref.cast %alloc_15 : memref to memref<*xindex>
    %alloc_17 = memref.alloc(%8) : memref
    scf.for %arg0 = %c0 to %8 step %c1 {
      memref.store %cst_0, %alloc_17[%arg0] : memref
    }
    %cast_18 = memref.cast %alloc_17 : memref to memref<*xf64>
    call @read_input_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast_2, %cast_4, %cast_6, %cast_8, %cast_10, %cast_12, %cast_14, %cast_16, %cast_18, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32) -> ()
     %11 = bufferization.to_tensor %alloc_9 restrict writable : memref
     %12 = bufferization.to_tensor %alloc_11 restrict writable : memref 
    %13 = bufferization.to_tensor %alloc_17 restrict writable : memref
    %alloc_19 = memref.alloc(%9, %10) {alignment = 32 : i64} : memref
    linalg.fill ins(%cst : f64) outs(%alloc_19 : memref)
    %14 = bufferization.to_tensor %alloc_19 restrict writable : memref
    %alloc_20 = memref.alloc() : memref<1xindex>
    memref.store %9, %alloc_20[%idx0] : memref<1xindex>
    %15 = bufferization.to_tensor %alloc_20 restrict writable : memref<1xindex>
    %alloc_21 = memref.alloc(%5) : memref
    scf.for %arg0 = %idx0 to %5 step %idx1 {
      memref.store %cst_0, %alloc_21[%arg0] : memref
    }
    %16 = bufferization.to_tensor %alloc_21 restrict writable : memref
    %17:3 = scf.for %arg0 = %c0 to %9 step %c1 iter_args(%arg1 = %idx0, %arg2 = %12, %arg3 = %16) -> (index, tensor, tensor) {
      %19 = arith.addi %arg0, %c1 : index
      %extracted = tensor.extract %11[%arg0] : tensor
      %extracted_22 = tensor.extract %11[%19] : tensor
      %20:3 = scf.for %arg4 = %extracted to %extracted_22 step %c1 iter_args(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (index, tensor, tensor) {
        %extracted_23 = tensor.extract %12[%arg4] : tensor
        %extracted_24 = tensor.extract %13[%arg4] : tensor
        %extracted_25 = tensor.extract %14[%arg0, %extracted_23] : tensor
        %21 = arith.mulf %extracted_24, %extracted_25 : f64
         %inserted_26 = tensor.insert %extracted_23 into %arg6[%arg5] : tensor <-- We insert into %arg6 which is %12 = %alloc11
        %22 = index.add %arg5, %idx1
        %inserted_27 = tensor.insert %21 into %arg7[%arg5] : tensor
        scf.yield %22, %inserted_26, %inserted_27 : index, tensor, tensor
      }
      scf.yield %20#0, %20#1, %20#2 : index, tensor, tensor
    }
    %18 = bufferization.alloc_tensor() : tensor<1xindex>
    %inserted = tensor.insert %idx-1 into %18[%idx0] : tensor<1xindex>
    "ta.print"(%15) : (tensor<1xindex>) -> ()
    "ta.print"(%inserted) : (tensor<1xindex>) -> ()
     "ta.print"(%11) : (tensor) -> () 
     "ta.print"(%17#1) : (tensor) -> () 
    "ta.print"(%17#2) : (tensor) -> ()
    return
  }
  func.func private @read_input_2D_f64(i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32)
  func.func private @read_input_sizes_2D_f64(i32, index, index, index, index, memref<*xindex>, i32)
}

Here, we can see two issues:

  1. SSA %11, which refers to data related to one of the input sparse tensors is printed when I try to print the result/output tensor. This means that the input and output tensors share a reference to the same data structure which should not be the case.
  2. Even worse, SSA %17#1 is a tensor produced by tensor.insert into one of the underlying data of the same input tensor. this operation specifically: %inserted_26 = tensor.insert %extracted_23 into %arg6[%arg5] : tensor<?xindex>.

I have formatted the related pieces of IR in bold to make it easier to track what I'm referring to.
The problem does not show up since we never check the input tensors but, also, one-shot-bufferize saves us by creating copy a of the tensor that we try to insert to before inserting to it.

If I let bufferization happen, here it is what we get (--convert-ta-to-it --convert-to-loops)


  func.func @main() {
    %idx-1 = index.constant -1
    %idx1 = index.constant 1
    %idx0 = index.constant 0
    %cst = arith.constant 2.700000e+00 : f64
    %cst_0 = arith.constant 0.000000e+00 : f64
    %c10 = arith.constant 10 : index
    %c9 = arith.constant 9 : index
    %c8 = arith.constant 8 : index
    %c7 = arith.constant 7 : index
    %c6 = arith.constant 6 : index
    %c5 = arith.constant 5 : index
    %c4 = arith.constant 4 : index
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %c3 = arith.constant 3 : index
    %c2 = arith.constant 2 : index
    %c-1 = arith.constant -1 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() : memref<13xindex>
    %cast = memref.cast %alloc : memref<13xindex> to memref<*xindex>
    call @read_input_sizes_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, i32) -> ()
    %0 = memref.load %alloc[%c0] : memref<13xindex>
    %1 = memref.load %alloc[%c1] : memref<13xindex>
    %2 = memref.load %alloc[%c2] : memref<13xindex>
    %3 = memref.load %alloc[%c3] : memref<13xindex>
    %4 = memref.load %alloc[%c4] : memref<13xindex>
    %5 = memref.load %alloc[%c5] : memref<13xindex>
    %6 = memref.load %alloc[%c6] : memref<13xindex>
    %7 = memref.load %alloc[%c7] : memref<13xindex>
    %8 = memref.load %alloc[%c8] : memref<13xindex>
    %9 = memref.load %alloc[%c9] : memref<13xindex>
    %10 = memref.load %alloc[%c10] : memref<13xindex>
    %alloc_1 = memref.alloc(%0) : memref
    scf.for %arg0 = %c0 to %0 step %c1 {
      memref.store %c0, %alloc_1[%arg0] : memref
    }
    %cast_2 = memref.cast %alloc_1 : memref to memref<*xindex>
    %alloc_3 = memref.alloc(%1) : memref
    scf.for %arg0 = %c0 to %1 step %c1 {
      memref.store %c0, %alloc_3[%arg0] : memref
    }
    %cast_4 = memref.cast %alloc_3 : memref to memref<*xindex>
    %alloc_5 = memref.alloc(%2) : memref
    scf.for %arg0 = %c0 to %2 step %c1 {
      memref.store %c0, %alloc_5[%arg0] : memref
    }
    %cast_6 = memref.cast %alloc_5 : memref to memref<*xindex>
    %alloc_7 = memref.alloc(%3) : memref
    scf.for %arg0 = %c0 to %3 step %c1 {
      memref.store %c0, %alloc_7[%arg0] : memref
    }
    %cast_8 = memref.cast %alloc_7 : memref to memref<*xindex>
    %alloc_9 = memref.alloc(%4) : memref
    scf.for %arg0 = %c0 to %4 step %c1 {
      memref.store %c0, %alloc_9[%arg0] : memref
    }
    %cast_10 = memref.cast %alloc_9 : memref to memref<*xindex>
    %alloc_11 = memref.alloc(%5) : memref
    scf.for %arg0 = %c0 to %5 step %c1 {
      memref.store %c0, %alloc_11[%arg0] : memref
    }
    %cast_12 = memref.cast %alloc_11 : memref to memref<*xindex>
    %alloc_13 = memref.alloc(%6) : memref
    scf.for %arg0 = %c0 to %6 step %c1 {
      memref.store %c0, %alloc_13[%arg0] : memref
    }
    %cast_14 = memref.cast %alloc_13 : memref to memref<*xindex>
    %alloc_15 = memref.alloc(%7) : memref
    scf.for %arg0 = %c0 to %7 step %c1 {
      memref.store %c0, %alloc_15[%arg0] : memref
    }
    %cast_16 = memref.cast %alloc_15 : memref to memref<*xindex>
    %alloc_17 = memref.alloc(%8) : memref
    scf.for %arg0 = %c0 to %8 step %c1 {
      memref.store %cst_0, %alloc_17[%arg0] : memref
    }
    %cast_18 = memref.cast %alloc_17 : memref to memref<*xf64>
    call @read_input_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast_2, %cast_4, %cast_6, %cast_8, %cast_10, %cast_12, %cast_14, %cast_16, %cast_18, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32) -> ()
    %alloc_19 = memref.alloc(%9, %10) {alignment = 32 : i64} : memref
    linalg.fill ins(%cst : f64) outs(%alloc_19 : memref)
    %alloc_20 = memref.alloc() : memref<1xindex>
    memref.store %9, %alloc_20[%idx0] : memref<1xindex>
    %alloc_21 = memref.alloc(%5) : memref
    scf.for %arg0 = %idx0 to %5 step %idx1 {
      memref.store %cst_0, %alloc_21[%arg0] : memref
    }
    %alloc_22 = memref.alloc(%5) {alignment = 64 : i64} : memref
     memref.copy %alloc_11, %alloc_22 : memref to memref <-- Inserted by one-shot-bufferize 
    %11 = scf.for %arg0 = %c0 to %9 step %c1 iter_args(%arg1 = %idx0) -> (index) {
      %12 = arith.addi %arg0, %c1 : index
      %13 = memref.load %alloc_9[%arg0] : memref
      %14 = memref.load %alloc_9[%12] : memref
      %15 = scf.for %arg2 = %13 to %14 step %c1 iter_args(%arg3 = %arg1) -> (index) {
        %16 = memref.load %alloc_11[%arg2] : memref
        %17 = memref.load %alloc_17[%arg2] : memref
        %18 = memref.load %alloc_19[%arg0, %16] : memref
        %19 = arith.mulf %17, %18 : f64
        memref.store %16, %alloc_22[%arg3] : memref
        %20 = index.add %arg3, %idx1
        memref.store %19, %alloc_21[%arg3] : memref
        scf.yield %20 : index
      }
      scf.yield %15 : index
    }
    %alloc_23 = memref.alloc() {alignment = 64 : i64} : memref<1xindex>
    memref.store %idx-1, %alloc_23[%idx0] : memref<1xindex>
    %cast_24 = memref.cast %alloc_20 : memref<1xindex> to memref<*xindex>
    call @comet_print_memref_i64(%cast_24) : (memref<*xindex>) -> ()
    %cast_25 = memref.cast %alloc_23 : memref<1xindex> to memref<*xindex>
    call @comet_print_memref_i64(%cast_25) : (memref<*xindex>) -> ()
    call @comet_print_memref_i64(%cast_10) : (memref<*xindex>) -> ()
    %cast_26 = memref.cast %alloc_22 : memref to memref<*xindex>
    call @comet_print_memref_i64(%cast_26) : (memref<*xindex>) -> ()
    %cast_27 = memref.cast %alloc_21 : memref to memref<*xf64>
    call @comet_print_memref_f64(%cast_27) : (memref<*xf64>) -> ()
    return
  }

Notice the copy operation that is inserted in bold. However, we cannot rely on this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions