Skip to content

Commit 5da878d

Browse files
sudhakarsingh27ptrendx
authored andcommitted
Create a small tutorial on how to accelerate HF Llama models with Transformer-Engine (#615)
1 parent 8c9abbb commit 5da878d

14 files changed

+1059
-0
lines changed

docs/examples/te_llama/media/llama_for_causal_lm.svg

+1
Loading

docs/examples/te_llama/media/llama_zoom.svg

+1
Loading

docs/examples/te_llama/media/llamadecoderlayer.svg

+1
Loading

docs/examples/te_llama/media/model_change.svg

+1
Loading

docs/examples/te_llama/media/swiglu.svg

+1
Loading

docs/examples/te_llama/media/swiglu_te.svg

+1
Loading

docs/examples/te_llama/media/tellamadecoderlayer.svg

+1
Loading
Loading

docs/examples/te_llama/media/transformer_vs_llama.svg

+1
Loading

docs/examples/te_llama/media/weight_swap.svg

+1
Loading

docs/examples/te_llama/te_llama.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
import os
6+
import re
7+
import gc
8+
from contextlib import contextmanager
9+
10+
import torch
11+
from torch import nn
12+
13+
import transformer_engine as te
14+
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
15+
from transformer_engine.pytorch.fp8 import fp8_model_init
16+
17+
import transformers
18+
from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig
19+
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
20+
from transformers.utils import WEIGHTS_INDEX_NAME
21+
from transformers.utils.hub import get_checkpoint_shard_files
22+
23+
@contextmanager
24+
def replace_decoder(te_decodder_cls):
25+
"""
26+
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
27+
"""
28+
original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
29+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decodder_cls
30+
try:
31+
yield
32+
finally:
33+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = original_llama_decoder_cls
34+
35+
36+
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
37+
"""
38+
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
39+
similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
40+
41+
Args:
42+
config: LlamaConfig
43+
args: positional args (for compatibility with `LlamaDecoderLayer`)
44+
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
45+
"""
46+
def __init__(self, config, *args, **kwargs):
47+
super().__init__(
48+
hidden_size=config.hidden_size,
49+
ffn_hidden_size=config.intermediate_size,
50+
num_attention_heads=config.num_attention_heads,
51+
bias=False,
52+
layernorm_epsilon=config.rms_norm_eps,
53+
hidden_dropout=0,
54+
attention_dropout=0,
55+
fuse_qkv_params=False,
56+
normalization="RMSNorm",
57+
activation="swiglu",
58+
attn_input_format="bshd",
59+
)
60+
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
61+
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
62+
63+
def forward(self,
64+
hidden_states,
65+
*args,
66+
attention_mask,
67+
**kwargs):
68+
"""
69+
Custom forward to make sure we only pass relevant arguments to the
70+
forward pass of the `TransformerLayer`. Also, make sure the output
71+
format matches the output of the HF's `LlamaDecoderLayer`.
72+
"""
73+
return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)
74+
75+
76+
class TELlamaForCausalLM:
77+
"""
78+
Causal LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
79+
class is monkey-patched with `TELlamaDecoderLayer` class before
80+
initializing the causal LM with `LlamaForCausalLM`.
81+
82+
Args:
83+
config: LlamaConfig
84+
"""
85+
86+
def __new__(cls, config: LlamaConfig):
87+
with replace_decoder(te_decodder_cls=TELlamaDecoderLayer):
88+
llama_for_causal_lm = LlamaForCausalLM(config)
89+
return llama_for_causal_lm
90+
91+
@classmethod
92+
def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs):
93+
"""
94+
Custom method adapted from `from_pretrained` method in HuggingFace
95+
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
96+
"""
97+
vanilla_model = cls(config).to(kwargs['torch_dtype'])
98+
is_local = os.path.isdir(pretrained_model_name_or_path)
99+
subfolder = ""
100+
variant = None
101+
if os.path.isfile(
102+
os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
103+
):
104+
# Load from a sharded PyTorch checkpoint
105+
archive_file = os.path.join(
106+
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
107+
)
108+
is_sharded = True
109+
else:
110+
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
111+
112+
113+
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
114+
pretrained_model_name_or_path,
115+
archive_file,
116+
)
117+
118+
# If the checkpoint is not sharded, it's a trivial sharding case
119+
if not is_sharded:
120+
assert not isinstance(resolved_archive_file, list)
121+
resolved_archive_file = [resolved_archive_file]
122+
123+
error_msgs = []
124+
for shard_file in resolved_archive_file:
125+
state_dict = load_state_dict(shard_file)
126+
replaced_layers = replace_params(state_dict, vanilla_model.state_dict())
127+
128+
error_msgs += _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
129+
130+
# Force mem release. Taken from huggingface code
131+
del state_dict
132+
gc.collect()
133+
134+
return vanilla_model
135+
136+
def replace_params(hf_state_dict, te_state_dict):
137+
# collect all layer prefixes to update
138+
all_layer_prefixes = set()
139+
for param_key in hf_state_dict.keys():
140+
layer_prefix_pat = 'model.layers.\d+.'
141+
m = re.match(layer_prefix_pat, param_key)
142+
if m is not None:
143+
all_layer_prefixes.add(m.group())
144+
145+
for layer_prefix in all_layer_prefixes:
146+
# When loading weights into models with less number of layers, skip the
147+
# copy if the corresponding layer doesn't exist in TE model
148+
if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:
149+
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]
150+
151+
if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:
152+
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]
153+
154+
if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:
155+
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]
156+
157+
if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict:
158+
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]
159+
160+
if layer_prefix + 'self_attention.proj.weight' in te_state_dict:
161+
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]
162+
163+
if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict:
164+
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]
165+
166+
if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict:
167+
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0)
168+
169+
if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict:
170+
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]
171+
172+
return all_layer_prefixes

0 commit comments

Comments
 (0)