Skip to content

Commit

Permalink
Integrate llvm-project@b358f21 (#19066)
Browse files Browse the repository at this point in the history
- llvm::Type->getPointerTo was deprecated and replaced with
llvm::PointerType::get
 - affine.delinearize_index now takes mixed attr/value operands

Includes two fixes
 - SCF::TileAndFuseConsumerOfSlice was incorrectly determining all DPS
   ops to be tilable
 - A new PackOp + CastOp folder was dropping lowering configs
  • Loading branch information
qedawkins authored Nov 7, 2024
1 parent 8e5f218 commit 92526ea
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 103 deletions.
139 changes: 75 additions & 64 deletions compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "compiler/plugins/target/LLVMCPU/LibraryBuilder.h"

#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"

// =============================================================================
Expand Down Expand Up @@ -42,13 +43,14 @@ static llvm::StructType *makeImportTableType(llvm::LLVMContext &context) {
}
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *i8PtrType = llvm::PointerType::getUnqual(context);
auto *type = llvm::StructType::create(context,
{
i32Type,
i8PtrType->getPointerTo(),
},
"iree_hal_executable_import_table_v0_t",
/*isPacked=*/false);
auto *type =
llvm::StructType::create(context,
{
i32Type,
llvm::PointerType::get(i8PtrType, 0),
},
"iree_hal_executable_import_table_v0_t",
/*isPacked=*/false);
return type;
}

