Skip to content

Commit 44adc3a

Browse files
authored
Fixing executable linking when other targets are present. (#19035)
The existing linking code was all kinds of wrong when multiple executables with disjoint entry points were present. Linking needs to be reworked in general but this incremental change ensures that targets only link executables that contain variants that use them. There are definitely still corner cases that don't work. This is a small step towards towards heterogeneous devices. The following example now compiles and runs: ```mlir #executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", native_vector_size = 4 : index}> #executable_target_embedded_elf_x86_64_whatever = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", native_vector_size = 16 : index}> #executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb"> util.global private @device_a = #hal.device.target<"local", {ordinal = 0 : index}, [ #executable_target_embedded_elf_x86_64_, #executable_target_embedded_elf_x86_64_whatever ]> : !hal.device util.global private @device_b = #hal.device.target<"local", {ordinal = 1 : index}, [ #executable_target_vmvx_bytecode_fb ]> : !hal.device func.func public @mutli_device_mul_add( // Input argument is resident on device_a (tooling default to first device). %input_a: tensor<4xf32> {iree.abi.affinity = #hal.device.affinity<@device_a>} ) -> ( // Output result is expected to be on device_a (though not required). tensor<4xf32> {iree.abi.affinity = #hal.device.affinity<@device_a>} ) { // Compute on device_a (input is there). %constant_a = arith.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32> %transient_a = arith.mulf %input_a, %constant_a : tensor<4xf32> // Transfer the result from device_a -> device_b. %transient_b = flow.tensor.transfer %transient_a : tensor<4xf32> to #hal.device.affinity<@device_b> // Compute on device_b. %constant_b = arith.constant dense<[4.0, 5.0, 6.0, 7.0]> : tensor<4xf32> %result_b = arith.mulf %transient_b, %constant_b : tensor<4xf32> // Transfer the result from device_b -> device_a. %result_a = flow.tensor.transfer %result_b : tensor<4xf32> to #hal.device.affinity<@device_a> // More compute on device_a - should produce into the result buffer. %result_a2 = arith.addf %result_a, %constant_a : tensor<4xf32> // Return the result on device_a (as required by ABI attr). func.return %result_a2 : tensor<4xf32> } ``` ```sh $ iree-compile --iree-execution-model=async-external iree-run-module-multi.mlir -o module.vmfb $ iree-run-module \ --module=module.vmfb --function=mutli_device_mul_add --input=4xf32=10,11,12,13 \ --device=local-task --device=local-task --task_topology_group_count=1 ``` (testing this in-tree is hard right now due to ergonomics issues - this is all experimental anyway)
1 parent 8158a8c commit 44adc3a

File tree

7 files changed

+70
-56
lines changed

7 files changed

+70
-56
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ struct LLVMCPULinkExecutablesPass
2525
auto moduleOp = getOperation();
2626
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
2727

28-
auto sourceExecutableOps =
29-
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
28+
auto sourceExecutableOps = gatherExecutablesForTarget(moduleOp, target);
3029
if (sourceExecutableOps.size() <= 1)
3130
return;
3231

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ struct LLVMGPULinkExecutablesPass
6464
auto moduleOp = getOperation();
6565
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
6666

67-
auto sourceExecutableOps =
68-
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
67+
auto sourceExecutableOps = gatherExecutablesForTarget(moduleOp, target);
6968
if (sourceExecutableOps.size() <= 1)
7069
return;
7170

compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
1-
// RUN: iree-opt --iree-llvmgpu-link-executables --split-input-file %s | FileCheck %s
21
// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-TARGET
32
// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="cuda"},iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-MULTI
43

54
#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb">
65

76
// Expect a single executable with both exports and correct ordinals.
8-
// CHECK: hal.executable private @link_executables_linked
9-
// CHECK: hal.executable.variant public @rocm_hsaco_fb
10-
// CHECK: hal.executable.export public @export0 ordinal(0)
11-
// CHECK: hal.executable.export public @export1 ordinal(1)
7+
// CHECK-TARGET: hal.executable private @link_executables_linked
8+
// CHECK-TARGET: hal.executable.variant public @rocm_hsaco_fb
9+
// CHECK-TARGET: hal.executable.export public @export0 ordinal(0)
10+
// CHECK-TARGET: hal.executable.export public @export1 ordinal(1)
1211

1312
// Expect one LLVM module with all globals and functions.
1413
// Note that shared memory is duplicated but dynamic shared memory is not.
15-
// CHECK: builtin.module
16-
// CHECK-NEXT: llvm.mlir.global external @__dynamic_shared_memory__
17-
// CHECK-NEXT: llvm.mlir.global private @__shared_memory__{{.+}} : !llvm.array<2 x array<64 x i32>>
18-
// CHECK-NEXT: llvm.func @export0
19-
// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
20-
// CHECK-NEXT: llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3>
21-
// CHECK: llvm.mlir.global private @__shared_memory___0{{.+}} : !llvm.array<2 x array<128 x i32>>
22-
// CHECK-NEXT: llvm.func @export1
23-
// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
24-
// CHECK-NEXT: llvm.mlir.addressof @__shared_memory___0 : !llvm.ptr<3>
14+
// CHECK-TARGET: builtin.module
15+
// CHECK-TARGET-NEXT: llvm.mlir.global external @__dynamic_shared_memory__
16+
// CHECK-TARGET-NEXT: llvm.mlir.global private @__shared_memory__{{.+}} : !llvm.array<2 x array<64 x i32>>
17+
// CHECK-TARGET-NEXT: llvm.func @export0
18+
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
19+
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3>
20+
// CHECK-TARGET: llvm.mlir.global private @__shared_memory___0{{.+}} : !llvm.array<2 x array<128 x i32>>
21+
// CHECK-TARGET-NEXT: llvm.func @export1
22+
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
23+
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__shared_memory___0 : !llvm.ptr<3>
2524

2625
hal.executable private @executable0 {
2726
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) {
@@ -65,15 +64,6 @@ hal.executable private @executable1 {
6564
#executable_target_cuda = #hal.executable.target<"cuda", "cuda-nvptx-fb">
6665
#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb">
6766

68-
// Expect a single executable with multiple variants when not specifying target.
69-
// CHECK: hal.executable private @link_executables_linked
70-
// CHECK: hal.executable.variant public @cuda_nvptx_fb_0
71-
// CHECK: hal.executable.export public @export0 ordinal(0)
72-
// CHECK: hal.executable.export public @export1 ordinal(1)
73-
// CHECK: hal.executable.variant public @rocm_hsaco_fb_1
74-
// CHECK: hal.executable.export public @export0 ordinal(0)
75-
// CHECK: hal.executable.export public @export1 ordinal(1)
76-
7767
// Expect only one target be linked when specified.
7868
// CHECK-TARGET: hal.executable private @link_executables_linked
7969
// CHECK-TARGET: hal.executable.variant public @rocm_hsaco_fb_1
@@ -88,7 +78,7 @@ hal.executable private @executable1 {
8878

8979
// Multiple applications of the pass per target should not conflict.
9080
// CHECK-MULTI: hal.executable private @link_executables_linked_0
91-
// CHECK-MULTI: hal.executable.variant public @rocm_hsaco_fb_1
81+
// CHECK-MULTI: hal.executable.variant public @rocm_hsaco_fb
9282
// CHECK-MULTI: hal.executable.export public @export0 ordinal(0)
9383
// CHECK-MULTI: hal.executable.export public @export1 ordinal(1)
9484
// CHECK-MULTI: hal.executable private @link_executables_linked

compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ namespace mlir::iree_compiler {
2323
#include "iree/compiler/Codegen/SPIRV/Passes.h.inc"
2424

2525
namespace IREE::HAL {
26+
2627
// Compares two ExecutableTargetAttr according to the alphabetical order of used
2728
// SPIR-V features.
2829
//
2930
// Note that this is a very specific ordering per the needs of this pass--we
3031
// guarantee that input ExectuableTargetAttr only differ w.r.t. their used
3132
// SPIR-V features, and we want a deterministic order when mutating the IR.
32-
bool operator<(const ExecutableTargetAttr &a, const ExecutableTargetAttr &b) {
33+
static bool operator<(const ExecutableTargetAttr &a,
34+
const ExecutableTargetAttr &b) {
3335
auto aFeatures = a.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
3436
auto bFeatures = b.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
3537
for (unsigned i = 0; i < std::min(aFeatures.size(), bFeatures.size()); ++i) {
@@ -40,36 +42,37 @@ bool operator<(const ExecutableTargetAttr &a, const ExecutableTargetAttr &b) {
4042
}
4143
return aFeatures.size() < bFeatures.size();
4244
}
43-
} // namespace IREE::HAL
4445

4546
namespace {
4647

47-
using IREE::HAL::ExecutableTargetAttr;
48+
// Returns all executables that have one or more variants that use SPIR-V
49+
// codegen. Executables that contain object references are currently ignored as
50+
// we only support full replacement of the modules and not yet linking.
51+
static SmallVector<IREE::HAL::ExecutableOp>
52+
gatherExecutablesForSPIRVCodegen(mlir::ModuleOp moduleOp) {
53+
SmallVector<IREE::HAL::ExecutableOp> result;
54+
for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
55+
if (llvm::any_of(executableOp.getOps<IREE::HAL::ExecutableVariantOp>(),
56+
[&](IREE::HAL::ExecutableVariantOp variantOp) {
57+
return usesSPIRVCodeGen(variantOp) &&
58+
!variantOp.getObjects().has_value();
59+
})) {
60+
result.push_back(executableOp);
61+
}
62+
}
63+
return result;
64+
}
4865

4966
struct SPIRVLinkExecutablesPass final
5067
: impl::SPIRVLinkExecutablesPassBase<SPIRVLinkExecutablesPass> {
5168
void runOnOperation() override {
5269
mlir::ModuleOp moduleOp = getOperation();
5370

5471
// Collect all source executable ops.
55-
SmallVector<IREE::HAL::ExecutableOp, 8> sourceExecutableOps =
56-
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
72+
auto sourceExecutableOps = gatherExecutablesForSPIRVCodegen(moduleOp);
5773
if (sourceExecutableOps.size() <= 1)
5874
return;
5975

60-
// Retain only non-external source executables. Linking right now happens as
61-
// placing spirv.module ops into the same hal.executable.variant ops.
62-
// External source executables won't have any spirv.modules inside.
63-
int retainSize = 0;
64-
for (int i = 0, e = sourceExecutableOps.size(); i < e; ++i) {
65-
IREE::HAL::ExecutableOp executable = sourceExecutableOps[i];
66-
if (llvm::none_of(executable.getOps<IREE::HAL::ExecutableVariantOp>(),
67-
[](auto op) { return op.getObjects().has_value(); })) {
68-
sourceExecutableOps[retainSize++] = executable;
69-
}
70-
}
71-
sourceExecutableOps.resize(retainSize);
72-
7376
// Note that at runtime, for a particular executable, only one variant of it
7477
// will be loaded. So, all variants of an executable are expected to provide
7578
// the exact same set of entry points; this way we can guarantee no matter
@@ -213,4 +216,6 @@ struct SPIRVLinkExecutablesPass final
213216
};
214217

215218
} // namespace
219+
220+
} // namespace IREE::HAL
216221
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,21 @@ gatherExecutableTargets(ArrayRef<IREE::HAL::ExecutableOp> executableOps) {
2828
return result;
2929
}
3030

31+
SmallVector<IREE::HAL::ExecutableOp>
32+
gatherExecutablesForTarget(mlir::ModuleOp moduleOp, StringRef targetName) {
33+
SmallVector<IREE::HAL::ExecutableOp> result;
34+
for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
35+
if (llvm::any_of(executableOp.getOps<IREE::HAL::ExecutableVariantOp>(),
36+
[&](IREE::HAL::ExecutableVariantOp variantOp) {
37+
return variantOp.getTarget().getBackend().getValue() ==
38+
targetName;
39+
})) {
40+
result.push_back(executableOp);
41+
}
42+
}
43+
return result;
44+
}
45+
3146
// Renames |op| within |moduleOp| with a new name that is unique within both
3247
// |moduleOp| and |optionalSymbolTable| (if one is provided).
3348
static void
@@ -221,8 +236,9 @@ LogicalResult linkExecutablesInto(
221236
// TODO(benvanik): allow for grouping when multi-versioning is supported?
222237
// We could, for example, link all aarch64 variants together and then
223238
// use function multi-versioning to let LLVM insert runtime switches.
224-
if (variantOp.getTarget() != linkedTargetOp.getTarget())
239+
if (variantOp.getTarget() != linkedTargetOp.getTarget()) {
225240
continue;
241+
}
226242

227243
// Add any required object files to the set we will link in the target.
228244
if (auto objectsAttr = variantOp.getObjectsAttr()) {

compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ namespace mlir::iree_compiler {
1616
SetVector<IREE::HAL::ExecutableTargetAttr>
1717
gatherExecutableTargets(ArrayRef<IREE::HAL::ExecutableOp> executableOps);
1818

19-
// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version.
20-
// Only difference is one has the symbol map that we don't even need.
19+
// Returns a set of executables that contain one or more variants for the given
20+
// target backend name.
21+
SmallVector<IREE::HAL::ExecutableOp>
22+
gatherExecutablesForTarget(mlir::ModuleOp moduleOp, StringRef targetName);
2123

2224
static inline bool allowRenamingPrivateSymbols(Operation *op) {
2325
return SymbolTable::getSymbolVisibility(op) ==
@@ -32,6 +34,9 @@ static inline bool allowRenamingPrivateSymbols(Operation *op) {
3234
//
3335
// Fails if a public symbol in |sourceModuleOp| conflicts with another public
3436
// symbol tracked in |targetSymbolMap|.
37+
//
38+
// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version.
39+
// Only difference is one has the symbol map that we don't even need.
3540
LogicalResult
3641
mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp,
3742
DenseMap<StringRef, Operation *> &targetSymbolMap,

compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ struct VMVXLinkExecutablesPass
2424
auto moduleOp = getOperation();
2525
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
2626

27-
auto sourceExecutableOps =
28-
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
27+
auto sourceExecutableOps = gatherExecutablesForTarget(moduleOp, "vmvx");
2928
if (sourceExecutableOps.size() <= 1)
3029
return;
3130

@@ -44,15 +43,16 @@ struct VMVXLinkExecutablesPass
4443

4544
// Gather all unique executable targets - we may have multiple.
4645
auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps);
47-
for (auto [index, attr] : llvm::enumerate(executableTargetAttrs)) {
46+
for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) {
4847
// Add our VMVX hal.executable.variant with an empty module.
4948
std::string linkedVariantName =
5049
executableTargetAttrs.size() == 1
51-
? attr.getSymbolNameFragment()
52-
: llvm::formatv("{0}_{1}", attr.getSymbolNameFragment(), index);
50+
? targetAttr.getSymbolNameFragment()
51+
: llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(),
52+
index);
5353
auto linkedTargetOp =
5454
executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
55-
moduleOp.getLoc(), linkedVariantName, attr);
55+
moduleOp.getLoc(), linkedVariantName, targetAttr);
5656
auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
5757
auto linkedModuleOp = targetBuilder.create<ModuleOp>(moduleOp.getLoc());
5858

0 commit comments

Comments
 (0)