@@ -133,27 +133,35 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
133133    support_block_io  =  torch .xpu .get_device_capability ()['has_subgroup_2d_block_io' ]
134134
135135    if  block_ptr :
136+         load_ops  =  f""" 
137+             %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M }  x{ N }  x{ ty }  , #layout>> 
138+             %store_val = tt.load %src_ptr {{boundaryCheck = array<i32: 0, 1>, padding = 1 : i32}} : !tt.ptr<tensor<{ M }  x{ N }  x{ ty }  , #layout>> 
139+             """ 
136140        store_ops  =  f""" 
137-             %M_i64 = arith.constant { M }   : i64 
138-             %N_i64 = arith.constant { N }   : i64 
139-             %c1_i64 = arith.constant 1 : i64 
140-             %c0_i32 = arith.constant 0 : i32 
141- 
142-             %blk_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M }  x{ N }  x{ ty }  , #layout>> 
143-             tt.store %blk_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M }  x{ N }  x{ ty }  , #layout>> 
141+             %dst_ptr = tt.make_tensor_ptr %dst, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{ M }  x{ N }  x{ ty }  , #layout>> 
142+             tt.store %dst_ptr, %store_val {{ttig.block_io = "row_major", boundaryCheck = array<i32: 0, 1>}} : !tt.ptr<tensor<{ M }  x{ N }  x{ ty }  , #layout>> 
144143            """ 
145144    else :
145+         load_ops  =  f""" 
146+             %src_base = tt.splat %src : !tt.ptr<{ ty }  > -> tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
147+             %src_ptr = tt.addptr %src_base, %8 : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout>, tensor<{ M }  x{ N }  xi32, #layout> 
148+             %store_val = tt.load %src_ptr {{ttig.block_io = "row_major"}} : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
149+             """ 
146150        store_ops  =  f""" 
147-             %12  = tt.splat %dst : !tt.ptr<{ ty }  > -> tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
148-             %13  = tt.addptr %12 , %8 : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout>, tensor<{ M }  x{ N }  xi32, #layout> 
149-             tt.store %13 , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
151+             %dst_base  = tt.splat %dst : !tt.ptr<{ ty }  > -> tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
152+             %dst_ptr  = tt.addptr %dst_base , %8 : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout>, tensor<{ M }  x{ N }  xi32, #layout> 
153+             tt.store %dst_ptr , %store_val {{ttig.block_io = "row_major"}} : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
150154            """ 
151155
152156    ir  =  f""" 
153157    #layout = { layout }  
154158    module attributes {{{ "ttig.support_sg_2d_block,"  if  support_block_io  else  "" }   "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = { num_warps }   : i32, ttg.target = "xpu", "ttg.threads-per-warp" = { threads_per_warp }   : i32}} {{ 
155159        tt.func public @block_store(%src: !tt.ptr<{ ty }  > {{tt.divisibility = 16 : i32}}, %dst: !tt.ptr<{ ty }  > {{tt.divisibility = 16 : i32}}) {{ 
156160
161+             %M_i64 = arith.constant { M }   : i64 
162+             %N_i64 = arith.constant { N }   : i64 
163+             %c1_i64 = arith.constant 1 : i64 
164+             %c0_i32 = arith.constant 0 : i32 
157165            %stride = arith.constant dense<{ N }  > : tensor<{ M }  x1xi32, #layout> 
158166            %1 = tt.make_range {{end = { M }   : i32, start = 0 : i32}} : tensor<{ M }  xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> 
159167            %2 = tt.expand_dims %1 {{axis = 1 : i32}} : tensor<{ M }  xi32, #ttg.slice<{{dim = 1, parent = #layout}}>> -> tensor<{ M }  x1xi32, #layout> 
@@ -163,9 +171,7 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
163171            %6 = tt.broadcast %3 : tensor<{ M }  x1xi32, #layout> -> tensor<{ M }  x{ N }  xi32, #layout> 
164172            %7 = tt.broadcast %5 : tensor<1x{ N }  xi32, #layout> -> tensor<{ M }  x{ N }  xi32, #layout> 
165173            %8 = arith.addi %6, %7 : tensor<{ M }  x{ N }  xi32, #layout> 
166-             %9 = tt.splat %src : !tt.ptr<{ ty }  > -> tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
167-             %10 = tt.addptr %9, %8 : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout>, tensor<{ M }  x{ N }  xi32, #layout> 
168-             %store_val = tt.load %10 : tensor<{ M }  x{ N }  x!tt.ptr<{ ty }  >, #layout> 
174+             { load_ops }  
169175
170176            { store_ops }  
171177
@@ -191,3 +197,5 @@ def test_block_store(M, N, dtype_str, layout, block_ptr, device, tmp_path: pathl
191197
192198    if  support_block_io :
193199        assert  'spirv_Subgroup2DBlockStoreINTEL'  in  kernel .asm ['llir' ] or  'GenISA.LSC2DBlockWrite'  in  kernel .asm ['llir' ]
200+         if  not  block_ptr :
201+             assert  'spirv_Subgroup2DBlockLoad'  in  kernel .asm ['llir' ] or  'GenISA.LSC2DBlockRead'  in  kernel .asm ['llir' ]
0 commit comments