Skip to content

Commit 54c314b

Browse files
authored
fix qkv ordering of GQA when tp_num_maping > 1 (#277)
1 parent 4ad5912 commit 54c314b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

chatlearn/synchronizer/megatron_vllm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,9 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_division):
338338
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
339339
to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
340340
# pylint: disable=too-many-nested-blocks
341-
if "attention.query_key_value" in name or \
341+
if ("attention.query_key_value" in name or \
342342
"self_attention.query_key_value" in name or \
343-
"self_attention.linear_qkv" in name:
343+
"self_attention.linear_qkv" in name) and 'norm' not in name:
344344
src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
345345
dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"]
346346
heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size
@@ -448,7 +448,7 @@ class MegatronVllmQWen2MCoreSync(MegatronVllmSync):
448448
"""qwen2-dense-mcore"""
449449

450450
def map_src_to_dst(self, src_names, src_pipe_layer_offset):
451-
self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
451+
self._to_fix_qkv_ordering_func = split_attn_state
452452
return MCore2Qwen2SyncMap(src_names, src_pipe_layer_offset)
453453

454454
class MegatronVllmLlamaSync(MegatronVllmSync):
@@ -544,5 +544,5 @@ def transform_parameters(self, params_to_sync_list):
544544
return params_to_sync_list
545545

546546
def map_src_to_dst(self, src_names, src_pipe_layer_offset):
547-
self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
547+
self._to_fix_qkv_ordering_func = split_attn_state
548548
return MCore2MoonlightSyncMap(src_names, src_pipe_layer_offset)

0 commit comments

Comments
 (0)