Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler][stream] Avoid circular dependency between partitions in execution scheduling #18217

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
Block *block) {
PartitionSet partitionSet;

struct OpInfo {
// Which partitions the op is contained within.
llvm::BitVector membership;
// Which partitions transitively depend on this operation.
llvm::BitVector hazards;
};
DenseMap<Operation *, OpInfo> opInfos;

struct PartitionBuilder {
unsigned ordinal;
// Affinity of the partition.
Expand All @@ -52,24 +60,76 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
SetVector<Operation *> ops;
// Ops that were cloned and are known not to have their values escape.
DenseSet<Operation *> clonedOps;
void insert(Operation *op) {
// Which partitions transitively depend on this partition.
llvm::BitVector hazards;
void insert(Operation *op, OpInfo &opInfo) {
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr())
: affinityOp.getAffinityAttr();
}
opInfo.membership.set(ordinal);
if (opInfo.hazards.size() > ordinal)
opInfo.hazards.reset(ordinal);
ops.insert(op);
hazards |= opInfo.hazards;
}
};
SmallVector<std::unique_ptr<PartitionBuilder>> builders;
llvm::BitVector usableBuilders;

struct OpInfo {
// Which partitions the op is contained within.
llvm::BitVector membership;
// Which partitions transitively depend on this operation.
llvm::BitVector hazards;
auto willCreateCircularDependencyBetweenPartitions =
[&](unsigned sourceOrdinal, unsigned targetOrdinal) -> bool {
// Returns:
// If we are to make partition with ordinal targetOrdinal to
// depend on partition with ordinal sourceOrdinal,
// will this create a circular dependency.
if (sourceOrdinal == targetOrdinal)
return false;
return builders[sourceOrdinal]->hazards.size() > targetOrdinal &&
builders[sourceOrdinal]->hazards[targetOrdinal];
};

auto canAddOpToPartition = [&](Operation &op, OpInfo &opInfo,
unsigned partitionOrdinal) {
auto streamableOp = dyn_cast<IREE::Stream::StreamableOpInterface>(op);
if (!streamableOp)
return false;
IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op))
affinityAttr = affinityOp.getAffinityAttr();
if (!IREE::Stream::AffinityAttr::canExecuteTogether(
affinityAttr, builders[partitionOrdinal]->affinity))
return false;

bool preferCloneToConsumers = streamableOp.preferCloneToConsumers();
llvm::BitVector *opHazards = nullptr;
llvm::BitVector opHazardsInCandidatePartition;
if (preferCloneToConsumers) {
// If we are cloning we care only about users that are a part of the
// candidate partition.
opHazards = &opHazardsInCandidatePartition;
for (auto user : op.getUsers()) {
if (builders[partitionOrdinal]->ops.contains(user))
opHazardsInCandidatePartition |= opInfos[user].hazards;
}
Copy link
Contributor Author

@sogartar sogartar Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, here we would need to walk further down the users if a user is also cloned into the partition. This will be useful if we have a block of cloneable ops. If left like that, other than the inefficiency, I don't think it will produce invalid partitioning.

} else
opHazards = &opInfo.hazards;

for (auto opHazardOrdinal : opHazards->set_bits()) {
if (partitionOrdinal < opHazardOrdinal)
// Reject partition ordering that would require partition sorting.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use {} around multi-line if statements

// TODO: It is probably more optimal to reorder the partitions after
// their formation based on their dependency graph instead of rejecting
// here. Since this is considered not a good partitioning algorithm
// and will probably get removed, we leave it like that.
return false;
// Check for formation of circular dependency between partitions.
if (willCreateCircularDependencyBetweenPartitions(opHazardOrdinal,
partitionOrdinal))
return false;
}
return true;
};
DenseMap<Operation *, OpInfo> opInfos;

auto asmState = getRootAsmState(block);

Expand Down Expand Up @@ -107,11 +167,6 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
opInfo.hazards.reserve(builders.size() + 1);
opInfo.hazards.resize(builders.size(), /*t=*/false);

IREE::Stream::AffinityAttr affinityAttr;
if (auto affinityOp = dyn_cast<IREE::Stream::AffinityOpInterface>(op)) {
affinityAttr = affinityOp.getAffinityAttr();
}

