@@ -127,6 +127,10 @@ def parse_target(self, tgt_prop) -> dict:
127127        dev_prop ['has_subgroup_2d_block_io' ] =  tgt_prop .get ('has_subgroup_2d_block_io' , False )
128128        dev_prop ['has_bfloat16_conversions' ] =  tgt_prop .get ('has_bfloat16_conversions' , True )
129129
130+         if  self .device_arch  in  self .device_props :
131+             dev_prop .update (self .device_props [self .device_arch ])
132+             return  dev_prop 
133+ 
130134        return  dev_prop 
131135
132136    def  parse_options (self , opts ) ->  Any :
@@ -202,85 +206,27 @@ def get_split_barrier_scope(opt):
202206            split_barriers_scope  =  intel .SplitBarrierScope .Subgroup 
203207        return  split_barriers_scope 
204208
205-     @classmethod  
206-     def  create_pass_manager (cls , context , add_passes = []):
207-         pm  =  ir .pass_manager (context )
208-         pm .enable_debug ()
209-         for  p  in  add_passes :
210-             if  p  is  None :
211-                 continue 
212-             elif  isinstance (p , tuple ):
213-                 p [0 ](pm , * p [1 :])
214-             else :
215-                 p (pm )
216-         return  pm 
217- 
218-     @classmethod  
219-     def  get_ttir_passes (cls , opt ):
220-         return  [
221-             passes .common .add_inliner ,
222-             intel .passes .ttir .add_convert_tdesc_to_block_pointer ,
223-             passes .ttir .add_rewrite_tensor_descriptor_to_pointer ,
224-             passes .common .add_cse ,
225-             passes .common .add_licm ,
226-             intel .passes .ttir .add_remove_masks ,
227-             intel .passes .ttir .add_fuse_reshape ,
228-             passes .common .add_canonicalizer ,
229-             passes .ttir .add_combine ,
230-             passes .ttir .add_reorder_broadcast ,
231-             passes .common .add_cse ,
232-             passes .common .add_symbol_dce ,
233-             passes .ttir .add_loop_unroll ,
234-         ]
235- 
236209    @classmethod  
237210    @track  
238211    def  make_ttir (cls , mod , metadata , opt ):
239-         pm  =  cls .create_pass_manager (mod .context , cls .get_ttir_passes (opt ))
212+         pm  =  ir .pass_manager (mod .context )
213+         pm .enable_debug ()
214+         passes .common .add_inliner (pm )
215+         intel .passes .ttir .add_convert_tdesc_to_block_pointer (pm )
216+         passes .ttir .add_rewrite_tensor_descriptor_to_pointer (pm )
217+         passes .common .add_cse (pm )
218+         passes .common .add_licm (pm )
219+         intel .passes .ttir .add_remove_masks (pm )
220+         intel .passes .ttir .add_fuse_reshape (pm )
221+         passes .common .add_canonicalizer (pm )
222+         passes .ttir .add_combine (pm )
223+         passes .ttir .add_reorder_broadcast (pm )
224+         passes .common .add_cse (pm )
225+         passes .common .add_symbol_dce (pm )
226+         passes .ttir .add_loop_unroll (pm )
240227        pm .run (mod , 'make_ttir' )
241228        return  mod 
242229
243-     @classmethod  
244-     def  get_ttgir_passes (cls , opt ):
245-         # fmt: off 
246-         return  [
247-             (passes .ttir .add_convert_to_ttgpuir , "xpu" , opt .num_warps , opt .warp_size , opt .num_ctas ),
248-             # optimize TTGIR 
249-             intel .passes .ttgpuir .add_coalesce ,
250-             intel .passes .ttgpuir .add_remove_layout_conversions ,
251- 
252-             intel .passes .ttgpuir .add_accelerate_matmul ,
253-             intel .passes .ttgpuir .add_materialize_block_pointer ,
254-             intel .passes .ttgpuir .add_remove_layout_conversions ,
255-             intel .passes .ttgpuir .add_optimize_dot_operands ,
256-             (intel .passes .ttgpuir .add_pipeline , opt .num_stages , cls .get_split_barrier_scope (opt )),
257- 
258-             intel .passes .ttgpuir .add_reduce_variable_liveness  if  opt .reduce_variable_liveness  else  None ,
259- 
260-             passes .ttgpuir .add_fuse_nested_loops ,
261- 
262-             passes .common .add_canonicalizer ,
263-             passes .ttir .add_triton_licm ,
264-             passes .common .add_canonicalizer ,
265-             passes .ttgpuir .add_combine_tensor_select_and_if ,
266- 
267-             passes .ttgpuir .add_optimize_thread_locality ,
268-             (passes .ttgpuir .add_optimize_dot_operands , True ),
269-             passes .common .add_cse ,
270-             passes .ttgpuir .add_prefetch ,
271-             (passes .ttgpuir .add_optimize_dot_operands , True ),
272-             intel .passes .ttgpuir .add_remove_layout_conversions ,
273-             intel .passes .ttgpuir .add_reduce_data_duplication ,
274-             passes .ttgpuir .add_reorder_instructions ,
275-             passes .common .add_cse ,
276-             passes .common .add_symbol_dce ,
277-             passes .common .add_sccp ,
278-             passes .common .add_canonicalizer ,
279-             intel .passes .ttgpuir .add_optimize_reduction_locality  if  knobs .intel .opt_reduction_locality  else  None ,
280-             (intel .passes .arith .add_arith_emulate_unsupported_floats , ["bf16" ], "f32" )
281-         ]
282-         # fmt: on 
283- 
284230    @classmethod  
285231    @track  
286232    def  make_ttgir (cls , mod , metadata , opt , properties ):
@@ -291,7 +237,8 @@ def make_ttgir(cls, mod, metadata, opt, properties):
291237            cluster_info .clusterDimZ  =  opt .cluster_dims [2 ]
292238
293239        # Annotate module with information required by subsequent transformations. 
294-         pm  =  cls .create_pass_manager (mod .context )
240+         pm  =  ir .pass_manager (mod .context )
241+         pm .enable_debug ()
295242        module_opts  =  intel .passes .ttgpuir .AnnotateModuleOptions ()
296243        cls .annotate_module (module_opts , properties , opt )
297244        intel .passes .ttgpuir .add_triton_annotate_module (pm , module_opts )
@@ -301,7 +248,44 @@ def make_ttgir(cls, mod, metadata, opt, properties):
301248        opt .warp_size  =  intel .get_threads_per_warp (mod )
302249        cls .validate_options (opt , properties )
303250
304-         pm  =  cls .create_pass_manager (mod .context , cls .get_ttgir_passes (opt ))
251+         pm  =  ir .pass_manager (mod .context )
252+         pm .enable_debug ()
253+         passes .ttir .add_convert_to_ttgpuir (pm , "xpu" , opt .num_warps , opt .warp_size , opt .num_ctas )
254+         # optimize TTGIR 
255+         intel .passes .ttgpuir .add_coalesce (pm )
256+         intel .passes .ttgpuir .add_remove_layout_conversions (pm )
257+ 
258+         intel .passes .ttgpuir .add_accelerate_matmul (pm )
259+         intel .passes .ttgpuir .add_materialize_block_pointer (pm )
260+         intel .passes .ttgpuir .add_remove_layout_conversions (pm )
261+         intel .passes .ttgpuir .add_optimize_dot_operands (pm )
262+         intel .passes .ttgpuir .add_pipeline (pm , opt .num_stages , XPUBackend .get_split_barrier_scope (opt ))
263+ 
264+         if  (opt .reduce_variable_liveness ):
265+             intel .passes .ttgpuir .add_reduce_variable_liveness (pm )
266+ 
267+         passes .ttgpuir .add_fuse_nested_loops (pm )
268+ 
269+         passes .common .add_canonicalizer (pm )
270+         passes .ttir .add_triton_licm (pm )
271+         passes .common .add_canonicalizer (pm )
272+         passes .ttgpuir .add_combine_tensor_select_and_if (pm )
273+ 
274+         passes .ttgpuir .add_optimize_thread_locality (pm )
275+         passes .ttgpuir .add_optimize_dot_operands (pm , True )
276+         passes .common .add_cse (pm )
277+         passes .ttgpuir .add_prefetch (pm )
278+         passes .ttgpuir .add_optimize_dot_operands (pm , True )
279+         intel .passes .ttgpuir .add_remove_layout_conversions (pm )
280+         intel .passes .ttgpuir .add_reduce_data_duplication (pm )
281+         passes .ttgpuir .add_reorder_instructions (pm )
282+         passes .common .add_cse (pm )
283+         passes .common .add_symbol_dce (pm )
284+         passes .common .add_sccp (pm )
285+         passes .common .add_canonicalizer (pm )
286+         if  knobs .intel .opt_reduction_locality :
287+             intel .passes .ttgpuir .add_optimize_reduction_locality (pm )
288+         intel .passes .arith .add_arith_emulate_unsupported_floats (pm , ["bf16" ], "f32" )
305289        pm .run (mod , 'make_ttgir' )
306290        metadata ["cluster_dims" ] =  (cluster_info .clusterDimX , cluster_info .clusterDimY , cluster_info .clusterDimZ )
307291        return  mod 
@@ -322,31 +306,6 @@ def gluon_to_ttgir(self, src, metadata, options):
322306        metadata ["tensordesc_meta" ] =  mod .get_tensordesc_metadata ()
323307        return  mod 
324308
325-     @classmethod  
326-     def  get_llir_passes (cls , opt , mod ):
327-         # fmt: off 
328-         return  [
329-             passes .convert .add_scf_to_cf ,
330-             passes .gluon .add_inliner ,
331-             passes .convert .add_index_to_llvmir ,
332-             intel .passes .ttgpuir .add_allocate_shared_memory ,
333-             passes .ttgpuir .add_allocate_global_scratch_memory ,
334-             # instrumentation point here so we can override IRs above (e.g., ttir and ttgir) 
335-             lambda  pm : cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context ) if  cls .instrumentation  else  None ,
336-             intel .passes .ttgpuir .add_to_llvmir ,
337-             intel .passes .ttgpuir .add_gen_to_llvm ,
338-             passes .common .add_canonicalizer ,
339-             intel .passes .ttgpuir .add_rewrite_stack_ptr ,
340-             passes .common .add_cse ,
341-             passes .convert .add_arith_to_llvmir ,
342-             passes .common .add_canonicalizer ,
343-             passes .common .add_cse ,
344-             passes .common .add_symbol_dce ,
345-             None  if  knobs .compilation .disable_line_info  or  knobs .compilation .dump_ir_extract_di_local_variables  else  passes .llvmir .add_di_scope ,
346-             lambda  pm : cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context ) if  cls .instrumentation  else  None ,
347-         ]
348-         # fmt: on 
349- 
350309    @classmethod  
351310    def  optimize_llvm_mod (cls , llvm_mod , options ):
352311        intel .set_spv_target_triple (llvm_mod )
@@ -358,21 +317,50 @@ def optimize_llvm_mod(cls, llvm_mod, options):
358317    def  make_llir (cls , src , metadata , options ):
359318        mod  =  src 
360319        # TritonGPU -> LLVM-IR (MLIR) 
361-         pm  =  cls .create_pass_manager (mod .context , cls .get_llir_passes (options , mod ))
320+         pm  =  ir .pass_manager (mod .context )
321+         pm .enable_debug ()
322+ 
323+         passes .convert .add_scf_to_cf (pm )
324+         passes .gluon .add_inliner (pm )
325+         passes .convert .add_index_to_llvmir (pm )
326+         intel .passes .ttgpuir .add_allocate_shared_memory (pm )
327+         passes .ttgpuir .add_allocate_global_scratch_memory (pm )
328+         # instrumentation point here so we can override IRs above (e.g., ttir and ttgir) 
329+         if  cls .instrumentation :
330+             cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context )
331+         intel .passes .ttgpuir .add_to_llvmir (pm )
332+         intel .passes .ttgpuir .add_gen_to_llvm (pm )
333+         passes .common .add_canonicalizer (pm )
334+         intel .passes .ttgpuir .add_rewrite_stack_ptr (pm )
335+         passes .common .add_cse (pm )
336+         passes .convert .add_arith_to_llvmir (pm )
337+         passes .common .add_canonicalizer (pm )
338+         passes .common .add_cse (pm )
339+         passes .common .add_symbol_dce (pm )
340+ 
341+         if  not  knobs .compilation .disable_line_info  and  not  knobs .compilation .dump_ir_extract_di_local_variables :
342+             passes .llvmir .add_di_scope (pm )
343+ 
344+         if  cls .instrumentation :
345+             cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context )
362346        pm .run (mod , 'make_llir' )
363347
364348        if  knobs .compilation .dump_ir_extract_di_local_variables :
365349            # comments below on why separate it 
366350            if  not  knobs .compilation .disable_line_info :
367-                 pm  =  cls .create_pass_manager (mod .context , [passes .llvmir .add_di_scope ])
351+                 pm  =  ir .pass_manager (mod .context )
352+                 pm .enable_debug ()
353+                 passes .llvmir .add_di_scope (pm )
368354                pm .run (mod , 'make_llir.disable_line_info' )
369355
370356            # insert dbg intrinsic with several DI Attribute including source 
371357            # var name and type info note: unknown reason for now, but this 
372358            # pass and add_di_scope has to be run separately, otherwise if we 
373359            # put them into previous pipline, it trigger a segmentfault without 
374360            # any error message; could be due to a bug in mlir or pybind11 
375-             pm  =  cls .create_pass_manager (mod .context , [passes .llvmir .add_di_local_variable ])
361+             pm  =  ir .pass_manager (mod .context )
362+             pm .enable_debug ()
363+             passes .llvmir .add_di_local_variable (pm )
376364            pm .run (mod , 'make_llir.dump_ir_extract_di_local_variables' )
377365
378366        # LLVM-IR (MLIR) -> LLVM-IR (LLVM) 
0 commit comments