Skip to content

Commit

Permalink
[compiler][stream] Avoid circular dependency between partitions in ex…
Browse files Browse the repository at this point in the history
…ecution scheduling

Multi-device programs may produce circular dependency between
partitions, which is an invalid partitioning.
One such example is provided as a test case.

To fix this here we track partitions that transitively depend (hazards) on a partition.
Not just per operation.
  • Loading branch information
sogartar committed Aug 14, 2024
1 parent 3901e62 commit be1af16
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
Block *block) {
PartitionSet partitionSet;

struct OpInfo {
// Which partitions the op is contained within.
llvm::BitVector membership;
// Which partitions transitively depend on this operation.
llvm::BitVector hazards;
};
DenseMap<Operation *, OpInfo> opInfos;

struct PartitionBuilder {
unsigned ordinal;
// Affinity of the partition.
Expand All @@ -52,24 +60,76 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
SetVector<Operation *> ops;
// Ops that were cloned and are known not to have their values escape.
DenseSet<Operation *> clonedOps;
void insert(Operation *op) {
// Which partitions transitively depend on this partition.
llvm::BitVector hazards;
void insert(Operation *op, OpInfo &opInfo) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
: affinityOp.getAffinityAttr();
}
opInfo.membership.set(ordinal);
if (opInfo.hazards.size() > ordinal)
opInfo.hazards.reset(ordinal);
ops.insert(op);
hazards |= opInfo.hazards;
}
};
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
llvm::BitVector usableBuilders;

struct OpInfo {
// Which partitions the op is contained within.
llvm::BitVector membership;
// Which partitions transitively depend on this operation.
llvm::BitVector hazards;
auto willCreateCircularDependencyBetweenPartitions =
[&](unsigned sourceOrdinal, unsigned targetOrdinal) -> bool {
// Returns:
// If we are to make partition with ordinal targetOrdinal to
// depend on partition with ordinal sourceOrdinal,
// will this create a circular dependency.
if (sourceOrdinal == targetOrdinal)
return false;
return builders[sourceOrdinal]->hazards.size() > targetOrdinal &&
builders[sourceOrdinal]->hazards[targetOrdinal];
};

auto canAddOpToPartition = [&](Operation &op, OpInfo &opInfo,
unsigned partitionOrdinal) {
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
if (!streamableOp)
return false;
IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op))
affinityAttr = affinityOp.getAffinityAttr();
if (!IREE::Stream::AffinityAttr::canExecuteTogether(
affinityAttr, builders[partitionOrdinal]->affinity))
return false;

bool preferCloneToConsumers = streamableOp.preferCloneToConsumers();
llvm::BitVector *opHazards = nullptr;
llvm::BitVector opHazardsInCandidatePartition;
if (preferCloneToConsumers) {
// If we are cloning we care only about users that are a part of the
// candidate partition.
opHazards = &opHazardsInCandidatePartition;
for (auto user : op.getUsers()) {
if (builders[partitionOrdinal]->ops.contains(user))
opHazardsInCandidatePartition |= opInfos[user].hazards;
}
} else
opHazards = &opInfo.hazards;

for (auto opHazardOrdinal : opHazards->set_bits()) {
if (partitionOrdinal < opHazardOrdinal)
// Reject partition ordering that would require partition sorting.
// TODO: It is probably more optimal to reorder the partitions after
// their formation based on their dependency graph instead of rejecting
// here. Since this is considered not a good partitioning algorithm
// and will probably get removed, we leave it like that.
return false;
// Check for formation of circular dependency between partitions.
if (willCreateCircularDependencyBetweenPartitions(opHazardOrdinal,
partitionOrdinal))
return false;
}
return true;
};
DenseMap<Operation *, OpInfo> opInfos;

auto asmState = getRootAsmState(block);

Expand Down Expand Up @@ -107,11 +167,6 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
opInfo.hazards.reserve(builders.size() + 1);
opInfo.hazards.resize(builders.size(), /*t=*/false);

IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinityAttr = affinityOp.getAffinityAttr();
}

