Skip to content

Commit

Permalink
Fixing executable linking when other targets are present. (#19035)
Browse files Browse the repository at this point in the history
The existing linking code was all kinds of wrong when multiple
executables with disjoint entry points were present. Linking needs to be
reworked in general but this incremental change ensures that targets
only link executables that contain variants that use them. There are
definitely still corner cases that don't work.

This is a small step towards towards heterogeneous devices. The
following example now compiles and runs:
```mlir
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", native_vector_size = 4 : index}>
#executable_target_embedded_elf_x86_64_whatever = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", native_vector_size = 16 : index}>
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">

util.global private @device_a = #hal.device.target<"local", {ordinal = 0 : index}, [
  #executable_target_embedded_elf_x86_64_,
  #executable_target_embedded_elf_x86_64_whatever
]> : !hal.device
util.global private @device_b = #hal.device.target<"local", {ordinal = 1 : index}, [
  #executable_target_vmvx_bytecode_fb
]> : !hal.device

func.func public @mutli_device_mul_add(
  // Input argument is resident on device_a (tooling default to first device).
  %input_a: tensor<4xf32> {iree.abi.affinity = #hal.device.affinity<@device_a>}
) -> (
  // Output result is expected to be on device_a (though not required).
  tensor<4xf32> {iree.abi.affinity = #hal.device.affinity<@device_a>}
) {
  // Compute on device_a (input is there).
  %constant_a = arith.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32>
  %transient_a = arith.mulf %input_a, %constant_a : tensor<4xf32>
  // Transfer the result from device_a -> device_b.
  %transient_b = flow.tensor.transfer %transient_a : tensor<4xf32> to #hal.device.affinity<@device_b>
  // Compute on device_b.
  %constant_b = arith.constant dense<[4.0, 5.0, 6.0, 7.0]> : tensor<4xf32>
  %result_b = arith.mulf %transient_b, %constant_b : tensor<4xf32>
  // Transfer the result from device_b -> device_a.
  %result_a = flow.tensor.transfer %result_b : tensor<4xf32> to #hal.device.affinity<@device_a>
  // More compute on device_a - should produce into the result buffer.
  %result_a2 = arith.addf %result_a, %constant_a : tensor<4xf32>
  // Return the result on device_a (as required by ABI attr).
  func.return %result_a2 : tensor<4xf32>
}
```

```sh
$ iree-compile --iree-execution-model=async-external iree-run-module-multi.mlir -o module.vmfb
$ iree-run-module \
    --module=module.vmfb --function=mutli_device_mul_add --input=4xf32=10,11,12,13 \
    --device=local-task --device=local-task --task_topology_group_count=1
```

(testing this in-tree is hard right now due to ergonomics issues - this
is all experimental anyway)
  • Loading branch information
benvanik authored Nov 6, 2024
1 parent 8158a8c commit 44adc3a
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ struct LLVMCPULinkExecutablesPass
auto moduleOp = getOperation();
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());

auto sourceExecutableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
auto sourceExecutableOps = gatherExecutablesForTarget(moduleOp, target);
if (sourceExecutableOps.size() <= 1)
return;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ struct LLVMGPULinkExecutablesPass
auto moduleOp = getOperation();
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());

auto sourceExecutableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
auto sourceExecutableOps = gatherExecutablesForTarget(moduleOp, target);
if (sourceExecutableOps.size() <= 1)
return;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
// RUN: iree-opt --iree-llvmgpu-link-executables --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-TARGET
// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="cuda"},iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-MULTI

#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb">

// Expect a single executable with both exports and correct ordinals.
// CHECK: hal.executable private @link_executables_linked
// CHECK: hal.executable.variant public @rocm_hsaco_fb
// CHECK: hal.executable.export public @export0 ordinal(0)
// CHECK: hal.executable.export public @export1 ordinal(1)
// CHECK-TARGET: hal.executable private @link_executables_linked
// CHECK-TARGET: hal.executable.variant public @rocm_hsaco_fb
// CHECK-TARGET: hal.executable.export public @export0 ordinal(0)
// CHECK-TARGET: hal.executable.export public @export1 ordinal(1)

