Skip to content

Commit 2348eed

Browse files
mosheislandMoshe Island
andauthored
universal-ckp: support llama model (bigscience-workshop#287)
Modify universal checkpoint parameter patterns based on the specific model configuration. This commit adds support for llama family of models. Signed-off-by: Moshe Island <[email protected]> Co-authored-by: Moshe Island <[email protected]>
1 parent 37050b8 commit 2348eed

File tree

1 file changed

+75
-22
lines changed

1 file changed

+75
-22
lines changed

megatron/model/gpt_model.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
3131
TP_REPLICATED_PARAMETER_PATTERNS,
3232
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
33+
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
3334
)
3435
DS_UNIVERSAL_CHECKPOINT_INFO = True
3536
except ImportError:
@@ -338,36 +339,88 @@ def _logits_helper(embedding, lm_output):
338339
activation_checkpoint_interval=interval,
339340
partition_method='type:transformer')
340341

342+
@staticmethod
343+
def _get_vocab_param_patterns():
344+
args = get_args()
345+
if args.untie_embeddings_and_output_weights:
346+
patterns = [
347+
r"\d+.word_embeddings.weight",
348+
r"\d+.lm_head.weight"
349+
]
350+
else:
351+
patterns = [
352+
r"tied_modules.embed.word_embeddings.weight"
353+
]
354+
return patterns
355+
356+
def _get_pp_replicated_param_patterns(self):
357+
args = get_args()
358+
if args.untie_embeddings_and_output_weights:
359+
return []
360+
patterns = self._get_vocab_param_patterns()
361+
if args.add_position_embedding:
362+
patterns.append(r"tied_modules.embed.position_embeddings.weight")
363+
return patterns
364+
365+
@staticmethod
366+
def _get_tp_replicated_param_patterns():
367+
args = get_args()
368+
patterns = [
369+
r"\d+.input_layernorm.weight",
370+
r"\d+.post_attention_layernorm.weight",
371+
r"\d+.weight",
372+
]
373+
if args.add_position_embedding:
374+
patterns.append(r"tied_modules.embed.position_embeddings.weight")
375+
if args.add_bias_linear:
376+
patterns.extend([
377+
r"\d+.self_attention.dense.bias",
378+
r"\d+.mlp.dense_4h_to_h.bias",
379+
])
380+
if args.normalization == 'layernorm':
381+
patterns.extend([
382+
r"\d+.input_layernorm.bias",
383+
r"\d+.post_attention_layernorm.bias",
384+
r"\d+.bias",
385+
])
386+
return patterns
387+
388+
@staticmethod
389+
def _get_row_parallel_param_patterns():
390+
return [
391+
r"\d+.mlp.dense_4h_to_h.weight",
392+
r"\d+.self_attention.dense.weight",
393+
]
394+
395+
@staticmethod
396+
def _get_swiglu_col_parallel_param_patterns():
397+
args = get_args()
398+
if not args.swiglu:
399+
return []
400+
patterns = [
401+
r"\d+.mlp.dense_h_to_4h.weight",
402+
]
403+
if args.add_bias_linear:
404+
patterns.append(r"\d+.mlp.dense_h_to_4h.bias")
405+
return patterns
406+
407+
341408
def universal_checkpoint_info(self):
342409
info = dict()
343410
if DS_UNIVERSAL_CHECKPOINT_INFO:
344411
# Vocabulary parameters (embeddings) that require special handling due to padding.
345-
info[VOCABULARY_PARAMETER_PATTERNS] = [
346-
r"tied_modules.embed.word_embeddings.weight"
347-
]
412+
info[VOCABULARY_PARAMETER_PATTERNS] = self._get_vocab_param_patterns()
348413

349414
# Replicated (shared) parameters on the pipeline dimension
350-
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = [
351-
r"tied_modules.embed.word_embeddings.weight",
352-
r"tied_modules.embed.position_embeddings.weight"
353-
]
415+
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = self._get_pp_replicated_param_patterns()
354416

355417
# Parameter slices that should be averaged not concatenated.
356-
info[TP_REPLICATED_PARAMETER_PATTERNS] = [
357-
r"tied_modules.embed.position_embeddings.weight",
358-
r"\d+.input_layernorm.weight",
359-
r"\d+.input_layernorm.bias",
360-
r"\d+.post_attention_layernorm.weight",
361-
r"\d+.post_attention_layernorm.bias",
362-
r"\d+.self_attention.dense.bias",
363-
r"\d+.mlp.dense_4h_to_h.bias",
364-
r"\d+.weight",
365-
r"\d+.bias",
366-
]
418+
info[TP_REPLICATED_PARAMETER_PATTERNS] = self._get_tp_replicated_param_patterns()
367419

368420
# Parameter that are sliced on the row dimension
369-
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [
370-
r"\d+.mlp.dense_4h_to_h.weight",
371-
r"\d+.self_attention.dense.weight",
372-
]
421+
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = self._get_row_parallel_param_patterns()
422+
423+
# SWIGLU parameters are first sliced on dim=0 to tp slices
424+
# Then, each tp slice is chunked into 2 to create the linear layers L1, L2 used for silu(L1(x)) * L2(x))
425+
info[PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0] = self._get_swiglu_col_parallel_param_patterns()
373426
return info

0 commit comments

Comments
 (0)