1
+ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # See LICENSE for license information.
4
+
5
+ import os
6
+ import re
7
+ import gc
8
+ from contextlib import contextmanager
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ import transformer_engine as te
14
+ from transformer_engine .pytorch .attention import RotaryPositionEmbedding
15
+ from transformer_engine .pytorch .fp8 import fp8_model_init
16
+
17
+ import transformers
18
+ from transformers .models .llama .modeling_llama import LlamaModel , LlamaForCausalLM , LlamaRMSNorm , LlamaConfig
19
+ from transformers .modeling_utils import _add_variant , load_state_dict , _load_state_dict_into_model
20
+ from transformers .utils import WEIGHTS_INDEX_NAME
21
+ from transformers .utils .hub import get_checkpoint_shard_files
22
+
23
+ @contextmanager
24
+ def replace_decoder (te_decodder_cls ):
25
+ """
26
+ Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
27
+ """
28
+ original_llama_decoder_cls = transformers .models .llama .modeling_llama .LlamaDecoderLayer
29
+ transformers .models .llama .modeling_llama .LlamaDecoderLayer = te_decodder_cls
30
+ try :
31
+ yield
32
+ finally :
33
+ transformers .models .llama .modeling_llama .LlamaDecoderLayer = original_llama_decoder_cls
34
+
35
+
36
+ class TELlamaDecoderLayer (te .pytorch .TransformerLayer ):
37
+ """
38
+ Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
39
+ similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
40
+
41
+ Args:
42
+ config: LlamaConfig
43
+ args: positional args (for compatibility with `LlamaDecoderLayer`)
44
+ kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
45
+ """
46
+ def __init__ (self , config , * args , ** kwargs ):
47
+ super ().__init__ (
48
+ hidden_size = config .hidden_size ,
49
+ ffn_hidden_size = config .intermediate_size ,
50
+ num_attention_heads = config .num_attention_heads ,
51
+ bias = False ,
52
+ layernorm_epsilon = config .rms_norm_eps ,
53
+ hidden_dropout = 0 ,
54
+ attention_dropout = 0 ,
55
+ fuse_qkv_params = False ,
56
+ normalization = "RMSNorm" ,
57
+ activation = "swiglu" ,
58
+ attn_input_format = "bshd" ,
59
+ )
60
+ te_rope = RotaryPositionEmbedding (config .hidden_size // config .num_attention_heads )
61
+ self .te_rope_emb = te_rope (max_seq_len = config .max_position_embeddings ).cuda ()
62
+
63
+ def forward (self ,
64
+ hidden_states ,
65
+ * args ,
66
+ attention_mask ,
67
+ ** kwargs ):
68
+ """
69
+ Custom forward to make sure we only pass relevant arguments to the
70
+ forward pass of the `TransformerLayer`. Also, make sure the output
71
+ format matches the output of the HF's `LlamaDecoderLayer`.
72
+ """
73
+ return (super ().forward (hidden_states , attention_mask = attention_mask , rotary_pos_emb = self .te_rope_emb ),)
74
+
75
+
76
+ class TELlamaForCausalLM :
77
+ """
78
+ Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
79
+ class is monkey-patched with `TELlamaDecoderLayer` class before
80
+ initializing the causal LM with `LlamaForCausalLM`.
81
+
82
+ Args:
83
+ config: LlamaConfig
84
+ """
85
+
86
+ def __new__ (cls , config : LlamaConfig ):
87
+ with replace_decoder (te_decodder_cls = TELlamaDecoderLayer ):
88
+ llama_for_causal_lm = LlamaForCausalLM (config )
89
+ return llama_for_causal_lm
90
+
91
+ @classmethod
92
+ def from_pretrained_local (cls , pretrained_model_name_or_path , * args , config , ** kwargs ):
93
+ """
94
+ Custom method adapted from `from_pretrained` method in HuggingFace
95
+ Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
96
+ """
97
+ vanilla_model = cls (config ).to (kwargs ['torch_dtype' ])
98
+ is_local = os .path .isdir (pretrained_model_name_or_path )
99
+ subfolder = ""
100
+ variant = None
101
+ if os .path .isfile (
102
+ os .path .join (pretrained_model_name_or_path , subfolder , _add_variant (WEIGHTS_INDEX_NAME , variant ))
103
+ ):
104
+ # Load from a sharded PyTorch checkpoint
105
+ archive_file = os .path .join (
106
+ pretrained_model_name_or_path , subfolder , _add_variant (WEIGHTS_INDEX_NAME , variant )
107
+ )
108
+ is_sharded = True
109
+ else :
110
+ raise AssertionError ("Only sharded PyTorch ckpt format supported at the moment" )
111
+
112
+
113
+ resolved_archive_file , sharded_metadata = get_checkpoint_shard_files (
114
+ pretrained_model_name_or_path ,
115
+ archive_file ,
116
+ )
117
+
118
+ # If the checkpoint is not sharded, it's a trivial sharding case
119
+ if not is_sharded :
120
+ assert not isinstance (resolved_archive_file , list )
121
+ resolved_archive_file = [resolved_archive_file ]
122
+
123
+ error_msgs = []
124
+ for shard_file in resolved_archive_file :
125
+ state_dict = load_state_dict (shard_file )
126
+ replaced_layers = replace_params (state_dict , vanilla_model .state_dict ())
127
+
128
+ error_msgs += _load_state_dict_into_model (vanilla_model , state_dict , start_prefix = "" )
129
+
130
+ # Force mem release. Taken from huggingface code
131
+ del state_dict
132
+ gc .collect ()
133
+
134
+ return vanilla_model
135
+
136
+ def replace_params (hf_state_dict , te_state_dict ):
137
+ # collect all layer prefixes to update
138
+ all_layer_prefixes = set ()
139
+ for param_key in hf_state_dict .keys ():
140
+ layer_prefix_pat = 'model.layers.\d+.'
141
+ m = re .match (layer_prefix_pat , param_key )
142
+ if m is not None :
143
+ all_layer_prefixes .add (m .group ())
144
+
145
+ for layer_prefix in all_layer_prefixes :
146
+ # When loading weights into models with less number of layers, skip the
147
+ # copy if the corresponding layer doesn't exist in TE model
148
+ if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict :
149
+ te_state_dict [layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' ].data [:] = hf_state_dict [layer_prefix + 'input_layernorm.weight' ].data [:]
150
+
151
+ if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict :
152
+ te_state_dict [layer_prefix + 'self_attention.layernorm_qkv.query_weight' ].data [:] = hf_state_dict [layer_prefix + 'self_attn.q_proj.weight' ].data [:]
153
+
154
+ if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict :
155
+ te_state_dict [layer_prefix + 'self_attention.layernorm_qkv.key_weight' ].data [:] = hf_state_dict [layer_prefix + 'self_attn.k_proj.weight' ].data [:]
156
+
157
+ if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict :
158
+ te_state_dict [layer_prefix + 'self_attention.layernorm_qkv.value_weight' ].data [:] = hf_state_dict [layer_prefix + 'self_attn.v_proj.weight' ].data [:]
159
+
160
+ if layer_prefix + 'self_attention.proj.weight' in te_state_dict :
161
+ te_state_dict [layer_prefix + 'self_attention.proj.weight' ].data [:] = hf_state_dict [layer_prefix + 'self_attn.o_proj.weight' ].data [:]
162
+
163
+ if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict :
164
+ te_state_dict [layer_prefix + 'layernorm_mlp.layer_norm_weight' ].data [:] = hf_state_dict [layer_prefix + 'post_attention_layernorm.weight' ].data [:]
165
+
166
+ if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict :
167
+ te_state_dict [layer_prefix + 'layernorm_mlp.fc1_weight' ].data [:] = torch .cat ((hf_state_dict [layer_prefix + 'mlp.gate_proj.weight' ].data [:], hf_state_dict [layer_prefix + 'mlp.up_proj.weight' ].data [:]), dim = 0 )
168
+
169
+ if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict :
170
+ te_state_dict [layer_prefix + 'layernorm_mlp.fc2_weight' ].data [:] = hf_state_dict [layer_prefix + 'mlp.down_proj.weight' ].data [:]
171
+
172
+ return all_layer_prefixes
0 commit comments