// Expect one LLVM module with all globals and functions.
// Note that shared memory is duplicated but dynamic shared memory is not.
// CHECK: builtin.module
// CHECK-NEXT: llvm.mlir.global external @__dynamic_shared_memory__
// CHECK-NEXT: llvm.mlir.global private @__shared_memory__{{.+}} : !llvm.array<2 x array<64 x i32>>
// CHECK-NEXT: llvm.func @export0
// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
// CHECK-NEXT: llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3>
// CHECK: llvm.mlir.global private @__shared_memory___0{{.+}} : !llvm.array<2 x array<128 x i32>>
// CHECK-NEXT: llvm.func @export1
// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
// CHECK-NEXT: llvm.mlir.addressof @__shared_memory___0 : !llvm.ptr<3>
// CHECK-TARGET: builtin.module
// CHECK-TARGET-NEXT: llvm.mlir.global external @__dynamic_shared_memory__
// CHECK-TARGET-NEXT: llvm.mlir.global private @__shared_memory__{{.+}} : !llvm.array<2 x array<64 x i32>>
// CHECK-TARGET-NEXT: llvm.func @export0
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3>
// CHECK-TARGET: llvm.mlir.global private @__shared_memory___0{{.+}} : !llvm.array<2 x array<128 x i32>>
// CHECK-TARGET-NEXT: llvm.func @export1
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
// CHECK-TARGET-NEXT: llvm.mlir.addressof @__shared_memory___0 : !llvm.ptr<3>

