Skip to content

Commit

Permalink
Get rid of that weird zero basis hack
Browse files Browse the repository at this point in the history
Now that there's an upstream PR that allows affine.delineraize_index
to clamp, use that instead of the hack I had.
  • Loading branch information
krzysz00 committed Nov 13, 2024
1 parent f42cf1a commit 5207e94
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1028,17 +1028,14 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
getSubgroupSize() / intrinsicLayoutThreadBound);
}

// Add a `0` at the front of the distribution sizes so that
// `affine.delinearize_index` clamp its output (we'll throw away the first
// result).
distributionThreadSizes.insert(distributionThreadSizes.begin(), 0);

// Obtain the offsets from delinearization along the distributionThreadSizes.
// Use a delinearize without outer bound and throw away its initial result
// to get clamping behavior.
SmallVector<OpFoldResult> tileOffsets =
builder
.create<affine::AffineDelinearizeIndexOp>(
loc, getValueOrCreateConstantIndexOp(builder, loc, threadId),
distributionThreadSizes)
distributionThreadSizes, /*hasOuterBound=*/false)
->getResults()
.drop_front();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<1x1x4x16xf32>, %rhs: t
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x4x16x4xf32>)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (0, 4, 16)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1]
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
Expand Down Expand Up @@ -426,7 +426,7 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled(%lhs: tensor<1x1x2x4x16x4x
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (0, 4, 16)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
Expand Down Expand Up @@ -462,12 +462,12 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor<
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (256) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>)
// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (0, 2, 4, 16)
// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (2, 4, 16)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[ACC_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (0, 2, 2, 4, 16)
// CHECK-DAG: %[[ACC_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (2, 2, 4, 16)
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project

0 comments on commit 5207e94

Please sign in to comment.