-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Running the test case in ops/eltwise_mult_CSRxDense_oCSR.ta
unveils a bug in the IR that is hard to notice since it does not affect the results.
Specifically, lowering to loops (--emit-loops
) produces the following IR:
module {
func.func @main() {
%idx-1 = index.constant -1
%idx1 = index.constant 1
%idx0 = index.constant 0
%cst = arith.constant 2.700000e+00 : f64
%cst_0 = arith.constant 0.000000e+00 : f64
%c10 = arith.constant 10 : index
%c9 = arith.constant 9 : index
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c-1 = arith.constant -1 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<13xindex>
%cast = memref.cast %alloc : memref<13xindex> to memref<*xindex>
call @read_input_sizes_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, i32) -> ()
%0 = memref.load %alloc[%c0] : memref<13xindex>
%1 = memref.load %alloc[%c1] : memref<13xindex>
%2 = memref.load %alloc[%c2] : memref<13xindex>
%3 = memref.load %alloc[%c3] : memref<13xindex>
%4 = memref.load %alloc[%c4] : memref<13xindex>
%5 = memref.load %alloc[%c5] : memref<13xindex>
%6 = memref.load %alloc[%c6] : memref<13xindex>
%7 = memref.load %alloc[%c7] : memref<13xindex>
%8 = memref.load %alloc[%c8] : memref<13xindex>
%9 = memref.load %alloc[%c9] : memref<13xindex>
%10 = memref.load %alloc[%c10] : memref<13xindex>
%alloc_1 = memref.alloc(%0) : memref
scf.for %arg0 = %c0 to %0 step %c1 {
memref.store %c0, %alloc_1[%arg0] : memref
}
%cast_2 = memref.cast %alloc_1 : memref to memref<*xindex>
%alloc_3 = memref.alloc(%1) : memref
scf.for %arg0 = %c0 to %1 step %c1 {
memref.store %c0, %alloc_3[%arg0] : memref
}
%cast_4 = memref.cast %alloc_3 : memref to memref<*xindex>
%alloc_5 = memref.alloc(%2) : memref
scf.for %arg0 = %c0 to %2 step %c1 {
memref.store %c0, %alloc_5[%arg0] : memref
}
%cast_6 = memref.cast %alloc_5 : memref to memref<*xindex>
%alloc_7 = memref.alloc(%3) : memref
scf.for %arg0 = %c0 to %3 step %c1 {
memref.store %c0, %alloc_7[%arg0] : memref
}
%cast_8 = memref.cast %alloc_7 : memref to memref<*xindex>
%alloc_9 = memref.alloc(%4) : memref
scf.for %arg0 = %c0 to %4 step %c1 {
memref.store %c0, %alloc_9[%arg0] : memref
}
%cast_10 = memref.cast %alloc_9 : memref to memref<*xindex>
%alloc_11 = memref.alloc(%5) : memref
scf.for %arg0 = %c0 to %5 step %c1 {
memref.store %c0, %alloc_11[%arg0] : memref
}
%cast_12 = memref.cast %alloc_11 : memref to memref<*xindex>
%alloc_13 = memref.alloc(%6) : memref
scf.for %arg0 = %c0 to %6 step %c1 {
memref.store %c0, %alloc_13[%arg0] : memref
}
%cast_14 = memref.cast %alloc_13 : memref to memref<*xindex>
%alloc_15 = memref.alloc(%7) : memref
scf.for %arg0 = %c0 to %7 step %c1 {
memref.store %c0, %alloc_15[%arg0] : memref
}
%cast_16 = memref.cast %alloc_15 : memref to memref<*xindex>
%alloc_17 = memref.alloc(%8) : memref
scf.for %arg0 = %c0 to %8 step %c1 {
memref.store %cst_0, %alloc_17[%arg0] : memref
}
%cast_18 = memref.cast %alloc_17 : memref to memref<*xf64>
call @read_input_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast_2, %cast_4, %cast_6, %cast_8, %cast_10, %cast_12, %cast_14, %cast_16, %cast_18, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32) -> ()
%11 = bufferization.to_tensor %alloc_9 restrict writable : memref
%12 = bufferization.to_tensor %alloc_11 restrict writable : memref
%13 = bufferization.to_tensor %alloc_17 restrict writable : memref
%alloc_19 = memref.alloc(%9, %10) {alignment = 32 : i64} : memref
linalg.fill ins(%cst : f64) outs(%alloc_19 : memref)
%14 = bufferization.to_tensor %alloc_19 restrict writable : memref
%alloc_20 = memref.alloc() : memref<1xindex>
memref.store %9, %alloc_20[%idx0] : memref<1xindex>
%15 = bufferization.to_tensor %alloc_20 restrict writable : memref<1xindex>
%alloc_21 = memref.alloc(%5) : memref
scf.for %arg0 = %idx0 to %5 step %idx1 {
memref.store %cst_0, %alloc_21[%arg0] : memref
}
%16 = bufferization.to_tensor %alloc_21 restrict writable : memref
%17:3 = scf.for %arg0 = %c0 to %9 step %c1 iter_args(%arg1 = %idx0, %arg2 = %12, %arg3 = %16) -> (index, tensor, tensor) {
%19 = arith.addi %arg0, %c1 : index
%extracted = tensor.extract %11[%arg0] : tensor
%extracted_22 = tensor.extract %11[%19] : tensor
%20:3 = scf.for %arg4 = %extracted to %extracted_22 step %c1 iter_args(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (index, tensor, tensor) {
%extracted_23 = tensor.extract %12[%arg4] : tensor
%extracted_24 = tensor.extract %13[%arg4] : tensor
%extracted_25 = tensor.extract %14[%arg0, %extracted_23] : tensor
%21 = arith.mulf %extracted_24, %extracted_25 : f64
%inserted_26 = tensor.insert %extracted_23 into %arg6[%arg5] : tensor <-- We insert into %arg6 which is %12 = %alloc11
%22 = index.add %arg5, %idx1
%inserted_27 = tensor.insert %21 into %arg7[%arg5] : tensor
scf.yield %22, %inserted_26, %inserted_27 : index, tensor, tensor
}
scf.yield %20#0, %20#1, %20#2 : index, tensor, tensor
}
%18 = bufferization.alloc_tensor() : tensor<1xindex>
%inserted = tensor.insert %idx-1 into %18[%idx0] : tensor<1xindex>
"ta.print"(%15) : (tensor<1xindex>) -> ()
"ta.print"(%inserted) : (tensor<1xindex>) -> ()
"ta.print"(%11) : (tensor) -> ()
"ta.print"(%17#1) : (tensor) -> ()
"ta.print"(%17#2) : (tensor) -> ()
return
}
func.func private @read_input_2D_f64(i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32)
func.func private @read_input_sizes_2D_f64(i32, index, index, index, index, memref<*xindex>, i32)
}
Here, we can see two issues:
- SSA
%11
, which refers to data related to one of the input sparse tensors is printed when I try to print the result/output tensor. This means that the input and output tensors share a reference to the same data structure which should not be the case. - Even worse, SSA
%17#1
is a tensor produced bytensor.insert
into one of the underlying data of the same input tensor. this operation specifically:%inserted_26 = tensor.insert %extracted_23 into %arg6[%arg5] : tensor<?xindex>
.
I have formatted the related pieces of IR in bold to make it easier to track what I'm referring to.
The problem does not show up since we never check the input tensors but, also, one-shot-bufferize saves us by creating copy a of the tensor that we try to insert to before inserting to it.
If I let bufferization happen, here it is what we get (--convert-ta-to-it --convert-to-loops
)
func.func @main() {
%idx-1 = index.constant -1
%idx1 = index.constant 1
%idx0 = index.constant 0
%cst = arith.constant 2.700000e+00 : f64
%cst_0 = arith.constant 0.000000e+00 : f64
%c10 = arith.constant 10 : index
%c9 = arith.constant 9 : index
%c8 = arith.constant 8 : index
%c7 = arith.constant 7 : index
%c6 = arith.constant 6 : index
%c5 = arith.constant 5 : index
%c4 = arith.constant 4 : index
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c-1 = arith.constant -1 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<13xindex>
%cast = memref.cast %alloc : memref<13xindex> to memref<*xindex>
call @read_input_sizes_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, i32) -> ()
%0 = memref.load %alloc[%c0] : memref<13xindex>
%1 = memref.load %alloc[%c1] : memref<13xindex>
%2 = memref.load %alloc[%c2] : memref<13xindex>
%3 = memref.load %alloc[%c3] : memref<13xindex>
%4 = memref.load %alloc[%c4] : memref<13xindex>
%5 = memref.load %alloc[%c5] : memref<13xindex>
%6 = memref.load %alloc[%c6] : memref<13xindex>
%7 = memref.load %alloc[%c7] : memref<13xindex>
%8 = memref.load %alloc[%c8] : memref<13xindex>
%9 = memref.load %alloc[%c9] : memref<13xindex>
%10 = memref.load %alloc[%c10] : memref<13xindex>
%alloc_1 = memref.alloc(%0) : memref
scf.for %arg0 = %c0 to %0 step %c1 {
memref.store %c0, %alloc_1[%arg0] : memref
}
%cast_2 = memref.cast %alloc_1 : memref to memref<*xindex>
%alloc_3 = memref.alloc(%1) : memref
scf.for %arg0 = %c0 to %1 step %c1 {
memref.store %c0, %alloc_3[%arg0] : memref
}
%cast_4 = memref.cast %alloc_3 : memref to memref<*xindex>
%alloc_5 = memref.alloc(%2) : memref
scf.for %arg0 = %c0 to %2 step %c1 {
memref.store %c0, %alloc_5[%arg0] : memref
}
%cast_6 = memref.cast %alloc_5 : memref to memref<*xindex>
%alloc_7 = memref.alloc(%3) : memref
scf.for %arg0 = %c0 to %3 step %c1 {
memref.store %c0, %alloc_7[%arg0] : memref
}
%cast_8 = memref.cast %alloc_7 : memref to memref<*xindex>
%alloc_9 = memref.alloc(%4) : memref
scf.for %arg0 = %c0 to %4 step %c1 {
memref.store %c0, %alloc_9[%arg0] : memref
}
%cast_10 = memref.cast %alloc_9 : memref to memref<*xindex>
%alloc_11 = memref.alloc(%5) : memref
scf.for %arg0 = %c0 to %5 step %c1 {
memref.store %c0, %alloc_11[%arg0] : memref
}
%cast_12 = memref.cast %alloc_11 : memref to memref<*xindex>
%alloc_13 = memref.alloc(%6) : memref
scf.for %arg0 = %c0 to %6 step %c1 {
memref.store %c0, %alloc_13[%arg0] : memref
}
%cast_14 = memref.cast %alloc_13 : memref to memref<*xindex>
%alloc_15 = memref.alloc(%7) : memref
scf.for %arg0 = %c0 to %7 step %c1 {
memref.store %c0, %alloc_15[%arg0] : memref
}
%cast_16 = memref.cast %alloc_15 : memref to memref<*xindex>
%alloc_17 = memref.alloc(%8) : memref
scf.for %arg0 = %c0 to %8 step %c1 {
memref.store %cst_0, %alloc_17[%arg0] : memref
}
%cast_18 = memref.cast %alloc_17 : memref to memref<*xf64>
call @read_input_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast_2, %cast_4, %cast_6, %cast_8, %cast_10, %cast_12, %cast_14, %cast_16, %cast_18, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32) -> ()
%alloc_19 = memref.alloc(%9, %10) {alignment = 32 : i64} : memref
linalg.fill ins(%cst : f64) outs(%alloc_19 : memref)
%alloc_20 = memref.alloc() : memref<1xindex>
memref.store %9, %alloc_20[%idx0] : memref<1xindex>
%alloc_21 = memref.alloc(%5) : memref
scf.for %arg0 = %idx0 to %5 step %idx1 {
memref.store %cst_0, %alloc_21[%arg0] : memref
}
%alloc_22 = memref.alloc(%5) {alignment = 64 : i64} : memref
memref.copy %alloc_11, %alloc_22 : memref to memref <-- Inserted by one-shot-bufferize
%11 = scf.for %arg0 = %c0 to %9 step %c1 iter_args(%arg1 = %idx0) -> (index) {
%12 = arith.addi %arg0, %c1 : index
%13 = memref.load %alloc_9[%arg0] : memref
%14 = memref.load %alloc_9[%12] : memref
%15 = scf.for %arg2 = %13 to %14 step %c1 iter_args(%arg3 = %arg1) -> (index) {
%16 = memref.load %alloc_11[%arg2] : memref
%17 = memref.load %alloc_17[%arg2] : memref
%18 = memref.load %alloc_19[%arg0, %16] : memref
%19 = arith.mulf %17, %18 : f64
memref.store %16, %alloc_22[%arg3] : memref
%20 = index.add %arg3, %idx1
memref.store %19, %alloc_21[%arg3] : memref
scf.yield %20 : index
}
scf.yield %15 : index
}
%alloc_23 = memref.alloc() {alignment = 64 : i64} : memref<1xindex>
memref.store %idx-1, %alloc_23[%idx0] : memref<1xindex>
%cast_24 = memref.cast %alloc_20 : memref<1xindex> to memref<*xindex>
call @comet_print_memref_i64(%cast_24) : (memref<*xindex>) -> ()
%cast_25 = memref.cast %alloc_23 : memref<1xindex> to memref<*xindex>
call @comet_print_memref_i64(%cast_25) : (memref<*xindex>) -> ()
call @comet_print_memref_i64(%cast_10) : (memref<*xindex>) -> ()
%cast_26 = memref.cast %alloc_22 : memref to memref<*xindex>
call @comet_print_memref_i64(%cast_26) : (memref<*xindex>) -> ()
%cast_27 = memref.cast %alloc_21 : memref to memref<*xf64>
call @comet_print_memref_f64(%cast_27) : (memref<*xf64>) -> ()
return
}
Notice the copy operation that is inserted in bold. However, we cannot rely on this.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working