forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_checkpoint.py
749 lines (665 loc) · 31.8 KB
/
convert_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
import argparse
import json
import os
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Optional, Tuple
import safetensors
import torch
from transformers import AutoModelForCausalLM, FalconConfig, FalconForCausalLM
import tensorrt_llm
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.llama.utils import ( # TODO: move the utils to common dir shared by models
iterate_shard_files, load_state_dict, retrieved_layer_index_from_name)
from tensorrt_llm.quantization import QuantAlgo
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--tp_size',
type=int,
default=1,
help='N-way tensor parallelism size')
parser.add_argument('--pp_size',
type=int,
default=1,
help='N-way pipeline parallelism size')
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'bfloat16', 'float16'])
parser.add_argument(
'--use_parallel_embedding',
action="store_true",
default=False,
help=
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
)
parser.add_argument(
'--embedding_sharding_dim',
type=int,
default=0,
choices=[0, 1],
help=
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
'To shard it along hidden dimension, set embedding_sharding_dim=1'
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
)
parser.add_argument(
'--use_embedding_sharing',
action="store_true",
default=False,
help=
'Try to reduce the engine size by sharing the embedding lookup table between two layers.'
'Note: the flag might not take effect when the criteria are not met.')
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument('--load_by_shard',
action='store_true',
help='Load a pretrained model shard-by-shard.')
parser.add_argument('--output_dir',
type=str,
default='tllm_checkpoint',
help='The path to save the TensorRT-LLM checkpoint')
parser.add_argument(
'--workers',
type=int,
default=1,
help='The number of workers for converting checkpoint in parallel')
parser.add_argument('--log_level', type=str, default='info')
args = parser.parse_args()
tensorrt_llm.logger.set_level(args.log_level)
return args
def load_falcon_config(model_dir: str) -> FalconConfig:
""" Helper utility to load FalconConfig.
A pretrained checkpoint from modeling_RW.py has a different structure
and is not compatible with `transformers.FalconConfig` and
`transformers.FalconModel`. We need to manually set the config values.
"""
config = FalconConfig.from_pretrained(model_dir)
config.architectures = ["FalconForCausalLM"]
# Falcon-7B config may not have num_kv_heads or n_head_kv.
# Although Falcon-180B uses GQA (num_kv_heads=8), its config
# has multi_query=True.
if getattr(config, 'multi_query', False) and \
not getattr(config, 'new_decoder_architecture', False):
config.num_kv_heads = 1
if config.model_type not in ['RefinedWebModel', 'RefinedWeb']:
return config
if config.model_type == 'RefinedWeb':
# Case 1. Falcon-40B / Falcon-40B-instruct
# https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
config.num_hidden_layers = config.n_layer
config.num_attention_heads = config.n_head
config.num_kv_heads = config.n_head_kv
config.new_decoder_architecture = True
elif config.model_type == 'RefinedWebModel':
# Case 2. Falcon-7B / Falcon-7B-instruct
# https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
config.num_hidden_layers = config.n_layer
config.num_attention_heads = config.n_head
config.num_kv_heads = 1 if config.multi_query else config.n_head
config.new_decoder_architecture = False
else:
raise ValueError("Shouldn't reach here.")
config.model_type = 'falcon'
return config
def split(weight: torch.Tensor,
tp_size: int,
rank: int = 0,
dim: int = 0) -> torch.Tensor:
if tp_size == 1:
return weight
elif weight.ndim == 1:
return torch.chunk(weight, tp_size)[rank].clone()
else:
return torch.chunk(weight, tp_size, dim=dim)[rank].clone()
def reorder_qkv_weight_or_bias(weight: torch.Tensor,
head_dim: int,
num_heads: int,
num_kv_heads: Optional[int] = None,
tp_size: int = 1,
is_bias: bool = False) -> torch.Tensor:
""" Reorder the qkv weight for TRT-LLM use.
The shape of the fused QKV weights in HF is different from the shape that
TRT-LLM requires. In particular, the weight of HF consists of interleaved
q, k, v head weights, while that of TRT-LLM is contiguous.
HF : [q1, k1, v1, ..., qh, kh, vh]
TRT-LLM: [q1, ..., qh, k1, ..., kh, v1, vh]
where qi, vi, ki are weight vectors corresponding to attention head i.
It's similar to multi/grouped query attention cases.
We reorder and split the weight of an attention layer to fit into TRT-LLM.
The reordered weight and bias will be
weight: (T, Qh * D + 2 * KVh * D, H)
bias : (T, Qh * D + 2 * KVh * D)
where T=tp_size, Qh=local_num_q_heads, KVh=local_num_kv_heads, D=head_dim,
H=hidden_dim. In the multi/grouped query attention, the number of K/V
attention heads are less than that of Q attention, so that K/V attention
heads may be shared across different ranks if necessary.
For tensor parallelism, we use the first dimension to select the
corresponding weights.
"""
# Query types and expected kv heads.
# - Conventional MHA: num_heads = num_kv_heads
# - Multi-Query Attention: num_kv_heads = 1
# - Grouped-Query Attention: num_heads % num_kv_heads = 0
num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
assert num_heads % num_kv_heads == 0, \
f'num_heads({num_heads}) must be divisible by '\
f'num_kv_heads({num_kv_heads})).'
# The number of attention heads per group: N q head + 1 k head + 1 v head.
num_group_heads = num_heads // num_kv_heads + 2
assert weight.shape[0] == num_kv_heads * num_group_heads * head_dim, \
f'{weight.shape[0]} != {num_kv_heads} * {num_group_heads} * {head_dim}'
qkv_in = num_heads * head_dim if not is_bias else 1
# Split Q/K/V weights
weight = weight.reshape(num_kv_heads, num_heads // num_kv_heads + 2,
head_dim, qkv_in)
q_w = weight[:, :-2, ...] # (nKV, num_heads // nKV, head_dim, qkv_in)
k_w = weight[:, -2:-1, ...] # (nKV, 1, head_dim, qkv_in)
v_w = weight[:, -1:, ...] # (nKV, 1, head_dim, qkv_in)
if num_kv_heads < num_heads and num_kv_heads < tp_size:
# Duplicate K/V heads to make sure that each rank has at least one
# K/V heads. For instance, num_heads=8, num_kv_heads=2, tp_size=4,
# we will make the qkv weight as below.
# Orig: [q0 q1 q2 q3 k0 v0 q4 q5 q6 q7 k1 v0 v1]
# >>>> [[q0 q1 k0 v0], [q2 q3 k0 v0], [q4 q5 k1 v1], [q6 q7 k1 v1]]
assert tp_size % num_kv_heads == 0
num_dups = tp_size // num_kv_heads
# k_w and v_w have the same shape.
new_shape = (num_kv_heads, num_dups) + k_w.shape[2:]
k_w = torch.broadcast_to(k_w, size=new_shape)
v_w = torch.broadcast_to(v_w, size=new_shape)
# Update the number of kv heads.
num_kv_heads = tp_size
reordered = torch.concat(
[
q_w.reshape(tp_size, num_heads // tp_size, head_dim, qkv_in),
k_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
v_w.reshape(tp_size, num_kv_heads // tp_size, head_dim, qkv_in),
],
dim=1,
)
qkv_out = (num_heads + 2 * num_kv_heads) // tp_size * head_dim
return reordered.reshape((tp_size, qkv_out, -1))
def split_qkv_weight(weight: torch.Tensor,
hidden_size: int,
num_heads: int,
tp_size: int,
rank: int,
is_bias: bool,
num_kv_heads: Optional[int] = None) -> torch.Tensor:
""" Splits the QKV matrix according to tensor parallelism """
head_dim = hidden_size // num_heads
weight = reorder_qkv_weight_or_bias(weight,
head_dim=head_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
tp_size=tp_size,
is_bias=is_bias)
# Copy a sliced tensor to prevent memory leak. A sliced tensor shares the
# memory buffer of the original tensor. So, returning without copying makes
# the buffer of a loaded "qkv" be referenced, resulting GC can't release
# those weights until the whole process ends.
if not is_bias:
return weight[rank, ...].clone()
else:
return weight[rank, ...].ravel().clone()
def split_matrix(weight: torch.Tensor, tp_size: int, rank: int,
dim: int) -> torch.Tensor:
return split(weight, tp_size, rank, dim=dim)
def get_weight(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> torch.Tensor:
if f'{prefix}.weight' not in params:
return None
return params[f'{prefix}.weight'].to(dtype).detach().cpu()
def get_bias(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> torch.Tensor:
if f'{prefix}.bias' not in params:
return None
return params[f'{prefix}.bias'].to(dtype).detach().cpu()
def get_weight_and_bias(params: Dict[str, torch.Tensor], prefix: str,
dtype: torch.dtype) -> Tuple[torch.Tensor]:
return get_weight(params, prefix, dtype), get_bias(params, prefix, dtype)
def get_tllm_linear_weight(
weight: torch.Tensor,
prefix: str,
bias: Optional[torch.Tensor] = None,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8
) -> Dict[str, torch.Tensor]:
results = {}
if use_weight_only:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[f'{prefix}.weight'] = processed_torch_weights
results[f'{prefix}.per_channel_scale'] = torch_weight_scales
else:
results[f'{prefix}.weight'] = weight
if bias is not None:
results[f'{prefix}.bias'] = bias
return results
def get_tllm_param(
param: torch.Tensor,
name: str,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8
) -> Dict[str, torch.Tensor]:
results = {}
if name.endswith('.weight') and use_weight_only:
v = param.t().contiguous()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[name] = processed_torch_weights
results[name.replace('weight',
'per_channel_scale')] = torch_weight_scales
else:
results[name] = param
return results
def convert_hf_falcon(hf_model: FalconForCausalLM,
hf_config: FalconConfig,
mapping: Mapping,
dtype: str = 'float32',
use_parallel_embedding: bool = False,
sharding_dim: int = 0,
share_embedding_table: bool = False,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8):
weights = {}
tik = time.time()
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_config.num_attention_heads
hidden_size = hf_config.hidden_size
vocab_size = hf_config.vocab_size
num_kv_heads = getattr(hf_config, 'num_kv_heads', num_attention_heads)
num_hidden_layers = hf_config.num_hidden_layers
parallel_attention = hf_config.parallel_attn
new_decoder_architecture = hf_config.new_decoder_architecture
layers_range = mapping.pp_layers(num_hidden_layers)
for l in layers_range:
prefix = f'transformer.h.{l}'
tllm_prex = f'transformer.layers.{l-layers_range[0]}'
qkv_weight, qkv_bias = get_weight_and_bias(
model_params, f'{prefix}.self_attention.query_key_value', dtype)
qkv_w = split_qkv_weight(qkv_weight,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=False,
num_kv_heads=num_kv_heads)
if qkv_bias is None:
qkv_b = None
else:
qkv_b = split_qkv_weight(qkv_bias,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=True,
num_kv_heads=num_kv_heads)
weights.update(
get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', qkv_b,
use_weight_only,
plugin_weight_only_quant_type))
attn_dense_weight, attn_dense_bias = get_weight_and_bias(
model_params, f'{prefix}.self_attention.dense', dtype)
attn_dense_w = split_matrix(attn_dense_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(attn_dense_w, f'{tllm_prex}.attention.dense',
attn_dense_bias, use_weight_only,
plugin_weight_only_quant_type))
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
model_params, f'{prefix}.mlp.dense_h_to_4h', dtype)
mlp_fc_w = split_matrix(mlp_fc_weight,
mapping.tp_size,
mapping.tp_rank,
dim=0)
if mlp_fc_bias is None:
mlp_fc_b = None
else:
mlp_fc_b = split_matrix(mlp_fc_bias,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', mlp_fc_b,
use_weight_only,
plugin_weight_only_quant_type))
mlp_proj_weight, mlp_proj_bias = get_weight_and_bias(
model_params, f'{prefix}.mlp.dense_4h_to_h', dtype)
mlp_proj_w = split_matrix(mlp_proj_weight,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj',
mlp_proj_bias, use_weight_only,
plugin_weight_only_quant_type))
if new_decoder_architecture:
input_ln_weight, input_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.ln_attn', dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_weight
if input_ln_bias is not None:
weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_bias
mlp_ln_weight, mlp_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.ln_mlp', dtype)
weights[f'{tllm_prex}.mlp_layernorm.weight'] = mlp_ln_weight
if mlp_ln_bias is not None:
weights[f'{tllm_prex}.mlp_layernorm.bias'] = mlp_ln_bias
else:
input_ln_weight, input_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.input_layernorm', dtype)
weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_weight
if input_ln_bias is not None:
weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_bias
if not parallel_attention:
post_ln_weight, post_ln_bias = get_weight_and_bias(
model_params, f'{prefix}.post_attention_layernorm', dtype)
if post_ln_weight is not None:
weights[
f'{tllm_prex}.post_layernorm.weight'] = post_ln_weight
if post_ln_bias is not None:
weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_bias
embed_w = get_weight(model_params, 'transformer.word_embeddings', dtype)
if mapping.is_first_pp_rank():
if not use_parallel_embedding:
weights['transformer.vocab_embedding.weight'] = embed_w
else:
if sharding_dim == 0:
assert vocab_size % mapping.tp_size == 0
else:
assert hidden_size % mapping.tp_size == 0
weights['transformer.vocab_embedding.weight'] = split_matrix(
embed_w, mapping.tp_size, mapping.tp_rank, sharding_dim)
if mapping.is_last_pp_rank():
if not share_embedding_table:
weights['lm_head.weight'] = split_matrix(embed_w.clone(),
mapping.tp_size,
mapping.tp_rank,
dim=0)
ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'transformer.ln_f',
dtype)
weights['transformer.ln_f.weight'] = ln_f_w
if ln_f_b is not None:
weights['transformer.ln_f.bias'] = ln_f_b
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
def load_from_hf_falcon_checkpoint(
hf_model_dir: str,
hf_config: FalconConfig,
mapping: Mapping,
dtype: str = 'float32',
use_parallel_embedding: bool = False,
sharding_dim: int = 0,
share_embedding_table: bool = False,
use_weight_only: bool = False,
plugin_weight_only_quant_type: torch.dtype = torch.int8):
weights = {}
tik = time.time()
dtype = getattr(torch, dtype)
num_attention_heads = hf_config.num_attention_heads
hidden_size = hf_config.hidden_size
vocab_size = hf_config.vocab_size
num_kv_heads = getattr(hf_config, 'num_kv_heads', num_attention_heads)
num_hidden_layers = hf_config.num_hidden_layers
layers_range = mapping.pp_layers(num_hidden_layers)
for model_file in iterate_shard_files(hf_model_dir, mapping.tp_rank):
state_dict = load_state_dict(model_file, dtype)
for name, param in state_dict.items():
l = retrieved_layer_index_from_name(name)
if l is not None:
if l not in layers_range:
continue
prefix = f'transformer.layers.{l-layers_range[0]}'
if 'self_attention.query_key_value' in name:
if name.endswith('weight'):
qkv_w = split_qkv_weight(param,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=False,
num_kv_heads=num_kv_heads)
weights.update(
get_tllm_param(qkv_w,
f'{prefix}.attention.qkv.weight',
use_weight_only,
plugin_weight_only_quant_type))
else:
qkv_b = split_qkv_weight(param,
hidden_size,
num_attention_heads,
mapping.tp_size,
mapping.tp_rank,
is_bias=True,
num_kv_heads=num_kv_heads)
weights.update(
get_tllm_param(qkv_b,
f'{prefix}.attention.qkv.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'self_attention.dense' in name:
if name.endswith('weight'):
attn_dense_w = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_param(attn_dense_w,
f'{prefix}.attention.dense.weight',
use_weight_only,
plugin_weight_only_quant_type))
else:
weights.update(
get_tllm_param(param,
f'{prefix}.attention.dense.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'mlp.dense_h_to_4h' in name:
if name.endswith('weight'):
mlp_fc_w = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_param(mlp_fc_w, f'{prefix}.mlp.fc.weight',
use_weight_only,
plugin_weight_only_quant_type))
else:
mlp_fc_b = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=0)
weights.update(
get_tllm_param(mlp_fc_b, f'{prefix}.mlp.fc.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'mlp.dense_4h_to_h' in name:
if name.endswith('weight'):
mlp_proj_w = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=1)
weights.update(
get_tllm_param(mlp_proj_w,
f'{prefix}.mlp.proj.weight',
use_weight_only,
plugin_weight_only_quant_type))
else:
weights.update(
get_tllm_param(param, f'{prefix}.mlp.proj.bias',
use_weight_only,
plugin_weight_only_quant_type))
elif 'ln_attn' in name or 'input_layernorm' in name:
if name.endswith('weight'):
weights[f'{prefix}.input_layernorm.weight'] = param
else:
weights[f'{prefix}.input_layernorm.bias'] = param
elif 'ln_mlp' in name:
if name.endswith('weight'):
weights[f'{prefix}.mlp_layernorm.weight'] = param
else:
weights[f'{prefix}.mlp_layernorm.bias'] = param
elif 'post_attention_layernorm' in name:
if name.endswith('weight'):
weights[f'{prefix}.post_layernorm.weight'] = param
else:
weights[f'{prefix}.post_layernorm.bias'] = param
elif 'word_embeddings' in name:
if mapping.is_first_pp_rank():
if not use_parallel_embedding:
weights['transformer.vocab_embedding.weight'] = param
else:
if sharding_dim == 0:
assert vocab_size % mapping.tp_size == 0
else:
assert hidden_size % mapping.tp_size == 0
weights[
'transformer.vocab_embedding.weight'] = split_matrix(
param, mapping.tp_size, mapping.tp_rank,
sharding_dim)
if mapping.is_last_pp_rank() and not share_embedding_table:
weights['lm_head.weight'] = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=0)
elif 'ln_f' in name:
if mapping.is_last_pp_rank():
if name.endswith('weight'):
weights['transformer.ln_f.weight'] = param
else:
weights['transformer.ln_f.bias'] = param
del state_dict
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Weights loaded. Total time: {t}')
return weights
if __name__ == '__main__':
# TODO(qijun): Currently, the convert script depends on a torch op:
# torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix,
# which is included in tensorrt_llm Python package. Otherwise, the convert
# script does not need to import tensorrt_llm. Will remove it after reimplementing
# the op with PyTorch.
print(tensorrt_llm.__version__)
args = parse_arguments()
world_size = args.tp_size * args.pp_size
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
quant_algo = None
plugin_weight_only_quant_type = None
if args.use_weight_only and args.weight_only_precision == 'int8':
plugin_weight_only_quant_type = torch.int8
quant_algo = QuantAlgo.W8A16
elif args.use_weight_only and args.weight_only_precision == 'int4':
plugin_weight_only_quant_type = torch.quint4x2
quant_algo = QuantAlgo.W4A16
hf_config = load_falcon_config(args.model_dir)
config = {
'architecture': hf_config.architectures[0],
'dtype': args.dtype,
'num_hidden_layers': hf_config.num_hidden_layers,
'num_attention_heads': hf_config.num_attention_heads,
'num_key_value_heads': hf_config.num_kv_heads,
'hidden_size': hf_config.hidden_size,
'norm_epsilon': hf_config.layer_norm_epsilon,
'vocab_size': hf_config.vocab_size,
'position_embedding_type':
'alibi_with_scale' if hf_config.alibi else 'rope_gpt_neox',
'max_position_embeddings': hf_config.max_position_embeddings,
'hidden_act': 'gelu',
'use_parallel_embedding': args.use_parallel_embedding,
'embedding_sharding_dim': args.embedding_sharding_dim,
'share_embedding_table': args.use_embedding_sharing,
'quantization': {
'quant_algo': quant_algo,
},
'mapping': {
'world_size': world_size,
'tp_size': args.tp_size,
'pp_size': args.pp_size,
},
'bias': hf_config.bias,
'parallel_attention': hf_config.parallel_attn,
'new_decoder_architecture': hf_config.new_decoder_architecture,
}
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=4)
def covert_and_save(rank):
mapping = Mapping(world_size=world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size)
if args.load_by_shard:
weights = load_from_hf_falcon_checkpoint(
args.model_dir,
hf_config,
mapping,
dtype=args.dtype,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=args.use_embedding_sharing,
use_weight_only=args.use_weight_only,
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
else:
hf_model = AutoModelForCausalLM.from_pretrained(
args.model_dir, trust_remote_code=True, torch_dtype="auto")
weights = convert_hf_falcon(
hf_model,
hf_config,
mapping,
dtype=args.dtype,
use_parallel_embedding=args.use_parallel_embedding,
sharding_dim=args.embedding_sharding_dim,
share_embedding_table=args.use_embedding_sharing,
use_weight_only=args.use_weight_only,
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
del hf_model
safetensors.torch.save_file(
weights, os.path.join(args.output_dir, f'rank{rank}.safetensors'))
if args.workers == 1:
for rank in range(world_size):
covert_and_save(rank)
else:
with ThreadPoolExecutor(max_workers=args.workers) as p:
futures = [
p.submit(covert_and_save, rank) for rank in range(world_size)
]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert len(
exceptions
) == 0, "Checkpoint conversion failed, please check error log."
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
print(f'Total time of converting checkpoints: {t}')