Expand Down Expand Up @@ -100,13 +102,14 @@ makeDispatchFunctionType(llvm::LLVMContext &context) {
auto *dispatchStateType = makeDispatchStateType(context);
auto *workgroupStateType = makeWorkgroupStateType(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
return llvm::FunctionType::get(i32Type,
{
environmentType->getPointerTo(),
dispatchStateType->getPointerTo(),
workgroupStateType->getPointerTo(),
},
/*isVarArg=*/false);
return llvm::FunctionType::get(
i32Type,
{
llvm::PointerType::get(environmentType, 0),
llvm::PointerType::get(dispatchStateType, 0),
llvm::PointerType::get(workgroupStateType, 0),
},
/*isVarArg=*/false);
}

// %struct.iree_hal_executable_dispatch_attrs_v0_t = type {
Expand Down Expand Up @@ -181,15 +184,15 @@ makeStageLocationTableType(llvm::LLVMContext &context) {
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *i8PtrType = llvm::PointerType::getUnqual(context);
auto *sourceLocationType = makeSourceLocationType(context);
auto *type =
llvm::StructType::create(context,
{
i32Type,
i8PtrType->getPointerTo(),
sourceLocationType->getPointerTo(),
},
"iree_hal_executable_stage_location_table_v0_t",
/*isPacked=*/false);
auto *type = llvm::StructType::create(
context,
{
i32Type,
llvm::PointerType::get(i8PtrType, 0),
llvm::PointerType::get(sourceLocationType, 0),
},
"iree_hal_executable_stage_location_table_v0_t",
/*isPacked=*/false);
return type;
}

Expand All @@ -209,6 +212,8 @@ static llvm::StructType *makeExportTableType(llvm::LLVMContext &context) {
}
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *dispatchFunctionType = makeDispatchFunctionType(context);
auto *dispatchFunctionPointerType =
llvm::PointerType::get(dispatchFunctionType, 0);
auto *dispatchAttrsType = makeDispatchAttrsType(context);
auto *i8PtrType = llvm::PointerType::getUnqual(context);
auto *sourceLocationType = makeSourceLocationType(context);
Expand All @@ -217,12 +222,12 @@ static llvm::StructType *makeExportTableType(llvm::LLVMContext &context) {
context,
{
i32Type,
dispatchFunctionType->getPointerTo()->getPointerTo(),
dispatchAttrsType->getPointerTo(),
i8PtrType->getPointerTo(),
i8PtrType->getPointerTo(),
sourceLocationType->getPointerTo(),
stageLocationTableType->getPointerTo(),
llvm::PointerType::get(dispatchFunctionPointerType, 0),
llvm::PointerType::get(dispatchAttrsType, 0),
llvm::PointerType::get(i8PtrType, 0),
llvm::PointerType::get(i8PtrType, 0),
llvm::PointerType::get(sourceLocationType, 0),
llvm::PointerType::get(stageLocationTableType, 0),
},
"iree_hal_executable_export_table_v0_t",
/*isPacked=*/false);
Expand Down Expand Up @@ -288,7 +293,7 @@ static llvm::StructType *makeSourceTableType(llvm::LLVMContext &context) {
llvm::StructType::create(context,
{
i32Type,
sourceFileType->getPointerTo(),
llvm::PointerType::get(sourceFileType, 0),
},
"iree_hal_executable_source_file_table_v0_t",
/*isPacked=*/false);
Expand Down Expand Up @@ -335,16 +340,17 @@ static llvm::StructType *makeLibraryType(llvm::StructType *libraryHeaderType) {
auto *exportTableType = makeExportTableType(context);
auto *constantTableType = makeConstantTableType(context);
auto *sourceTableType = makeSourceTableType(context);
auto *type = llvm::StructType::create(context,
{
libraryHeaderType->getPointerTo(),
importTableType,
exportTableType,
constantTableType,
sourceTableType,
},
"iree_hal_executable_library_v0_t",
/*isPacked=*/false);
auto *type =
llvm::StructType::create(context,
{
llvm::PointerType::get(libraryHeaderType, 0),
importTableType,
exportTableType,
constantTableType,
sourceTableType,
},
"iree_hal_executable_library_v0_t",
/*isPacked=*/false);
return type;
}

Expand Down Expand Up @@ -379,7 +385,7 @@ static llvm::Constant *createStringConstantOrNull(StringRef value,
llvm::Module *module) {
if (value.empty()) {
auto i8Type = llvm::IntegerType::getInt8Ty(module->getContext());
return llvm::ConstantPointerNull::get(i8Type->getPointerTo());
return llvm::ConstantPointerNull::get(llvm::PointerType::get(i8Type, 0));
}
return createStringConstant(value, module);
}
Expand Down Expand Up @@ -427,13 +433,14 @@ static llvm::Constant *createArrayConstant(StringRef name,
llvm::Function *LibraryBuilder::build(StringRef queryFuncName) {
auto &context = module->getContext();
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
auto *environmentType = makeEnvironmentType(context)->getPointerTo();
auto *environmentStructType = makeEnvironmentType(context);
auto *environmentType = llvm::PointerType::get(environmentStructType, 0);
auto *libraryHeaderType = makeLibraryHeaderType(context);

// %struct.iree_hal_executable_library_header_t**
// @iree_hal_library_query(i32, %struct.iree_hal_executable_environment_v0_t*)
auto *queryFuncType =
llvm::FunctionType::get(libraryHeaderType->getPointerTo(),
llvm::FunctionType::get(llvm::PointerType::get(libraryHeaderType, 0),
{
i32Type,
environmentType,
Expand All @@ -454,8 +461,10 @@ llvm::Function *LibraryBuilder::build(StringRef queryFuncName) {
builder.CreateICmpEQ(func->getArg(0),
llvm::ConstantInt::get(
i32Type, static_cast<int64_t>(Version::LATEST))),
builder.CreatePointerCast(v0, libraryHeaderType->getPointerTo()),
llvm::ConstantPointerNull::get(libraryHeaderType->getPointerTo())));
builder.CreatePointerCast(v0,
llvm::PointerType::get(libraryHeaderType, 0)),
llvm::ConstantPointerNull::get(
llvm::PointerType::get(libraryHeaderType, 0))));

return func;
}
Expand All @@ -467,7 +476,7 @@ LibraryBuilder::buildLibraryV0ImportTable(std::string libraryName) {
auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
llvm::Constant *symbolNames =
llvm::Constant::getNullValue(i8Type->getPointerTo());
llvm::Constant::getNullValue(llvm::PointerType::get(i8Type, 0));
if (!imports.empty()) {
SmallVector<llvm::Constant *> symbolNameValues;
for (auto &import : imports) {
Expand All @@ -476,9 +485,9 @@ LibraryBuilder::buildLibraryV0ImportTable(std::string libraryName) {
symbolName = "?" + symbolName;
symbolNameValues.push_back(createStringConstant(symbolName, module));
}
symbolNames =
createArrayConstant(libraryName + "_import_names",
i8Type->getPointerTo(), symbolNameValues, module);
symbolNames = createArrayConstant(libraryName + "_import_names",
llvm::PointerType::get(i8Type, 0),
symbolNameValues, module);
}
return llvm::ConstantStruct::get(
importTableType, {
Expand Down Expand Up @@ -507,12 +516,12 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {
for (auto dispatch : exports)
exportPtrValues.push_back(dispatch.func);
llvm::Constant *exportPtrs = createArrayConstant(
libraryName + "_funcs", dispatchFunctionType->getPointerTo(),
libraryName + "_funcs", llvm::PointerType::get(dispatchFunctionType, 0),
exportPtrValues, module);

// iree_hal_executable_export_table_v0_t::attrs
llvm::Constant *exportAttrs =
llvm::Constant::getNullValue(i32Type->getPointerTo());
llvm::Constant::getNullValue(llvm::PointerType::get(i32Type, 0));
bool hasNonDefaultAttrs = llvm::any_of(exports, [](const auto &dispatch) {
return !dispatch.attrs.isDefault();
});
Expand Down Expand Up @@ -557,33 +566,35 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {

// iree_hal_executable_export_table_v0_t::names
llvm::Constant *exportNames =
llvm::Constant::getNullValue(i8Type->getPointerTo()->getPointerTo());
llvm::Constant::getNullValue(llvm::PointerType::get(i8Type, 0));
if (mode == Mode::INCLUDE_REFLECTION_ATTRS) {
SmallVector<llvm::Constant *> exportNameValues;
for (auto dispatch : exports)
exportNameValues.push_back(createStringConstant(dispatch.name, module));
exportNames =
createArrayConstant(libraryName + "_names", i8Type->getPointerTo(),
exportNameValues, module);
exportNames = createArrayConstant(libraryName + "_names",
llvm::PointerType::get(i8Type, 0),
exportNameValues, module);
}

// iree_hal_executable_export_table_v0_t::tags
auto *i8PtrType = llvm::PointerType::get(i8Type, 0);
llvm::Constant *exportTags =
llvm::Constant::getNullValue(i8Type->getPointerTo()->getPointerTo());
llvm::Constant::getNullValue(llvm::PointerType::get(i8PtrType, 0));
bool hasAnyTags = llvm::any_of(
exports, [](auto &dispatch) { return !dispatch.tag.empty(); });
if (mode == Mode::INCLUDE_REFLECTION_ATTRS && hasAnyTags) {
SmallVector<llvm::Constant *> exportTagValues;
for (auto dispatch : exports)
exportTagValues.push_back(
createStringConstantOrNull(dispatch.tag, module));
exportTags = createArrayConstant(
libraryName + "_tags", i8Type->getPointerTo(), exportTagValues, module);
exportTags = createArrayConstant(libraryName + "_tags",
llvm::PointerType::get(i8Type, 0),
exportTagValues, module);
}

// iree_hal_executable_export_table_v0_t::source_locations
llvm::Constant *exportSourceLocations =
llvm::Constant::getNullValue(sourceLocationType->getPointerTo());
llvm::Constant *exportSourceLocations = llvm::Constant::getNullValue(
llvm::PointerType::get(sourceLocationType, 0));
if (mode == Mode::INCLUDE_REFLECTION_ATTRS) {
SmallVector<llvm::Constant *> exportSourceLocationValues;
for (auto dispatch : exports) {
Expand All @@ -605,8 +616,8 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {
}

// iree_hal_executable_export_table_v0_t::stage_locations
llvm::Constant *exportStageLocations =
llvm::Constant::getNullValue(stageLocationTableType->getPointerTo());
llvm::Constant *exportStageLocations = llvm::Constant::getNullValue(
llvm::PointerType::get(stageLocationTableType, 0));
if (mode == Mode::INCLUDE_REFLECTION_ATTRS) {
SmallVector<llvm::Constant *> exportStageTableValues;
for (auto dispatch : exports) {
Expand All @@ -628,7 +639,7 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {
}
llvm::Constant *stageNamesPtr = createArrayConstant(
libraryName + "_" + dispatch.name + "_stage_names",
i8Type->getPointerTo(), exportStageNameValues, module);
llvm::PointerType::get(i8Type, 0), exportStageNameValues, module);
llvm::Constant *sourceLocationsPtr = createArrayConstant(
libraryName + "_" + dispatch.name + "_stage_source_locations",
sourceLocationType, exportSourceLocationValues, module);
Expand Down Expand Up @@ -688,7 +699,7 @@ LibraryBuilder::buildLibraryV0SourceTable(std::string libraryName) {
auto *sourceTableType = makeSourceTableType(context);
auto *i32Type = llvm::IntegerType::getInt32Ty(context);
llvm::Constant *sourceFilesValue =
llvm::Constant::getNullValue(sourceFileType->getPointerTo());
llvm::Constant::getNullValue(llvm::PointerType::get(sourceFileType, 0));
if (!sourceFiles.empty()) {
SmallVector<llvm::Constant *> sourceFileValues;
for (auto &sourceFile : sourceFiles) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ func.func @distribute_thread_forall(%out : memref<?xi32>)
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]])
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[DELIN:.+]] = affine.delinearize_index %[[LINID]] into (%c1024) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]]
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----

Expand All @@ -44,8 +43,7 @@ func.func @distribute_warp_forall(%out : memref<?xi32>)
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 2 + s2 * 4 + s0 floordiv 32)>(%[[I]])
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[DELIN:.+]] = affine.delinearize_index %[[LINID]] into (%c32) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]]
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----

Expand Down Expand Up @@ -85,8 +83,7 @@ func.func @distribute_thread_forall_drop_for_loop(%out : memref<?xi32>)
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)>
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[DELIN:.+]] = affine.delinearize_index %[[LINID]] into (%c128) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]]
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----

