Skip to content

Commit ee13f74

Browse files
akuegelGoogle-ML-Automation
authored andcommitted
Use GetInPlaceInputOutputPairs from AliasInfo instead of HloDataflowAnalysis.
PiperOrigin-RevId: 837718510
1 parent 7d27a09 commit ee13f74

16 files changed

+222
-161
lines changed

xla/backends/cpu/transforms/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ cc_library(
140140
"//xla:xla_proto_cc",
141141
"//xla/backends/cpu:xnn_support",
142142
"//xla/backends/cpu/runtime/xnnpack:xnn_interop",
143+
"//xla/hlo/analysis:alias_info",
143144
"//xla/hlo/ir:hlo",
144145
"//xla/service:call_graph",
145146
"//xla/service:instruction_fusion",
@@ -158,6 +159,7 @@ xla_cc_test(
158159
"//xla:xla_data_proto_cc",
159160
"//xla:xla_proto_cc",
160161
"//xla/backends/cpu:xnn_support",
162+
"//xla/hlo/analysis:alias_info",
161163
"//xla/hlo/ir:hlo",
162164
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
163165
"//xla/hlo/utils:hlo_matchers",

xla/backends/cpu/transforms/xnn_graph_fusion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "absl/container/flat_hash_set.h"
2222
#include "absl/strings/string_view.h"
23+
#include "xla/hlo/analysis/alias_info.h"
2324
#include "xla/hlo/ir/hlo_instruction.h"
2425
#include "xla/service/instruction_fusion.h"
2526

@@ -28,7 +29,8 @@ namespace cpu {
2829

2930
class XnnGraphFusion : public InstructionFusion {
3031
public:
31-
XnnGraphFusion() : InstructionFusion(XnnGraphFusion::IsExpensive) {}
32+
explicit XnnGraphFusion(const AliasInfo* alias_info)
33+
: InstructionFusion(XnnGraphFusion::IsExpensive, alias_info) {}
3234
~XnnGraphFusion() override = default;
3335

3436
private:

xla/backends/cpu/transforms/xnn_graph_fusion_test.cc

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <gmock/gmock.h>
2222
#include <gtest/gtest.h>
2323
#include "xla/backends/cpu/xnn_support.h"
24+
#include "xla/hlo/analysis/alias_info.h"
2425
#include "xla/hlo/ir/hlo_casting_utils.h"
2526
#include "xla/hlo/ir/hlo_instruction.h"
2627
#include "xla/hlo/ir/hlo_instructions.h"
@@ -37,7 +38,10 @@ namespace op = xla::testing::opcode_matchers;
3738
namespace xla::cpu {
3839
namespace {
3940

40-
using XnnGraphFusionTest = HloHardwareIndependentTestBase;
41+
class XnnGraphFusionTest : public HloHardwareIndependentTestBase {
42+
protected:
43+
AliasInfo alias_info_;
44+
};
4145

4246
TEST_F(XnnGraphFusionTest, BasicFusion) {
4347
std::string hlo_string = R"(
@@ -54,7 +58,8 @@ ENTRY entry {
5458

5559
TF_ASSERT_OK_AND_ASSIGN(auto module,
5660
ParseAndReturnVerifiedModule(hlo_string));
57-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
61+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
62+
XnnGraphFusion(&alias_info_).Run(module.get()));
5863
ASSERT_TRUE(changed);
5964
EXPECT_THAT(module.get()->entry_computation()->root_instruction(),
6065
op::Fusion());
@@ -82,7 +87,8 @@ ENTRY entry {
8287

8388
TF_ASSERT_OK_AND_ASSIGN(auto module,
8489
ParseAndReturnVerifiedModule(hlo_string));
85-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
90+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
91+
XnnGraphFusion(&alias_info_).Run(module.get()));
8692
ASSERT_FALSE(changed);
8793
}
8894

@@ -101,7 +107,8 @@ ENTRY entry {
101107

102108
TF_ASSERT_OK_AND_ASSIGN(auto module,
103109
ParseAndReturnVerifiedModule(hlo_string));
104-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
110+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
111+
XnnGraphFusion(&alias_info_).Run(module.get()));
105112
ASSERT_FALSE(changed);
106113
}
107114

@@ -128,7 +135,8 @@ ENTRY entry {
128135
ParseAndReturnVerifiedModule(hlo_string));
129136
SetFusionMode(module.get(),
130137
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
131-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
138+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
139+
XnnGraphFusion(&alias_info_).Run(module.get()));
132140
ASSERT_TRUE(changed);
133141
EXPECT_THAT(module.get()->entry_computation()->root_instruction(),
134142
op::Fusion());
@@ -158,7 +166,8 @@ ENTRY entry {
158166
ParseAndReturnVerifiedModule(hlo_string));
159167
SetFusionMode(module.get(),
160168
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
161-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
169+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
170+
XnnGraphFusion(&alias_info_).Run(module.get()));
162171
ASSERT_FALSE(changed);
163172
}
164173

@@ -179,7 +188,8 @@ ENTRY entry {
179188
ParseAndReturnVerifiedModule(hlo_string));
180189
SetFusionMode(module.get(),
181190
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
182-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
191+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
192+
XnnGraphFusion(&alias_info_).Run(module.get()));
183193
ASSERT_FALSE(changed);
184194
}
185195

@@ -199,7 +209,8 @@ ENTRY entry {
199209
ParseAndReturnVerifiedModule(hlo_string));
200210
SetFusionMode(module.get(),
201211
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
202-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
212+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
213+
XnnGraphFusion(&alias_info_).Run(module.get()));
203214
ASSERT_FALSE(changed);
204215
}
205216

@@ -215,7 +226,8 @@ ENTRY entry {
215226

216227
TF_ASSERT_OK_AND_ASSIGN(auto module,
217228
ParseAndReturnVerifiedModule(hlo_string));
218-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
229+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
230+
XnnGraphFusion(&alias_info_).Run(module.get()));
219231
ASSERT_FALSE(changed);
220232
}
221233

@@ -240,7 +252,8 @@ ENTRY main {
240252
ParseAndReturnVerifiedModule(hlo_string));
241253
SetFusionMode(module.get(),
242254
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
243-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
255+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
256+
XnnGraphFusion(&alias_info_).Run(module.get()));
244257
ASSERT_TRUE(changed);
245258
EXPECT_THAT(module.get()->entry_computation()->root_instruction(),
246259
op::Fusion());
@@ -274,7 +287,8 @@ ENTRY main {
274287
ParseAndReturnVerifiedModule(hlo_string));
275288
SetFusionMode(module.get(),
276289
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
277-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
290+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
291+
XnnGraphFusion(&alias_info_).Run(module.get()));
278292
ASSERT_FALSE(changed);
279293
}
280294

