Skip to content

Commit bd1c3bf

Browse files
jreifferstensorflower-gardener
authored andcommitted
Add a utility for finding HLO use chains
This is needed for proper epilogue indexing computations, e.g. in fusions like this: ``` reduce / \ broadcast log | | neg bitcast \ / ROOT ``` The current assumption in `ComputeEpilogueInputToOutputIndexing` that we can just take the first user is incorrect here - the reduce is both part of the side output's computation and the hero of the fusion. Fusions like this make absolutely no sense, but they exist. PiperOrigin-RevId: 629994503
1 parent 4fa10ac commit bd1c3bf

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed

third_party/xla/xla/service/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5732,6 +5732,7 @@ xla_cc_test(
57325732
":hlo_traversal",
57335733
"//xla/hlo/ir:hlo",
57345734
"//xla/tests:hlo_test_base",
5735+
"@com_google_absl//absl/strings:string_view",
57355736
"@com_google_googletest//:gtest_main",
57365737
],
57375738
)

third_party/xla/xla/service/gpu/hlo_traversal.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ limitations under the License.
1414
==============================================================================*/
1515
#include "xla/service/gpu/hlo_traversal.h"
1616

17+
#include <algorithm>
1718
#include <functional>
1819
#include <iterator>
1920
#include <memory>
2021
#include <optional>
2122
#include <queue>
2223
#include <sstream>
2324
#include <string>
25+
#include <vector>
2426

2527
#include "absl/algorithm/container.h"
2628
#include "absl/container/flat_hash_set.h"
@@ -482,5 +484,29 @@ std::optional<const HloInstruction*> HloFindIf(
482484
return std::nullopt;
483485
}
484486

487+
std::vector<HloInstructionAdaptor> HloFindUseChain(HloInstructionAdaptor parent,
488+
HloInstructionAdaptor root) {
489+
absl::flat_hash_set<HloInstructionAdaptor> visited;
490+
std::vector<HloInstructionAdaptor> result;
491+
std::function<bool(HloInstructionAdaptor)> visit;
492+
visit = [&](HloInstructionAdaptor node) {
493+
if (node == root) return true;
494+
for (const auto& user : node.GetUsers()) {
495+
if (visited.insert(user).second && visit(user)) {
496+
result.push_back(user);
497+
return true;
498+
}
499+
}
500+
return false;
501+
};
502+
if (visit(parent)) {
503+
result.push_back(parent);
504+
std::reverse(result.begin(), result.end());
505+
} else {
506+
result.clear();
507+
}
508+
return result;
509+
}
510+
485511
} // namespace gpu
486512
} // namespace xla

third_party/xla/xla/service/gpu/hlo_traversal.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ void FindFusionArguments(
186186
const HloFusionAdaptor& fusion,
187187
const std::function<void(HloInstructionAdaptor producer)>& visit);
188188

189+
// Find a use chain from `parent` to `root`. Empty if no chain exists.
190+
// `[parent]` if `parent` is `root`.
191+
std::vector<HloInstructionAdaptor> HloFindUseChain(HloInstructionAdaptor parent,
192+
HloInstructionAdaptor root);
193+
189194
} // namespace gpu
190195
} // namespace xla
191196

third_party/xla/xla/service/gpu/hlo_traversal_test.cc

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121

2222
#include <gmock/gmock.h>
2323
#include <gtest/gtest.h>
24+
#include "absl/strings/string_view.h"
2425
#include "xla/hlo/ir/hlo_instruction.h"
2526
#include "xla/hlo/ir/hlo_opcode.h"
2627
#include "xla/tests/hlo_test_base.h"
@@ -30,6 +31,7 @@ namespace gpu {
3031
namespace {
3132

3233
using ::testing::ElementsAre;
34+
using ::testing::IsEmpty;
3335

3436
MATCHER_P(InstructionAdaptorName, name, "") { return arg.name() == name; }
3537

@@ -288,7 +290,7 @@ const char kTwoFusions[] = R"(
288290
sum = f32[128] add(p1, p1)
289291
negate = f32[128] negate(sum)
290292
fusion.1 = f32[] fusion(p0, negate), kind=kLoop, calls=fused_computation_1
291-
fusion.2 = f32[] fusion(fusion.1, negate), kind=kLoop, calls=fused_computation_2
293+
fusion.2 = f32[] fusion(fusion.1, negate), kind=kLoop, calls=fused_computation_2
292294
ROOT difference = f32[] subtract(fusion.2, p0)
293295
})";
294296

@@ -497,6 +499,49 @@ TEST_F(HloTraversalTest, MakeInstructionsPostOrder_TwoMultiOutputFusions) {
497499
InstructionAdaptorName("reduce.2")));
498500
}
499501

502+
TEST_F(HloTraversalTest, HloFindUseChain) {
503+
auto module = ParseAndReturnVerifiedModule(R"(
504+
fusion {
505+
p0 = f32[] parameter(0)
506+
p1 = f32[] parameter(1)
507+
negate = f32[] negate(p0)
508+
log = f32[] log(p0)
509+
sum = f32[] add(p0, log)
510+
exp = f32[] exponential(p1)
511+
ROOT call = f32[] custom-call(negate, exp, sum), custom_call_target="it"
512+
}
513+
514+
ENTRY main {
515+
p0 = f32[] parameter(0)
516+
p1 = f32[] parameter(1)
517+
ROOT fusion = f32[] fusion(p0, p1), kind=kLoop, calls=fusion
518+
}
519+
)")
520+
.value();
521+
522+
auto* fusion_computation = module->GetComputationWithName("fusion");
523+
auto fusion = HloFusionAdaptor::ForComputation(fusion_computation);
524+
auto get = [&](absl::string_view name) {
525+
return HloInstructionAdaptor{
526+
*fusion_computation->GetInstructionWithName(name), fusion.get()};
527+
};
528+
auto p0 = get("p0");
529+
auto p1 = get("p1");
530+
auto log = get("log");
531+
auto sum = get("sum");
532+
auto negate = get("negate");
533+
auto exp = get("exp");
534+
auto call = get("call");
535+
536+
EXPECT_THAT(HloFindUseChain(p0, p0), ElementsAre(p0));
537+
EXPECT_THAT(HloFindUseChain(p0, p1), IsEmpty());
538+
EXPECT_THAT(HloFindUseChain(p0, call), ElementsAre(p0, negate, call));
539+
EXPECT_THAT(HloFindUseChain(p0, sum), ElementsAre(p0, log, sum));
540+
EXPECT_THAT(HloFindUseChain(p1, exp), ElementsAre(p1, exp));
541+
EXPECT_THAT(HloFindUseChain(negate, exp), IsEmpty());
542+
EXPECT_THAT(HloFindUseChain(call, p0), IsEmpty());
543+
}
544+
500545
} // namespace
501546
} // namespace gpu
502547
} // namespace xla

0 commit comments

Comments
 (0)