diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index bfc0ee207569..65e9b6e7ccc0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -96,6 +96,7 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Codegen/Transforms", "//compiler/src/iree/compiler/Codegen/Utils", "//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils", + "//compiler/src/iree/compiler/Dialect/Encoding/IR", "//compiler/src/iree/compiler/Dialect/HAL/IR", "@llvm-project//llvm:Support", "@llvm-project//mlir:AMDGPUDialect", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index c3b265e6f804..3c29b055541c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -126,6 +126,7 @@ iree_cc_library( iree::compiler::Codegen::Transforms iree::compiler::Codegen::Utils iree::compiler::Codegen::Utils::VectorOpUtils + iree::compiler::Dialect::Encoding::IR iree::compiler::Dialect::HAL::IR PUBLIC ) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index fe74d62c832c..59ac6a04dbf0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -6,6 +6,11 @@ #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -35,14 +40,16 @@ static std::optional getIntrinsicSize(TypeRange elementTypes) { return std::nullopt; } -// TODO: Query the value from GPU attributes. -// TODO: Define a struct with meaningful name for the pair. -SmallVector getIntrinsicVectorSize(TypeRange elementTypes, - int64_t roleIdx) { - Type lhs = elementTypes[0]; - Type rhs = elementTypes[1]; - Type out = elementTypes[2]; - if (lhs.isF32() && rhs.isF32() && out.isF32()) { +/// Returns the corresponding native vector sizes defined by the `mma` +/// intrinsic. +static SmallVector getIntrinsicVectorSize(IREE::GPU::MMAAttr mma, + int64_t roleIdx) { + if (mma.getIntrinsic().getValue() == + IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x4_F32) { + // TODO: Query the value from GPU attributes. + if (roleIdx == 0 || roleIdx == 1) { + return {1, 1}; + } if (roleIdx == 0 || roleIdx == 1) { return {1, 1}; } @@ -55,13 +62,11 @@ SmallVector getIntrinsicVectorSize(TypeRange elementTypes, // Given encoding's role index and element types, return the transpose // permutation used in GPU materialization. -SmallVector getTransposePermutation(int64_t roleIdx, - TypeRange elementTypes) { - // For now, check that all types are f32: - Type lhs = elementTypes[0]; - Type rhs = elementTypes[1]; - Type out = elementTypes[2]; - if (!lhs.isF32() || !rhs.isF32() || !out.isF32()) { +static SmallVector getTransposePermutation(IREE::GPU::MMAAttr mma, + int64_t roleIdx) { + // TODO: Support other intrinsics. + if (mma.getIntrinsic().getValue() != + IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x4_F32) { return {}; } @@ -81,27 +86,33 @@ SmallVector getTransposePermutation(int64_t roleIdx, } } -// TODO(hanchung): Pass an ExecutableTargetAttr attribute for the target -// encoding. Here we assume that every mfma op is available. -// TODO(hanchung): Handle wmma ops. -static SmallVector enumerateMatmulTileMxNxK(TypeRange elementTypes) { +static std::optional +enumerateMmaIntrinsic(TypeRange elementTypes, IREE::GPU::TargetAttr target) { assert(elementTypes.size() == 3); Type lhs = elementTypes[0]; Type rhs = elementTypes[1]; Type out = elementTypes[2]; - if (lhs.isF32() && rhs.isF32() && out.isF32()) { - // TODO: Take subgroup_size into account, so we can have more unrolling. - // TODO: Take the bitwidth of load into account, so we can have correct - // unrolling factor for K-dimension. - return {TileMxNxK{16, 16, 4}}; // Aim to use mfma_f32_16x16x4_f32 intrinsic. + for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { + IREE::GPU::MMAIntrinsic type = mma.getIntrinsic().getValue(); + // TODO: Drop this once all intrinsics are supported. + if (type != IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x4_F32) { + continue; + } + + auto [aType, bType, cType] = mma.getABCElementTypes(); + if (lhs != aType || rhs != bType || out != cType) { + continue; + } + return mma; } // Fallback - no architecture-optimized tile size for this case. - return {}; + return std::nullopt; } static FailureOr -materializeEncodingForTarget(RankedTensorType tensorType) { +materializeEncodingForTarget(RankedTensorType tensorType, + IREE::HAL::ExecutableTargetAttr targetAttr) { auto encoding = dyn_cast_or_null(tensorType.getEncoding()); if (!encoding) { @@ -113,28 +124,31 @@ materializeEncodingForTarget(RankedTensorType tensorType) { cDims->n.size() > 1 || cDims->k.size() > 1) { return failure(); } + // Enumerate available tile shapes for the given encoding and target. + IREE::GPU::TargetAttr gpuTargetAttr = getGPUTargetAttr(targetAttr); auto elementTypes = llvm::to_vector( llvm::map_range(encoding.getElementTypes().getValue(), [](Attribute a) { return cast(a).getValue(); })); - SmallVector enumeratedTileMxNxK = - enumerateMatmulTileMxNxK(elementTypes); - if (enumeratedTileMxNxK.empty()) { + std::optional mma = + enumerateMmaIntrinsic(elementTypes, gpuTargetAttr); + if (!mma) { return failure(); } // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand, // based on its operand index in the matmul. + // TODO: Support unrolling. auto rank = tensorType.getRank(); - - auto encodingInfo = - getEncodingInfoForMatmul(encoding, rank, enumeratedTileMxNxK[0]); + TileMxNxK innerTile; + std::tie(innerTile.M, innerTile.N, innerTile.K) = mma->getMNKShape(); + auto encodingInfo = getEncodingInfoForMatmul(encoding, rank, innerTile); // insert inner tile shapes and permutation info auto roleIdx = encoding.getOperandIndex().getInt(); - auto intrinsicVectorSizes = getIntrinsicVectorSize(elementTypes, roleIdx); - auto permutation = getTransposePermutation(roleIdx, elementTypes); + auto intrinsicVectorSizes = getIntrinsicVectorSize(*mma, roleIdx); + auto permutation = getTransposePermutation(*mma, roleIdx); encodingInfo.innerTileShapes = intrinsicVectorSizes; encodingInfo.permutation = permutation; return encodingInfo; @@ -146,6 +160,11 @@ struct GPUMaterializeDeviceEncodingPass final GPUMaterializeDeviceEncodingPass> { using GPUMaterializeDeviceEncodingPassBase:: GPUMaterializeDeviceEncodingPassBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } void runOnOperation() override; }; @@ -301,13 +320,25 @@ struct GPUSetEncodingOpLoweringConversion } // namespace +// TODO(hanchung): Remove the wrapper after allowing the type converter to carry +// the targetAttr. For now, follow what CPU is doing. +static MaterializeEncodingFn +getMaterializeEncodingFn(IREE::HAL::ExecutableTargetAttr targetAttr) { + return + [targetAttr]( + RankedTensorType tensorType) -> FailureOr { + return materializeEncodingForTarget(tensorType, targetAttr); + }; +} + void GPUMaterializeDeviceEncodingPass::runOnOperation() { MLIRContext *ctx = &getContext(); FunctionOpInterface funcOp = getOperation(); + auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); { RewritePatternSet patterns(ctx); MaterializeEncodingTypeConverter typeConverter( - materializeEncodingForTarget); + getMaterializeEncodingFn(targetAttr)); MaterializeEncodingConversionTarget target(*funcOp.getContext()); MaterializeEncodingValueFn materializeEncodingValueFn = [](RankedTensorType, OpBuilder, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 14dbfda1848b..9a2b912d0a4a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -26,6 +26,7 @@ iree_lit_test_suite( "gpu_distribute_shared_memory.mlir", "gpu_generalize_named_ops.mlir", "gpu_lower_to_ukernels.mlir", + "gpu_materialize_encoding.mlir", "gpu_nested_layout_contract_amdgpu.mlir", "gpu_nested_layout_vector_distribution.mlir", "gpu_pipeline.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 9ccc268d0d79..7d1442abb131 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -22,6 +22,7 @@ iree_lit_test_suite( "gpu_distribute_shared_memory.mlir" "gpu_generalize_named_ops.mlir" "gpu_lower_to_ukernels.mlir" + "gpu_materialize_encoding.mlir" "gpu_nested_layout_contract_amdgpu.mlir" "gpu_nested_layout_vector_distribution.mlir" "gpu_pipeline.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir new file mode 100644 index 000000000000..821cac997f9d --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding.mlir @@ -0,0 +1,95 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" --split-input-file %s | FileCheck %s + +//----------------------------------------------------------------------------- +// 1. MFMA_F32_16x16x4_F32 +//----------------------------------------------------------------------------- + +#encoding = #iree_encoding.encoding, + user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + round_dims_to = array> +#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", { + iree.gpu.target = #iree_gpu.target, , , , , ], + subgroup_size_choices = [64], + max_workgroup_sizes = [1024, 1024, 1024], + max_thread_count_per_workgroup = 1024, + max_workgroup_memory_bytes = 65536>> +}> +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +func.func @set_encoding_LHS() attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb +} { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> + %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> + flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func.func @set_encoding_LHS +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<33x64x16x4xf32> +// CHECK: %[[PACK:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32> +// CHECK: %[[EXPAND_LHS:.*]] = tensor.expand_shape %[[PACK]] +// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32> +// CHECK: %[[EMPTY_LHS2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> +// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[EXPAND_LHS]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_LHS2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]] +// CHECK: %[[EXPAND_LHS_2:.*]] = tensor.expand_shape %[[COLLAPSE]] +// CHECK: flow.dispatch.tensor.store %[[EXPAND_LHS_2]] + +func.func @set_encoding_RHS() attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb +} { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> + %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> + flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func.func @set_encoding_RHS +// CHECK: %[[EMPTY_RHS:.*]] = tensor.empty() : tensor<33x64x16x4xf32> +// CHECK: %[[PACK_RHS:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %3 : tensor<255x513xf32> -> tensor<33x64x16x4xf32> +// CHECK: %[[EXPAND_RHS:.*]] = tensor.expand_shape %[[PACK_RHS]] +// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32> +// CHECK: %[[EMPTY_RHS2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> +// CHECK: %[[TRANSPOSE_RHS:.*]] = linalg.transpose ins(%[[EXPAND_RHS]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_RHS2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3] +// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_RHS]] +// CHECK: %[[EXPAND_RHS_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]] +// CHECK: flow.dispatch.tensor.store %[[EXPAND_RHS_2]] + +func.func @set_encoding_ACC() attributes { + hal.executable.target = #executable_target_rocm_hsaco_fb +} { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> + %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> + flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func.func @set_encoding_ACC +// CHECK: %[[EMPTY_ACC:.*]] = tensor.empty() : tensor<33x64x16x4xf32> +// CHECK: %[[PACK_ACC:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY_ACC]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32> +// CHECK: %[[EXPAND_ACC:.*]] = tensor.expand_shape %[[PACK_ACC]] +// CHECK: %[[EMPTY_ACC2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> +// CHECK: %[[TRANSPOSE_ACC:.*]] = linalg.transpose ins(%[[EXPAND_ACC]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_ACC2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3] +// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_ACC]] +// CHECK: %[[EXPAND_ACC_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]] +// CHECK: flow.dispatch.tensor.store %[[EXPAND_ACC_2]] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 76b085ee4fac..67a84393a918 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -37,7 +37,6 @@ iree_lit_test_suite( "config_winograd.mlir", "extract_address_computation_gpu.mlir", "gpu_set_num_workgroups.mlir", - "gpu_materialize_encoding.mlir", "gpu_pipeline_generalize_named_ops.mlir", "nvvm_extract_address_computation.mlir", "nvvm_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 89a32173ecb2..46366b4d9fa6 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -33,7 +33,6 @@ iree_lit_test_suite( "distribute_to_thread.mlir" "elementwise_pipeline.mlir" "extract_address_computation_gpu.mlir" - "gpu_materialize_encoding.mlir" "gpu_pipeline_generalize_named_ops.mlir" "gpu_set_num_workgroups.mlir" "illegal_configuration.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_materialize_encoding.mlir deleted file mode 100644 index 7890dffed6c0..000000000000 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_materialize_encoding.mlir +++ /dev/null @@ -1,78 +0,0 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" --split-input-file %s | FileCheck %s - -#encoding = #iree_encoding.encoding, - user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], - round_dims_to = array> - -#pipeline_layout = #hal.pipeline.layout, - #hal.descriptor_set.binding<1, storage_buffer> - ]> -]> -func.func @set_encoding_LHS() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> - %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> - flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> - return -} - -// CHECK-LABEL: func.func @set_encoding_LHS -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<33x64x16x4xf32> -// CHECK: %[[PACK:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32> -// CHECK: %[[EXPAND_LHS:.*]] = tensor.expand_shape %[[PACK]] -// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32> -// CHECK: %[[EMPTY_LHS2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> -// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[EXPAND_LHS]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_LHS2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3] -// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]] -// CHECK: %[[EXPAND_LHS_2:.*]] = tensor.expand_shape %[[COLLAPSE]] -// CHECK: flow.dispatch.tensor.store %[[EXPAND_LHS_2]] - -//--------- - -func.func @set_encoding_RHS() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> - %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> - flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> - return -} - -// CHECK-LABEL: func.func @set_encoding_RHS -// CHECK: %[[EMPTY_RHS:.*]] = tensor.empty() : tensor<33x64x16x4xf32> -// CHECK: %[[PACK_RHS:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %3 : tensor<255x513xf32> -> tensor<33x64x16x4xf32> -// CHECK: %[[EXPAND_RHS:.*]] = tensor.expand_shape %[[PACK_RHS]] -// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32> -// CHECK: %[[EMPTY_RHS2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> -// CHECK: %[[TRANSPOSE_RHS:.*]] = linalg.transpose ins(%[[EXPAND_RHS]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_RHS2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3] -// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_RHS]] -// CHECK: %[[EXPAND_RHS_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]] -// CHECK: flow.dispatch.tensor.store %[[EXPAND_RHS_2]] - -//--------- - -func.func @set_encoding_ACC() { - %c0 = arith.constant 0 : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> - %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> - %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> - flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> - return -} - -// CHECK-LABEL: func.func @set_encoding_ACC -// CHECK: %[[EMPTY_ACC:.*]] = tensor.empty() : tensor<33x64x16x4xf32> -// CHECK: %[[PACK_ACC:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY_ACC]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32> -// CHECK: %[[EXPAND_ACC:.*]] = tensor.expand_shape %[[PACK_ACC]] -// CHECK: %[[EMPTY_ACC2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> -// CHECK: %[[TRANSPOSE_ACC:.*]] = linalg.transpose ins(%[[EXPAND_ACC]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_ACC2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3] - -// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_ACC]] -// CHECK: %[[EXPAND_ACC_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]] -// CHECK: flow.dispatch.tensor.store %[[EXPAND_ACC_2]]