hal.executable private @executable0 {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) {
Expand Down Expand Up @@ -65,15 +64,6 @@ hal.executable private @executable1 {
#executable_target_cuda = #hal.executable.target<"cuda", "cuda-nvptx-fb">
#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb">

// Expect a single executable with multiple variants when not specifying target.
// CHECK: hal.executable private @link_executables_linked
// CHECK: hal.executable.variant public @cuda_nvptx_fb_0
// CHECK: hal.executable.export public @export0 ordinal(0)
// CHECK: hal.executable.export public @export1 ordinal(1)
// CHECK: hal.executable.variant public @rocm_hsaco_fb_1
// CHECK: hal.executable.export public @export0 ordinal(0)
// CHECK: hal.executable.export public @export1 ordinal(1)

// Expect only one target be linked when specified.
// CHECK-TARGET: hal.executable private @link_executables_linked
// CHECK-TARGET: hal.executable.variant public @rocm_hsaco_fb_1
Expand All @@ -88,7 +78,7 @@ hal.executable private @executable1 {

// Multiple applications of the pass per target should not conflict.
// CHECK-MULTI: hal.executable private @link_executables_linked_0
// CHECK-MULTI: hal.executable.variant public @rocm_hsaco_fb_1
// CHECK-MULTI: hal.executable.variant public @rocm_hsaco_fb
// CHECK-MULTI: hal.executable.export public @export0 ordinal(0)
// CHECK-MULTI: hal.executable.export public @export1 ordinal(1)
// CHECK-MULTI: hal.executable private @link_executables_linked
Expand Down
41 changes: 23 additions & 18 deletions compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ namespace mlir::iree_compiler {
#include "iree/compiler/Codegen/SPIRV/Passes.h.inc"

namespace IREE::HAL {

// Compares two ExecutableTargetAttr according to the alphabetical order of used
// SPIR-V features.
//
// Note that this is a very specific ordering per the needs of this pass--we
// guarantee that input ExectuableTargetAttr only differ w.r.t. their used
// SPIR-V features, and we want a deterministic order when mutating the IR.
bool operator<(const ExecutableTargetAttr &a, const ExecutableTargetAttr &b) {
static bool operator<(const ExecutableTargetAttr &a,
const ExecutableTargetAttr &b) {
auto aFeatures = a.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
auto bFeatures = b.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
for (unsigned i = 0; i < std::min(aFeatures.size(), bFeatures.size()); ++i) {
Expand All @@ -40,36 +42,37 @@ bool operator<(const ExecutableTargetAttr &a, const ExecutableTargetAttr &b) {
}
return aFeatures.size() < bFeatures.size();
}
} // namespace IREE::HAL

namespace {

using IREE::HAL::ExecutableTargetAttr;
// Returns all executables that have one or more variants that use SPIR-V
// codegen. Executables that contain object references are currently ignored as
// we only support full replacement of the modules and not yet linking.
static SmallVector<IREE::HAL::ExecutableOp>
gatherExecutablesForSPIRVCodegen(mlir::ModuleOp moduleOp) {
SmallVector<IREE::HAL::ExecutableOp> result;
for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
if (llvm::any_of(executableOp.getOps<IREE::HAL::ExecutableVariantOp>(),
[&](IREE::HAL::ExecutableVariantOp variantOp) {
return usesSPIRVCodeGen(variantOp) &&
!variantOp.getObjects().has_value();
})) {
result.push_back(executableOp);
}
}
return result;
}

struct SPIRVLinkExecutablesPass final
: impl::SPIRVLinkExecutablesPassBase<SPIRVLinkExecutablesPass> {
void runOnOperation() override {
mlir::ModuleOp moduleOp = getOperation();

// Collect all source executable ops.
SmallVector<IREE::HAL::ExecutableOp, 8> sourceExecutableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
auto sourceExecutableOps = gatherExecutablesForSPIRVCodegen(moduleOp);
if (sourceExecutableOps.size() <= 1)
return;

// Retain only non-external source executables. Linking right now happens as
// placing spirv.module ops into the same hal.executable.variant ops.
// External source executables won't have any spirv.modules inside.
int retainSize = 0;
for (int i = 0, e = sourceExecutableOps.size(); i < e; ++i) {
IREE::HAL::ExecutableOp executable = sourceExecutableOps[i];
if (llvm::none_of(executable.getOps<IREE::HAL::ExecutableVariantOp>(),
[](auto op) { return op.getObjects().has_value(); })) {
sourceExecutableOps[retainSize++] = executable;
}
}
sourceExecutableOps.resize(retainSize);

// Note that at runtime, for a particular executable, only one variant of it
// will be loaded. So, all variants of an executable are expected to provide
// the exact same set of entry points; this way we can guarantee no matter
Expand Down Expand Up @@ -213,4 +216,6 @@ struct SPIRVLinkExecutablesPass final
};

} // namespace

} // namespace IREE::HAL
} // namespace mlir::iree_compiler
18 changes: 17 additions & 1 deletion compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ gatherExecutableTargets(ArrayRef<IREE::HAL::ExecutableOp> executableOps) {
return result;
}

SmallVector<IREE::HAL::ExecutableOp>
gatherExecutablesForTarget(mlir::ModuleOp moduleOp, StringRef targetName) {
SmallVector<IREE::HAL::ExecutableOp> result;
for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
if (llvm::any_of(executableOp.getOps<IREE::HAL::ExecutableVariantOp>(),
[&](IREE::HAL::ExecutableVariantOp variantOp) {
return variantOp.getTarget().getBackend().getValue() ==
targetName;
})) {
result.push_back(executableOp);
}
}
return result;
}

// Renames |op| within |moduleOp| with a new name that is unique within both
// |moduleOp| and |optionalSymbolTable| (if one is provided).
static void
Expand Down Expand Up @@ -221,8 +236,9 @@ LogicalResult linkExecutablesInto(
// TODO(benvanik): allow for grouping when multi-versioning is supported?
// We could, for example, link all aarch64 variants together and then
// use function multi-versioning to let LLVM insert runtime switches.
if (variantOp.getTarget() != linkedTargetOp.getTarget())
if (variantOp.getTarget() != linkedTargetOp.getTarget()) {
continue;
}

// Add any required object files to the set we will link in the target.
if (auto objectsAttr = variantOp.getObjectsAttr()) {
Expand Down
9 changes: 7 additions & 2 deletions compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ namespace mlir::iree_compiler {
SetVector<IREE::HAL::ExecutableTargetAttr>
gatherExecutableTargets(ArrayRef<IREE::HAL::ExecutableOp> executableOps);

// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version.
// Only difference is one has the symbol map that we don't even need.
// Returns a set of executables that contain one or more variants for the given
// target backend name.
SmallVector<IREE::HAL::ExecutableOp>
gatherExecutablesForTarget(mlir::ModuleOp moduleOp, StringRef targetName);

static inline bool allowRenamingPrivateSymbols(Operation *op) {
return SymbolTable::getSymbolVisibility(op) ==
Expand All @@ -32,6 +34,9 @@ static inline bool allowRenamingPrivateSymbols(Operation *op) {
//
// Fails if a public symbol in |sourceModuleOp| conflicts with another public
// symbol tracked in |targetSymbolMap|.
//
// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version.
// Only difference is one has the symbol map that we don't even need.
LogicalResult
mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp,
DenseMap<StringRef, Operation *> &targetSymbolMap,
Expand Down
12 changes: 6 additions & 6 deletions compiler/src/iree/compiler/Codegen/VMVX/VMVXLinkExecutables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ struct VMVXLinkExecutablesPass
auto moduleOp = getOperation();
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());

auto sourceExecutableOps =
llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
auto sourceExecutableOps = gatherExecutablesForTarget(moduleOp, "vmvx");
if (sourceExecutableOps.size() <= 1)
return;

Expand All @@ -44,15 +43,16 @@ struct VMVXLinkExecutablesPass

// Gather all unique executable targets - we may have multiple.
auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps);
for (auto [index, attr] : llvm::enumerate(executableTargetAttrs)) {
for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) {
// Add our VMVX hal.executable.variant with an empty module.
std::string linkedVariantName =
executableTargetAttrs.size() == 1
? attr.getSymbolNameFragment()
: llvm::formatv("{0}_{1}", attr.getSymbolNameFragment(), index);
? targetAttr.getSymbolNameFragment()
: llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(),
index);
auto linkedTargetOp =
executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
moduleOp.getLoc(), linkedVariantName, attr);
moduleOp.getLoc(), linkedVariantName, targetAttr);
auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
auto linkedModuleOp = targetBuilder.create<ModuleOp>(moduleOp.getLoc());

Expand Down

0 comments on commit 44adc3a

Please sign in to comment.