LLVM_DEBUG({
llvm::dbgs() << "====\nPartitioning op:\n";
op.print(llvm::dbgs(), *asmState);
Expand Down Expand Up @@ -149,8 +204,7 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,

// Prune candidates that do not have a compatible affinity.
for (auto ordinal : candidates.set_bits()) {
if (!IREE::Stream::AffinityAttr::canExecuteTogether(
affinityAttr, builders[ordinal]->affinity)) {
if (!canAddOpToPartition(op, opInfo, ordinal)) {
LLVM_DEBUG(llvm::dbgs()
<< "Candidate partition " << ordinal << " incompatible\n");
candidates.reset(ordinal);
Expand Down Expand Up @@ -181,19 +235,15 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op);
consumerBuilder->insert(&op, opInfo);
consumerBuilder->clonedOps.insert(&op);
opInfo.membership.set(consumerOrdinal);
opInfo.hazards.reset(consumerOrdinal);
}
} else {
int consumerOrdinal = consumers.find_last();
LLVM_DEBUG(llvm::dbgs() << "Moving into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op);
opInfo.membership.set(consumerOrdinal);
opInfo.hazards.reset(consumerOrdinal);
consumerBuilder->insert(&op, opInfo);
}
LLVM_DEBUG(llvm::dbgs() << "Handled streamable (continue)\n");
continue;
Expand All @@ -204,13 +254,13 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
if (firstCandidateOrdinal != -1) {
LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
<< firstCandidateOrdinal << " (continue)\n");
builders[firstCandidateOrdinal]->insert(&op);
opInfo.membership.set(firstCandidateOrdinal);
opInfo.hazards.reset(firstCandidateOrdinal);
builders[firstCandidateOrdinal]->insert(&op, opInfo);
continue;
}

// Mark the op as having hazards against all other partitions.
// Why are we that conservative?
// Why we don't take the hazards for the users?
if (!builders.empty()) {
opInfo.hazards.set(0, builders.size() - 1);
}
Expand All @@ -219,8 +269,7 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
builder->affinity = affinityAttr;
builder->insert(&op);
builder->insert(&op, opInfo);
LLVM_DEBUG(llvm::dbgs()
<< "Created partition " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,112 @@ util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource<e

// -----

// Partitioning with device assignment.
// Ops on different devices are interleaved and interdependent according to
// arg0 arg1
// ↓ ↓
// device_0 0 1 device_1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// device_0 2 3 device_1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// device_0 4 5 device_1
//
// This will result in partition assignment
// arg0 arg1
// ↓ ↓
// P0 0 1 P1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// P2 2 3 P1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// P2 4 5 P3
//
// CHECK-LABEL: @partitionWithInterdependentInterleavedDeviceAffinites
util.func public @partitionWithInterdependentInterleavedDeviceAffinites(
// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>,
%arg0: !stream.resource<external>,
// CHECK-SAME: %[[ARG1:.+]]: !stream.resource<external>)
%arg1: !stream.resource<external>) -> (!stream.resource<external>, !stream.resource<external>) {
// CHECK: %[[C1:.+]] = arith.constant 1 : index
%c1 = arith.constant 1 : index

%0 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e00[%c1](
%arg0[%c1 to %c1 for %c1]
) : (!stream.resource<external>{%c1}) -> !stream.resource<transient>{%c1}
%1 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e01[%c1](
%arg1[%c1 to %c1 for %c1]
) : (!stream.resource<external>{%c1}) -> !stream.resource<transient>{%c1}

%2 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e10[%c1](
%0[%c1 to %c1 for %c1], %1[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<transient>{%c1}
%3 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e11[%c1](
%0[%c1 to %c1 for %c1], %1[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<transient>{%c1}

%4 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e20[%c1](
%2[%c1 to %c1 for %c1], %3[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<external>{%c1}
%5 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e21[%c1](
%2[%c1 to %c1 for %c1], %3[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<external>{%c1}

// Partition 0
// CHECK: %[[RESULTS:.+]], %[[RESULT_TIMEPOINT:.+]] = stream.async.execute on(#hal.device.affinity<@device_0>)
// CHECK-SAME: with(%[[ARG0]] as %{{.+}}: !stream.resource<external>{%[[C1]]})
// CHECK: stream.async.dispatch @ex::@e00

// Partition 1
// CHECK: %[[RESULTS_0:.+]]:2, %[[RESULT_TIMEPOINT_1:.+]] = stream.async.execute on(#hal.device.affinity<@device_1>)
// CHECK-SAME: await(%[[RESULT_TIMEPOINT]]) => with(
// CHECK-SAME: %[[ARG1]] as %{{.+}}: !stream.resource<external>{%[[C1]]},
// CHECK-SAME: %[[RESULTS]] as %{{.+}}: !stream.resource<transient>{%[[C1]]})
// CHECK-DAG: stream.async.dispatch @ex::@e01
// CHECK-DAG: stream.async.dispatch @ex::@e11

// CHECK: %[[T0:.+]] = stream.timepoint.join max(%[[RESULT_TIMEPOINT]], %[[RESULT_TIMEPOINT_1]]) => !stream.timepoint

// Partition 2
// CHECK: %[[RESULTS_2:.+]]:2, %[[RESULT_TIMEPOINT_3:.+]] = stream.async.execute on(#hal.device.affinity<@device_0>)
// CHECK-SAME: await(%[[T0]]) => with(
// CHECK-SAME: %[[RESULTS]] as %{{[A-Za-z0-9_]+}}: !stream.resource<transient>{%[[C1]]},
// CHECK-SAME: %[[RESULTS_0]]#0 as %{{.+}}: !stream.resource<transient>{%[[C1]]},
// CHECK-SAME: %[[RESULTS_0]]#1 as %{{.+}}: !stream.resource<transient>{%[[C1]]})
// CHECK-DAG: stream.async.dispatch @ex::@e10
// CHECK-DAG: stream.async.dispatch @ex::@e20

// CHECK: %[[T1:.+]] = stream.timepoint.join max(%[[RESULT_TIMEPOINT_3]], %[[RESULT_TIMEPOINT_1]]) => !stream.timepoint

// Partition 3
// CHECK: %[[RESULTS_4:.+]], %[[RESULT_TIMEPOINT_5:.+]] = stream.async.execute on(#hal.device.affinity<@device_1>)
// CHECK-SAME: await(%[[T1]]) => with(
// CHECK-SAME: %[[RESULTS_2]]#0 as %{{.+}}: !stream.resource<transient>{%[[C1]]},
// CHECK-SAME: %[[RESULTS_0]]#1 as %{{.+}}: !stream.resource<transient>{%[[C1]]})
// CHECK: stream.async.dispatch @ex::@e21

// CHECK: %[[R4:.+]] = stream.timepoint.await %[[RESULT_TIMEPOINT_5]] => %[[RESULTS_4]] : !stream.resource<external>{%[[C1]]}
// CHECK: %[[R21:.+]] = stream.timepoint.await %[[RESULT_TIMEPOINT_3]] => %[[RESULTS_2]]#1 : !stream.resource<external>{%[[C1]]}
// CHECK: util.return %[[R21]], %[[R4]]
util.return %4, %5 : !stream.resource<external>, !stream.resource<external>
}

// -----

// Tests that ops in multiple blocks are partitioned independently and that
// timepoints are chained between the partitions. Note that the dispatches
// happen in-place on the splat and we expect the execution regions to be tied.
Expand Down

0 comments on commit be1af16

Please sign in to comment.