@@ -338,9 +338,9 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_division):
338338 # Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
339339 to_fix_qkv_ordering_dict = self .sync_map .to_fix_qkv_ordering_dict
340340 # pylint: disable=too-many-nested-blocks
341- if "attention.query_key_value" in name or \
341+ if ( "attention.query_key_value" in name or \
342342 "self_attention.query_key_value" in name or \
343- "self_attention.linear_qkv" in name :
343+ "self_attention.linear_qkv" in name ) and 'norm' not in name :
344344 src_tp_size = self .src_module_args .args_dict ["tensor_model_parallel_size" ]
345345 dst_tp_size = self .dst_module_args .args_dict ["tensor_model_parallel_size" ]
346346 heads = self .src_module_args .args_dict ["num_attention_heads" ] // src_tp_size
@@ -448,7 +448,7 @@ class MegatronVllmQWen2MCoreSync(MegatronVllmSync):
448448 """qwen2-dense-mcore"""
449449
450450 def map_src_to_dst (self , src_names , src_pipe_layer_offset ):
451- self ._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
451+ self ._to_fix_qkv_ordering_func = split_attn_state
452452 return MCore2Qwen2SyncMap (src_names , src_pipe_layer_offset )
453453
454454class MegatronVllmLlamaSync (MegatronVllmSync ):
@@ -544,5 +544,5 @@ def transform_parameters(self, params_to_sync_list):
544544 return params_to_sync_list
545545
546546 def map_src_to_dst (self , src_names , src_pipe_layer_offset ):
547- self ._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
547+ self ._to_fix_qkv_ordering_func = split_attn_state
548548 return MCore2MoonlightSyncMap (src_names , src_pipe_layer_offset )
0 commit comments