@@ -299,7 +313,8 @@ ENTRY main {
299313
ParseAndReturnVerifiedModule(hlo_string));
300314
SetFusionMode(module.get(),
301315
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
302-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
316+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
317+
XnnGraphFusion(&alias_info_).Run(module.get()));
303318
ASSERT_FALSE(changed);
304319
}
305320

@@ -325,7 +340,8 @@ ENTRY main {
325340
ParseAndReturnVerifiedModule(hlo_string));
326341
SetFusionMode(module.get(),
327342
DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
328-
TF_ASSERT_OK_AND_ASSIGN(bool changed, XnnGraphFusion().Run(module.get()));
343+
TF_ASSERT_OK_AND_ASSIGN(bool changed,
344+
XnnGraphFusion(&alias_info_).Run(module.get()));
329345
ASSERT_FALSE(changed);
330346
}
331347

xla/service/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1976,7 +1976,6 @@ cc_library(
19761976
"//xla:shape_util",
19771977
"//xla:util",
19781978
"//xla/hlo/analysis:alias_info",
1979-
"//xla/hlo/analysis:hlo_dataflow_analysis",
19801979
"//xla/hlo/analysis:hlo_operand_index",
19811980
"//xla/hlo/analysis:hlo_reachability",
19821981
"//xla/hlo/ir:hlo",
@@ -2007,6 +2006,7 @@ xla_cc_test(
20072006
":instruction_fusion",
20082007
"//xla:shape_util",
20092008
"//xla:xla_data_proto_cc",
2009+
"//xla/hlo/analysis:alias_info",
20102010
"//xla/hlo/ir:hlo",
20112011
"//xla/hlo/parser:hlo_parser",
20122012
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
@@ -2093,6 +2093,7 @@ xla_cc_test(
20932093
deps = [
20942094
":fusion_node_indexing_evaluation",
20952095
":instruction_fusion",
2096+
"//xla/hlo/analysis:alias_info",
20962097
"//xla/hlo/ir:hlo",
20972098
"//xla/hlo/parser:hlo_parser",
20982099
"//xla/hlo/testlib:hlo_hardware_independent_test_base",

xla/service/cpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ xla_cc_test(
10751075
"//xla:literal_util",
10761076
"//xla:shape_util",
10771077
"//xla:xla_data_proto_cc",
1078+
"//xla/hlo/analysis:alias_info",
10781079
"//xla/hlo/ir:hlo",
10791080
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
10801081
"//xla/hlo/utils:hlo_matchers",

xla/service/cpu/cpu_compiler.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,14 +1004,16 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
10041004
TF_RETURN_IF_ERROR(lib_pipeline.Run(module).status());
10051005
}
10061006

1007+
AliasInfo alias_info;
10071008
if (debug_options.xla_cpu_experimental_xnn_graph_fusion_mode() !=
10081009
DebugOptions::XNN_GRAPH_FUSION_MODE_DISABLED) {
1009-
pipeline.AddPass<XnnGraphFusion>();
1010+
pipeline.AddPass<XnnGraphFusion>(&alias_info);
10101011
}
10111012

10121013
bool use_multi_output_fusion =
10131014
options::UseMultiOutputFusion(module->config());
10141015
pipeline.AddPass<CpuInstructionFusion>(
1016+
&alias_info,
10151017
/*may_duplicate=*/!use_multi_output_fusion);
10161018

10171019
if (is_fusion_emitters) {
@@ -1020,7 +1022,6 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
10201022
pipeline.AddPass<FusionWrapper>(use_experimental_loop_fusion);
10211023
}
10221024

1023-
AliasInfo alias_info;
10241025
if (use_multi_output_fusion) {
10251026
pipeline.AddPass<CpuMultiOutputFusion>(&alias_info);
10261027
pipeline.AddPass<TupleSimplifier>();

xla/service/cpu/cpu_instruction_fusion.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ namespace cpu {
3131

3232
class CpuInstructionFusion : public InstructionFusion {
3333
public:
34-
explicit CpuInstructionFusion(bool may_duplicate = true)
35-
: InstructionFusion(CpuInstructionFusion::IsExpensive, may_duplicate) {}
34+
explicit CpuInstructionFusion(const AliasInfo* alias_info,
35+
bool may_duplicate = true)
36+
: InstructionFusion(CpuInstructionFusion::IsExpensive, alias_info,
37+
may_duplicate) {}
3638
~CpuInstructionFusion() override = default;
3739

3840
// Returns the threshold for a constant to be considered a large constant.

0 commit comments

Comments
 (0)