LLVM_DEBUG({
llvm::dbgs() << "====\nPartitioning op:\n";
op.print(llvm::dbgs(), *asmState);
Expand Down Expand Up @@ -149,8 +204,7 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,

// Prune candidates that do not have a compatible affinity.
for (auto ordinal : candidates.set_bits()) {
if (!IREE::Stream::AffinityAttr::canExecuteTogether(
affinityAttr, builders[ordinal]->affinity)) {
if (!canAddOpToPartition(op, opInfo, ordinal)) {
LLVM_DEBUG(llvm::dbgs()
<< "Candidate partition " << ordinal << " incompatible\n");
candidates.reset(ordinal);
Expand Down Expand Up @@ -181,19 +235,15 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op);
consumerBuilder->insert(&op, opInfo);
consumerBuilder->clonedOps.insert(&op);
opInfo.membership.set(consumerOrdinal);
opInfo.hazards.reset(consumerOrdinal);
}
} else {
int consumerOrdinal = consumers.find_last();
LLVM_DEBUG(llvm::dbgs() << "Moving into consumer partition "
<< consumerOrdinal << "\n");
auto &consumerBuilder = builders[consumerOrdinal];
consumerBuilder->insert(&op);
opInfo.membership.set(consumerOrdinal);
opInfo.hazards.reset(consumerOrdinal);
consumerBuilder->insert(&op, opInfo);
}
LLVM_DEBUG(llvm::dbgs() << "Handled streamable (continue)\n");
continue;
Expand All @@ -204,13 +254,13 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
if (firstCandidateOrdinal != -1) {
LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition "
<< firstCandidateOrdinal << " (continue)\n");
builders[firstCandidateOrdinal]->insert(&op);
opInfo.membership.set(firstCandidateOrdinal);
opInfo.hazards.reset(firstCandidateOrdinal);
builders[firstCandidateOrdinal]->insert(&op, opInfo);
continue;
}

// Mark the op as having hazards against all other partitions.
// Why are we that conservative?
// Why we don't take the hazards for the users?
if (!builders.empty()) {
opInfo.hazards.set(0, builders.size() - 1);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benvanik, could you shed some light on what was your reasoning here when you wrote the algorithm?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to be safe than incorrect, especially with our current minimal test coverage. It's not always safe to reorder things - if anything we are unlikely to be conservative enough here - for example, if there's a stream.resource.load of a resource or a global we can't move anything that may affect that resource or global. This partitioning was designed to be conservative because debugging such issues is really difficult.

Expand All @@ -219,8 +269,7 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true);
auto builder = std::make_unique<PartitionBuilder>();
builder->ordinal = builders.size();
builder->affinity = affinityAttr;
builder->insert(&op);
builder->insert(&op, opInfo);
LLVM_DEBUG(llvm::dbgs()
<< "Created partition " << builder->ordinal << "\n");
builders.push_back(std::move(builder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,112 @@ util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource<e

// -----

// Partitioning with device assignment.
// Ops on different devices are interleaved and interdependent according to
// arg0 arg1
// ↓ ↓
// device_0 0 1 device_1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// device_0 2 3 device_1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// device_0 4 5 device_1
//
// This will result in partition assignment
// arg0 arg1
// ↓ ↓
// P0 0 1 P1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// P2 2 3 P1
// |\ /|
// | \/ |
// | /\ |
// |/ \|
// ↓ ↓
// P2 4 5 P3
//
// CHECK-LABEL: @partitionWithInterdependentInterleavedDeviceAffinites
util.func public @partitionWithInterdependentInterleavedDeviceAffinites(
// CHECK-SAME: (%[[ARG0:.+]]: !stream.resource<external>,
%arg0: !stream.resource<external>,
// CHECK-SAME: %[[ARG1:.+]]: !stream.resource<external>)
%arg1: !stream.resource<external>) -> (!stream.resource<external>, !stream.resource<external>) {
// CHECK: %[[C1:.+]] = arith.constant 1 : index
%c1 = arith.constant 1 : index

%0 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e00[%c1](
%arg0[%c1 to %c1 for %c1]
) : (!stream.resource<external>{%c1}) -> !stream.resource<transient>{%c1}
%1 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e01[%c1](
%arg1[%c1 to %c1 for %c1]
) : (!stream.resource<external>{%c1}) -> !stream.resource<transient>{%c1}

%2 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e10[%c1](
%0[%c1 to %c1 for %c1], %1[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<transient>{%c1}
%3 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e11[%c1](
%0[%c1 to %c1 for %c1], %1[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<transient>{%c1}

%4 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e20[%c1](
%2[%c1 to %c1 for %c1], %3[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<external>{%c1}
%5 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e21[%c1](
%2[%c1 to %c1 for %c1], %3[%c1 to %c1 for %c1]
) : (!stream.resource<transient>{%c1}, !stream.resource<transient>{%c1}) -> !stream.resource<external>{%c1}

// Partition 0
// CHECK: %[[RESULTS:.+]], %[[RESULT_TIMEPOINT:.+]] = stream.async.execute on(#hal.device.affinity<@device_0>)
// CHECK-SAME: with(%[[ARG0]] as %{{.+}}: !stream.resource<external>{%[[C1]]})
// CHECK: stream.async.dispatch @ex::@e00

// Partition 1
// CHECK: %[[RESULTS_0:.+]]:2, %[[RESULT_TIMEPOINT_1:.+]] = stream.async.execute on(#hal.device.affinity<@device_1>)
// CHECK-SAME: await(%[[RESULT_TIMEPOINT]]) => with(
// CHECK-SAME: %[[ARG1]] as %{{.+}}: !stream.resource<external>{%[[C1]]},
// CHECK-SAME: %[[RESULTS]] as %{{.+}}: !stream.resource<transient>{%[[C1]]})
// CHECK-DAG: stream.async.dispatch @ex::@e01
// CHECK-DAG: stream.async.dispatch @ex::@e11

// CHECK: %[[T0:.+]] = stream.timepoint.join max(%[[RESULT_TIMEPOINT]], %[[RESULT_TIMEPOINT_1]]) => !stream.timepoint

// Partition 2
// CHECK: %[[RESULTS_2:.+]]:2, %[[RESULT_TIMEPOINT_3:.+]] = stream.async.execute on(#hal.device.affinity<@device_0>)
// CHECK-SAME: await(%[[T0]]) => with(
// CHECK-SAME: %[[RESULTS]] as %{{[A-Za-z0-9_]+}}: !stream.resource<transient>{%[[C1]]},
// CHECK-SAME: %[[RESULTS_0]]#0 as %{{.+}}: !stream.resource<transient>{%[[C1]]},
// CHECK-SAME: %[[RESULTS_0]]#1 as %{{.+}}: !stream.resource<transient>{%[[C1]]})
// CHECK-DAG: stream.async.dispatch @ex::@e10
// CHECK-DAG: stream.async.dispatch @ex::@e20

// CHECK: %[[T1:.+]] = stream.timepoint.join max(%[[RESULT_TIMEPOINT_3]], %[[RESULT_TIMEPOINT_1]]) => !stream.timepoint

// Partition 3
// CHECK: %[[RESULTS_4:.+]], %[[RESULT_TIMEPOINT_5:.+]] = stream.async.execute on(#hal.device.affinity<@device_1>)
// CHECK-SAME: await(%[[T1]]) => with(
// CHECK-SAME: %[[RESULTS_2]]#0 as %{{.+}}: !stream.resource<transient>{%[[C1]]},
// CHECK-SAME: %[[RESULTS_0]]#1 as %{{.+}}: !stream.resource<transient>{%[[C1]]})
// CHECK: stream.async.dispatch @ex::@e21

// CHECK: %[[R4:.+]] = stream.timepoint.await %[[RESULT_TIMEPOINT_5]] => %[[RESULTS_4]] : !stream.resource<external>{%[[C1]]}
// CHECK: %[[R21:.+]] = stream.timepoint.await %[[RESULT_TIMEPOINT_3]] => %[[RESULTS_2]]#1 : !stream.resource<external>{%[[C1]]}
// CHECK: util.return %[[R21]], %[[R4]]
util.return %4, %5 : !stream.resource<external>, !stream.resource<external>
}

// -----

// Tests that ops in multiple blocks are partitioned independently and that
// timepoints are chained between the partitions. Note that the dispatches
// happen in-place on the splat and we expect the execution regions to be tied.
Expand Down
Loading