Expand All @@ -102,15 +99,14 @@ func.func @distribute_thread_forall_single_thread(%out : memref<?xi32>)
}

// CHECK-LABEL: func @distribute_thread_forall_single_thread
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)>
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: scf.for %[[I:.+]] = %[[LINID]] to %c1 step %c128 {
// CHECK: memref.store {{.*}}[%[[C0]]]
// CHECK: memref.store {{.*}}[%[[I]]]

// -----

Expand All @@ -133,7 +129,7 @@ func.func @distribute_thread_forall_multi_dim(%out : memref<?x?x?xi32>)
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]])
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (%c16, %c8, %c4) : index
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2]


Expand All @@ -157,5 +153,4 @@ func.func @distribute_thread_forall_small_workgroup(%out : memref<?xi32>)
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 7 + s2 * 7)>
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[DELIN:.+]] = affine.delinearize_index %[[LINID]] into (%c7) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]]
// CHECK: memref.store {{.*}}[%[[LINID]]]
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,6 @@ module {

// CHECK-LABEL: func @simple_nd_write(
// CHECK: %[[RD:.+]] = vector.transfer_read {{.*}} vector<1x128xf32>
// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index
// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (4, 8) : index, index
// CHECK: %[[INNER_ID:.+]] = affine.apply #[[$MAP]]()[%[[IDS]]#1]
// CHECK: vector.transfer_write %[[RD]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func.func @batch_matmul_with_padding_strategy(%arg0: tensor<1x?x1280xf16>, %arg1
%4 = tensor.empty() : tensor<1x64x128xf16>
%5 = vector.transfer_write %cst, %4[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x64x128xf16>, tensor<1x64x128xf16>
%6 = scf.for %arg2 = %c0 to %c20 step %c1 iter_args(%arg3 = %5) -> (tensor<1x64x128xf16>) {
%7 = affine.delinearize_index %arg2 into (%c20) : index
%7 = affine.delinearize_index %arg2 into (20) : index
%8 = affine.apply #map()[%7]
%extracted_slice_1 = tensor.extract_slice %arg1[0, %8, 0] [1, 64, 128] [1, 1, 1] : tensor<1x1280x128xf16> to tensor<1x64x128xf16>
%extracted_slice_2 = tensor.extract_slice %arg0[0, 0, %8] [1, %3, 64] [1, 1, 1] : tensor<1x?x1280xf16> to tensor<1x?x64xf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,11 @@ hal.executable private @scf_forall_4D_static_interchange {
// CHECK-DAG: %[[C160:.+]] = arith.constant 160 : index
// CHECK: hal.return %[[C6]], %[[C7]], %[[C160]]
// CHECK: func @scf_forall_4D_static_interchange()
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
// CHECK-NOT: scf.forall
// CHECK: %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (%[[C5]], %[[C8]], %[[C4]])
// CHECK: %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[WG_ID_Z]] into (5, 8, 4)
// CHECK: %[[I:.+]] = affine.apply #[[MAP0]]()[%[[DELINEARIZE]]#0]
// CHECK: %[[J:.+]] = affine.apply #[[MAP1]]()[%[[WG_ID_X]]]
// CHECK: %[[K:.+]] = affine.apply #[[MAP2]]()[%[[WG_ID_Y]]]
Expand Down
Loading

0 comments on commit 92526ea

Please sign in to comment.