@@ -21,6 +21,7 @@ limitations under the License.
2121#include < gmock/gmock.h>
2222#include < gtest/gtest.h>
2323#include " xla/backends/cpu/xnn_support.h"
24+ #include " xla/hlo/analysis/alias_info.h"
2425#include " xla/hlo/ir/hlo_casting_utils.h"
2526#include " xla/hlo/ir/hlo_instruction.h"
2627#include " xla/hlo/ir/hlo_instructions.h"
@@ -37,7 +38,10 @@ namespace op = xla::testing::opcode_matchers;
3738namespace xla ::cpu {
3839namespace {
3940
40- using XnnGraphFusionTest = HloHardwareIndependentTestBase;
41+ class XnnGraphFusionTest : public HloHardwareIndependentTestBase {
42+ protected:
43+ AliasInfo alias_info_;
44+ };
4145
4246TEST_F (XnnGraphFusionTest, BasicFusion) {
4347 std::string hlo_string = R"(
@@ -54,7 +58,8 @@ ENTRY entry {
5458
5559 TF_ASSERT_OK_AND_ASSIGN (auto module ,
5660 ParseAndReturnVerifiedModule (hlo_string));
57- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
61+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
62+ XnnGraphFusion (&alias_info_).Run (module .get ()));
5863 ASSERT_TRUE (changed);
5964 EXPECT_THAT (module .get ()->entry_computation ()->root_instruction (),
6065 op::Fusion ());
@@ -82,7 +87,8 @@ ENTRY entry {
8287
8388 TF_ASSERT_OK_AND_ASSIGN (auto module ,
8489 ParseAndReturnVerifiedModule (hlo_string));
85- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
90+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
91+ XnnGraphFusion (&alias_info_).Run (module .get ()));
8692 ASSERT_FALSE (changed);
8793}
8894
@@ -101,7 +107,8 @@ ENTRY entry {
101107
102108 TF_ASSERT_OK_AND_ASSIGN (auto module ,
103109 ParseAndReturnVerifiedModule (hlo_string));
104- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
110+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
111+ XnnGraphFusion (&alias_info_).Run (module .get ()));
105112 ASSERT_FALSE (changed);
106113}
107114
@@ -128,7 +135,8 @@ ENTRY entry {
128135 ParseAndReturnVerifiedModule (hlo_string));
129136 SetFusionMode (module .get (),
130137 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
131- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
138+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
139+ XnnGraphFusion (&alias_info_).Run (module .get ()));
132140 ASSERT_TRUE (changed);
133141 EXPECT_THAT (module .get ()->entry_computation ()->root_instruction (),
134142 op::Fusion ());
@@ -158,7 +166,8 @@ ENTRY entry {
158166 ParseAndReturnVerifiedModule (hlo_string));
159167 SetFusionMode (module .get (),
160168 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
161- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
169+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
170+ XnnGraphFusion (&alias_info_).Run (module .get ()));
162171 ASSERT_FALSE (changed);
163172}
164173
@@ -179,7 +188,8 @@ ENTRY entry {
179188 ParseAndReturnVerifiedModule (hlo_string));
180189 SetFusionMode (module .get (),
181190 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
182- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
191+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
192+ XnnGraphFusion (&alias_info_).Run (module .get ()));
183193 ASSERT_FALSE (changed);
184194}
185195
@@ -199,7 +209,8 @@ ENTRY entry {
199209 ParseAndReturnVerifiedModule (hlo_string));
200210 SetFusionMode (module .get (),
201211 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
202- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
212+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
213+ XnnGraphFusion (&alias_info_).Run (module .get ()));
203214 ASSERT_FALSE (changed);
204215}
205216
@@ -215,7 +226,8 @@ ENTRY entry {
215226
216227 TF_ASSERT_OK_AND_ASSIGN (auto module ,
217228 ParseAndReturnVerifiedModule (hlo_string));
218- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
229+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
230+ XnnGraphFusion (&alias_info_).Run (module .get ()));
219231 ASSERT_FALSE (changed);
220232}
221233
@@ -240,7 +252,8 @@ ENTRY main {
240252 ParseAndReturnVerifiedModule (hlo_string));
241253 SetFusionMode (module .get (),
242254 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
243- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
255+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
256+ XnnGraphFusion (&alias_info_).Run (module .get ()));
244257 ASSERT_TRUE (changed);
245258 EXPECT_THAT (module .get ()->entry_computation ()->root_instruction (),
246259 op::Fusion ());
@@ -274,7 +287,8 @@ ENTRY main {
274287 ParseAndReturnVerifiedModule (hlo_string));
275288 SetFusionMode (module .get (),
276289 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
277- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
290+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
291+ XnnGraphFusion (&alias_info_).Run (module .get ()));
278292 ASSERT_FALSE (changed);
279293}
280294
@@ -299,7 +313,8 @@ ENTRY main {
299313 ParseAndReturnVerifiedModule (hlo_string));
300314 SetFusionMode (module .get (),
301315 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
302- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
316+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
317+ XnnGraphFusion (&alias_info_).Run (module .get ()));
303318 ASSERT_FALSE (changed);
304319}
305320
@@ -325,7 +340,8 @@ ENTRY main {
325340 ParseAndReturnVerifiedModule (hlo_string));
326341 SetFusionMode (module .get (),
327342 DebugOptions::XNN_GRAPH_FUSION_MODE_GREEDY_SLINKY);
328- TF_ASSERT_OK_AND_ASSIGN (bool changed, XnnGraphFusion ().Run (module .get ()));
343+ TF_ASSERT_OK_AND_ASSIGN (bool changed,
344+ XnnGraphFusion (&alias_info_).Run (module .get ()));
329345 ASSERT_FALSE (changed);
330346}
331347
0 commit comments