Skip to content

Commit

Permalink
universal-ckp: support llama model (bigscience-workshop#287)
Browse files Browse the repository at this point in the history
Modify universal checkpoint parameter patterns based on the specific model
configuration. This commit adds support for llama family of models.

Signed-off-by: Moshe Island <[email protected]>
Co-authored-by: Moshe Island <[email protected]>
  • Loading branch information
mosheisland and misland-habana authored Nov 16, 2023
1 parent 37050b8 commit 2348eed
Showing 1 changed file with 75 additions and 22 deletions.
97 changes: 75 additions & 22 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PIPELINE_REPLICATED_PARAMETER_PATTERNS,
TP_REPLICATED_PARAMETER_PATTERNS,
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
)
DS_UNIVERSAL_CHECKPOINT_INFO = True
except ImportError:
Expand Down Expand Up @@ -338,36 +339,88 @@ def _logits_helper(embedding, lm_output):
activation_checkpoint_interval=interval,
partition_method='type:transformer')

@staticmethod
def _get_vocab_param_patterns():
args = get_args()
if args.untie_embeddings_and_output_weights:
patterns = [
r"\d+.word_embeddings.weight",
r"\d+.lm_head.weight"
]
else:
patterns = [
r"tied_modules.embed.word_embeddings.weight"
]
return patterns

def _get_pp_replicated_param_patterns(self):
args = get_args()
if args.untie_embeddings_and_output_weights:
return []
patterns = self._get_vocab_param_patterns()
if args.add_position_embedding:
patterns.append(r"tied_modules.embed.position_embeddings.weight")
return patterns

@staticmethod
def _get_tp_replicated_param_patterns():
args = get_args()
patterns = [
r"\d+.input_layernorm.weight",
r"\d+.post_attention_layernorm.weight",
r"\d+.weight",
]
if args.add_position_embedding:
patterns.append(r"tied_modules.embed.position_embeddings.weight")
if args.add_bias_linear:
patterns.extend([
r"\d+.self_attention.dense.bias",
r"\d+.mlp.dense_4h_to_h.bias",
])
if args.normalization == 'layernorm':
patterns.extend([
r"\d+.input_layernorm.bias",
r"\d+.post_attention_layernorm.bias",
r"\d+.bias",
])
return patterns

@staticmethod
def _get_row_parallel_param_patterns():
return [
r"\d+.mlp.dense_4h_to_h.weight",
r"\d+.self_attention.dense.weight",
]

@staticmethod
def _get_swiglu_col_parallel_param_patterns():
args = get_args()
if not args.swiglu:
return []
patterns = [
r"\d+.mlp.dense_h_to_4h.weight",
]
if args.add_bias_linear:
patterns.append(r"\d+.mlp.dense_h_to_4h.bias")
return patterns


def universal_checkpoint_info(self):
info = dict()
if DS_UNIVERSAL_CHECKPOINT_INFO:
# Vocabulary parameters (embeddings) that require special handling due to padding.
info[VOCABULARY_PARAMETER_PATTERNS] = [
r"tied_modules.embed.word_embeddings.weight"
]
info[VOCABULARY_PARAMETER_PATTERNS] = self._get_vocab_param_patterns()

# Replicated (shared) parameters on the pipeline dimension
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = [
r"tied_modules.embed.word_embeddings.weight",
r"tied_modules.embed.position_embeddings.weight"
]
info[PIPELINE_REPLICATED_PARAMETER_PATTERNS] = self._get_pp_replicated_param_patterns()

# Parameter slices that should be averaged not concatenated.
info[TP_REPLICATED_PARAMETER_PATTERNS] = [
r"tied_modules.embed.position_embeddings.weight",
r"\d+.input_layernorm.weight",
r"\d+.input_layernorm.bias",
r"\d+.post_attention_layernorm.weight",
r"\d+.post_attention_layernorm.bias",
r"\d+.self_attention.dense.bias",
r"\d+.mlp.dense_4h_to_h.bias",
r"\d+.weight",
r"\d+.bias",
]
info[TP_REPLICATED_PARAMETER_PATTERNS] = self._get_tp_replicated_param_patterns()

# Parameter that are sliced on the row dimension
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = [
r"\d+.mlp.dense_4h_to_h.weight",
r"\d+.self_attention.dense.weight",
]
info[PARAMETER_WITH_ROW_PARALLELISM_PATTERNS] = self._get_row_parallel_param_patterns()

# SWIGLU parameters are first sliced on dim=0 to tp slices
# Then, each tp slice is chunked into 2 to create the linear layers L1, L2 used for silu(L1(x)) * L2(x))
info[PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0] = self._get_swiglu_col_parallel_param_patterns()
return info

0 comments on commit 2348eed

Please sign in to comment.