diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp index a4fff96c3016..9cb10137874c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -44,6 +44,14 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, Block *block) { PartitionSet partitionSet; + struct OpInfo { + // Which partitions the op is contained within. + llvm::BitVector membership; + // Which partitions transitively depend on this operation. + llvm::BitVector hazards; + }; + DenseMap opInfos; + struct PartitionBuilder { unsigned ordinal; // Affinity of the partition. @@ -52,24 +60,81 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, SetVector ops; // Ops that were cloned and are known not to have their values escape. DenseSet clonedOps; - void insert(Operation *op) { + // Which partitions transitively depend on this partition. + llvm::BitVector hazards; + void insert(Operation *op, OpInfo &opInfo) { if (auto affinityOp = dyn_cast(op)) { affinity = affinity ? affinity.joinAND(affinityOp.getAffinityAttr()) : affinityOp.getAffinityAttr(); } + opInfo.membership.set(ordinal); + if (opInfo.hazards.size() > ordinal) + opInfo.hazards.reset(ordinal); ops.insert(op); + hazards |= opInfo.hazards; } }; SmallVector> builders; llvm::BitVector usableBuilders; - struct OpInfo { - // Which partitions the op is contained within. - llvm::BitVector membership; - // Which partitions transitively depend on this operation. - llvm::BitVector hazards; + auto willCreateCircularDependencyBetweenPartitions = + [&](unsigned sourceOrdinal, unsigned targetOrdinal) -> bool { + // Returns: + // If we are to make partition with ordinal targetOrdinal to + // depend on partition with ordinal sourceOrdinal, + // will this create a circular dependency. + if (sourceOrdinal == targetOrdinal) + return false; + return builders[sourceOrdinal]->hazards.size() > targetOrdinal && + builders[sourceOrdinal]->hazards[targetOrdinal]; + }; + + auto canAddOpToPartition = [&](Operation &op, OpInfo &opInfo, + unsigned partitionOrdinal) { + auto streamableOp = dyn_cast(op); + if (!streamableOp) + return false; + IREE::Stream::AffinityAttr affinityAttr; + if (auto affinityOp = dyn_cast(op)) + affinityAttr = affinityOp.getAffinityAttr(); + if (!IREE::Stream::AffinityAttr::canExecuteTogether( + affinityAttr, builders[partitionOrdinal]->affinity)) + return false; + + bool preferCloneToConsumers = streamableOp.preferCloneToConsumers(); + llvm::BitVector *opHazards = nullptr; + llvm::BitVector opHazardsInCandidatePartition; + if (preferCloneToConsumers) { + // If we are cloning we care only about users that are a part of the + // candidate partition. + // Here we would need to walk further down the users if a user is also + // cloned into the partition. This will be useful if we have a block of + // cloneable ops. If left like that, other than the inefficiency, + // it should not produce invalid partitioning. + opHazards = &opHazardsInCandidatePartition; + for (auto user : op.getUsers()) { + if (builders[partitionOrdinal]->ops.contains(user)) + opHazardsInCandidatePartition |= opInfos[user].hazards; + } + } else + opHazards = &opInfo.hazards; + + for (auto opHazardOrdinal : opHazards->set_bits()) { + if (partitionOrdinal < opHazardOrdinal) { + // Reject partition ordering that would require partition sorting. + // TODO: It is probably more optimal to reorder the partitions after + // their formation based on their dependency graph instead of rejecting + // here. Since this is considered not a good partitioning algorithm + // and will probably get removed, we leave it like that. + return false; + } + // Check for formation of circular dependency between partitions. + if (willCreateCircularDependencyBetweenPartitions(opHazardOrdinal, + partitionOrdinal)) + return false; + } + return true; }; - DenseMap opInfos; auto asmState = getRootAsmState(block); @@ -107,11 +172,6 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, opInfo.hazards.reserve(builders.size() + 1); opInfo.hazards.resize(builders.size(), /*t=*/false); - IREE::Stream::AffinityAttr affinityAttr; - if (auto affinityOp = dyn_cast(op)) { - affinityAttr = affinityOp.getAffinityAttr(); - } - LLVM_DEBUG({ llvm::dbgs() << "====\nPartitioning op:\n"; op.print(llvm::dbgs(), *asmState); @@ -149,8 +209,7 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, // Prune candidates that do not have a compatible affinity. for (auto ordinal : candidates.set_bits()) { - if (!IREE::Stream::AffinityAttr::canExecuteTogether( - affinityAttr, builders[ordinal]->affinity)) { + if (!canAddOpToPartition(op, opInfo, ordinal)) { LLVM_DEBUG(llvm::dbgs() << "Candidate partition " << ordinal << " incompatible\n"); candidates.reset(ordinal); @@ -181,19 +240,15 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, LLVM_DEBUG(llvm::dbgs() << "Cloning into consumer partition " << consumerOrdinal << "\n"); auto &consumerBuilder = builders[consumerOrdinal]; - consumerBuilder->insert(&op); + consumerBuilder->insert(&op, opInfo); consumerBuilder->clonedOps.insert(&op); - opInfo.membership.set(consumerOrdinal); - opInfo.hazards.reset(consumerOrdinal); } } else { int consumerOrdinal = consumers.find_last(); LLVM_DEBUG(llvm::dbgs() << "Moving into consumer partition " << consumerOrdinal << "\n"); auto &consumerBuilder = builders[consumerOrdinal]; - consumerBuilder->insert(&op); - opInfo.membership.set(consumerOrdinal); - opInfo.hazards.reset(consumerOrdinal); + consumerBuilder->insert(&op, opInfo); } LLVM_DEBUG(llvm::dbgs() << "Handled streamable (continue)\n"); continue; @@ -204,13 +259,18 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, if (firstCandidateOrdinal != -1) { LLVM_DEBUG(llvm::dbgs() << "Moving to first candidate partition " << firstCandidateOrdinal << " (continue)\n"); - builders[firstCandidateOrdinal]->insert(&op); - opInfo.membership.set(firstCandidateOrdinal); - opInfo.hazards.reset(firstCandidateOrdinal); + builders[firstCandidateOrdinal]->insert(&op, opInfo); continue; } // Mark the op as having hazards against all other partitions. + // It is better to be safe than incorrect, especially with our current + // minimal test coverage. It's not always safe to reorder things - if + // anything we are unlikely to be conservative enough here - for example, + // if there's a stream.resource.load of a resource or a global we can't + // move anything that may affect that resource or global. This partitioning + // was designed to be conservative because debugging such issues is really + // difficult. if (!builders.empty()) { opInfo.hazards.set(0, builders.size() - 1); } @@ -219,8 +279,7 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, opInfo.membership.resize(opInfo.membership.size() + 1, /*t=*/true); auto builder = std::make_unique(); builder->ordinal = builders.size(); - builder->affinity = affinityAttr; - builder->insert(&op); + builder->insert(&op, opInfo); LLVM_DEBUG(llvm::dbgs() << "Created partition " << builder->ordinal << "\n"); builders.push_back(std::move(builder)); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir index dcdc586184be..b5ecc53df47a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir @@ -139,6 +139,112 @@ util.func public @partitioningWithConcurrentAffinities(%arg0: !stream.resource, + %arg0: !stream.resource, +// CHECK-SAME: %[[ARG1:.+]]: !stream.resource) + %arg1: !stream.resource) -> (!stream.resource, !stream.resource) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + + %0 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e00[%c1]( + %arg0[%c1 to %c1 for %c1] + ) : (!stream.resource{%c1}) -> !stream.resource{%c1} + %1 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e01[%c1]( + %arg1[%c1 to %c1 for %c1] + ) : (!stream.resource{%c1}) -> !stream.resource{%c1} + + %2 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e10[%c1]( + %0[%c1 to %c1 for %c1], %1[%c1 to %c1 for %c1] + ) : (!stream.resource{%c1}, !stream.resource{%c1}) -> !stream.resource{%c1} + %3 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e11[%c1]( + %0[%c1 to %c1 for %c1], %1[%c1 to %c1 for %c1] + ) : (!stream.resource{%c1}, !stream.resource{%c1}) -> !stream.resource{%c1} + + %4 = stream.async.dispatch on(#hal.device.affinity<@device_0>) @ex::@e20[%c1]( + %2[%c1 to %c1 for %c1], %3[%c1 to %c1 for %c1] + ) : (!stream.resource{%c1}, !stream.resource{%c1}) -> !stream.resource{%c1} + %5 = stream.async.dispatch on(#hal.device.affinity<@device_1>) @ex::@e21[%c1]( + %2[%c1 to %c1 for %c1], %3[%c1 to %c1 for %c1] + ) : (!stream.resource{%c1}, !stream.resource{%c1}) -> !stream.resource{%c1} + + // Partition 0 + // CHECK: %[[RESULTS:.+]], %[[RESULT_TIMEPOINT:.+]] = stream.async.execute on(#hal.device.affinity<@device_0>) + // CHECK-SAME: with(%[[ARG0]] as %{{.+}}: !stream.resource{%[[C1]]}) + // CHECK: stream.async.dispatch @ex::@e00 + + // Partition 1 + // CHECK: %[[RESULTS_0:.+]]:2, %[[RESULT_TIMEPOINT_1:.+]] = stream.async.execute on(#hal.device.affinity<@device_1>) + // CHECK-SAME: await(%[[RESULT_TIMEPOINT]]) => with( + // CHECK-SAME: %[[ARG1]] as %{{.+}}: !stream.resource{%[[C1]]}, + // CHECK-SAME: %[[RESULTS]] as %{{.+}}: !stream.resource{%[[C1]]}) + // CHECK-DAG: stream.async.dispatch @ex::@e01 + // CHECK-DAG: stream.async.dispatch @ex::@e11 + + // CHECK: %[[T0:.+]] = stream.timepoint.join max(%[[RESULT_TIMEPOINT]], %[[RESULT_TIMEPOINT_1]]) => !stream.timepoint + + // Partition 2 + // CHECK: %[[RESULTS_2:.+]]:2, %[[RESULT_TIMEPOINT_3:.+]] = stream.async.execute on(#hal.device.affinity<@device_0>) + // CHECK-SAME: await(%[[T0]]) => with( + // CHECK-SAME: %[[RESULTS]] as %{{[A-Za-z0-9_]+}}: !stream.resource{%[[C1]]}, + // CHECK-SAME: %[[RESULTS_0]]#0 as %{{.+}}: !stream.resource{%[[C1]]}, + // CHECK-SAME: %[[RESULTS_0]]#1 as %{{.+}}: !stream.resource{%[[C1]]}) + // CHECK-DAG: stream.async.dispatch @ex::@e10 + // CHECK-DAG: stream.async.dispatch @ex::@e20 + + // CHECK: %[[T1:.+]] = stream.timepoint.join max(%[[RESULT_TIMEPOINT_3]], %[[RESULT_TIMEPOINT_1]]) => !stream.timepoint + + // Partition 3 + // CHECK: %[[RESULTS_4:.+]], %[[RESULT_TIMEPOINT_5:.+]] = stream.async.execute on(#hal.device.affinity<@device_1>) + // CHECK-SAME: await(%[[T1]]) => with( + // CHECK-SAME: %[[RESULTS_2]]#0 as %{{.+}}: !stream.resource{%[[C1]]}, + // CHECK-SAME: %[[RESULTS_0]]#1 as %{{.+}}: !stream.resource{%[[C1]]}) + // CHECK: stream.async.dispatch @ex::@e21 + + // CHECK: %[[R4:.+]] = stream.timepoint.await %[[RESULT_TIMEPOINT_5]] => %[[RESULTS_4]] : !stream.resource{%[[C1]]} + // CHECK: %[[R21:.+]] = stream.timepoint.await %[[RESULT_TIMEPOINT_3]] => %[[RESULTS_2]]#1 : !stream.resource{%[[C1]]} + // CHECK: util.return %[[R21]], %[[R4]] + util.return %4, %5 : !stream.resource, !stream.resource +} + +// ----- + // Tests that ops in multiple blocks are partitioned independently and that // timepoints are chained between the partitions. Note that the dispatches // happen in-place on the splat and we expect the execution regions to be tied.