forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_llama_lib.py
1204 lines (1053 loc) · 40.1 KB
/
export_llama_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# Example script for exporting Llama2 to flatbuffer
import argparse
import copy
import json
import logging
import re
import shlex
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from typing import Callable, List, Optional, Union
import pkg_resources
import torch
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
from executorch.devtools.backend_debug import print_delegation_info
from executorch.devtools.etrecord import generate_etrecord
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
from executorch.extension.llm.export.partitioner_lib import (
get_coreml_partitioner,
get_mps_partitioner,
get_qnn_partitioner,
get_vulkan_partitioner,
get_xnnpack_partitioner,
)
from executorch.extension.llm.export.quantizer_lib import (
get_coreml_quantizer,
get_pt2e_quantization_params,
get_pt2e_quantizers,
get_qnn_quantizer,
get_vulkan_quantizer,
)
from executorch.util.activation_memory_profiler import generate_memory_trace
from ..model_factory import EagerModelFactory
from .source_transformation.apply_spin_quant_r1_r2 import (
fuse_layer_norms,
get_model_with_r1_r2,
)
from .source_transformation.attention import replace_attention_to_attention_sha
from .source_transformation.quantize import (
get_quant_embedding_transform,
get_quant_weight_transform,
)
from .source_transformation.quantized_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
replace_kv_cache_with_quantized_kv_cache,
)
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
from .source_transformation.sdpa import (
replace_causal_mask,
replace_kv_cache_with_coreml_kv_cache,
replace_kv_cache_with_simple_kv_cache,
replace_sdpa_with_coreml_sdpa,
replace_sdpa_with_custom_op,
replace_sdpa_with_flex_sdpa,
replace_sdpa_with_simple_sdpa,
)
from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
pkg_name = __name__
verbosity_setting = None
# All models that leverage the transformer architecture defined in llama_transformer.py.
EXECUTORCH_DEFINED_MODELS = [
"stories110m",
"llama2",
"llama3",
"llama3_1",
"llama3_2",
"static_llama",
"qwen2_5",
"phi-4-mini",
"smollm",
]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
class WeightType(Enum):
LLAMA = "LLAMA"
FAIRSEQ2 = "FAIRSEQ2"
def set_pkg_name(name: str) -> None:
global pkg_name
pkg_name = name
def get_resource_path(resource_name) -> str:
return pkg_resources.resource_filename(pkg_name, resource_name)
def set_verbosity(val):
global verbosity_setting
verbosity_setting = val
def verbose_export():
return verbosity_setting
def build_model(
modelname: str = "llama3",
extra_opts: str = "",
*,
par_local_output: bool = False,
resource_pkg_name: str = __name__,
) -> str:
if False: # par_local_output:
output_dir_path = "par:."
else:
output_dir_path = "."
argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
parser = build_args_parser()
args = parser.parse_args(shlex.split(argString))
# pkg_name = resource_pkg_name
return export_llama(args)
def build_args_parser() -> argparse.ArgumentParser:
ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
# parser.add_argument(
# "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
# )
parser.add_argument(
"--model",
default="llama3",
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
)
parser.add_argument(
"-E",
"--embedding-quantize",
default=None,
type=str,
help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)
parser.add_argument(
"--pt2e_quantize",
default=None,
choices=[
"xnnpack_dynamic",
"xnnpack_dynamic_qc4",
"qnn_8a8w",
"qnn_16a16w",
"qnn_16a4w",
"coreml_c4w",
"coreml_8a_c8w",
"coreml_8a_c4w",
"coreml_baseline_8a_c8w",
"coreml_baseline_8a_c4w",
"vulkan_8w",
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
)
parser.add_argument(
"-qmode",
"--quantization_mode",
type=_qmode_type,
default=None,
help="type of quantization",
)
parser.add_argument(
"-c",
"--checkpoint",
default=f"{ckpt_dir}/params/demo_rand_params.pth",
help="checkpoint path",
)
parser.add_argument(
"--checkpoint_dir",
default=None,
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)
parser.add_argument(
"--use_qnn_sha",
action="store_true",
help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
)
parser.add_argument(
"--calibration_tasks",
nargs="+",
type=str,
default=None,
help="Tasks for GPTQ calibration from lm_eval",
)
parser.add_argument(
"--calibration_limit",
type=int,
default=None,
help="number of samples used for calibration from lm_eval",
)
parser.add_argument(
"--calibration_seq_length",
type=int,
default=None,
help="Sequence length for GPTQ calibration from lm_eval",
)
parser.add_argument(
"--calibration_data",
type=str,
default="Once upon a time",
help="Calibration prompts from users",
)
parser.add_argument(
"-t",
"--tokenizer_path",
default=None,
help="tokenizer path (Note: .model not .bin)",
)
parser.add_argument(
"-kv",
"--use_kv_cache",
default=False,
action="store_true",
help="Whether or not to export a model using kv cache",
)
parser.add_argument(
"--quantize_kv_cache",
default=False,
action="store_true",
help="Whether or not to export a model using int8 per token quantized kv cache",
)
parser.add_argument(
"--num_sharding",
type=int,
default=0,
help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
)
parser.add_argument(
"--use_sdpa_with_kv_cache",
default=False,
action="store_true",
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
)
parser.add_argument(
"--disable_dynamic_shape",
dest="enable_dynamic_shape",
default=True, # Enable this by default
action="store_false",
help="Enable dynamic shape along seq dim. Used for faster prefill",
)
parser.add_argument(
"-p",
"--params",
default=f"{ckpt_dir}/params/demo_config.json",
help="config.json",
)
parser.add_argument(
"--optimized_rotation_path",
default=None,
required=False,
help="[QNN backend] Optimized rotation checkpoint path. Just apply R1/R2 here."
"You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main",
)
parser.add_argument(
"-m",
"--metadata",
default=None,
help='metadata string in json format. Example {"key": 1, "key2": "value2"}',
)
parser.add_argument(
"-s",
"--so_library",
default=None,
required=False,
help="shared library for quantized operators",
)
parser.add_argument(
"--profile_memory",
required=False,
action="store_true",
help="Generate chrome trace of activation memory for intermediate tensors.",
)
parser.add_argument(
"-prof",
"--profile_path",
default=None,
help="Use cProfile to profile model export. Results saved to profile_path as a html file.",
)
parser.add_argument(
"-G",
"--group_size",
type=int,
default=None,
help="group_size for weight quantization",
)
parser.add_argument(
"-d",
"--dtype-override",
default="fp32",
type=str,
choices=["fp32", "fp16", "bf16"],
help="Override the dtype of the model (default is the checkpoint dtype)."
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
)
parser.add_argument(
"-n",
"--output_name",
default=None,
help="Override the output filename of the saved pte model file.",
)
parser.add_argument(
"--max_seq_length",
type=int,
default=128,
help="maximum length sequence to evaluate",
)
parser.add_argument(
"--max_context_length",
type=int,
default=128,
help="maximum length of context for model to remember",
)
parser.add_argument("-2", "--fairseq2", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
"-X",
"--xnnpack",
action="store_true",
help="Delegate to DQLinear ops to the xnnpack backend",
)
parser.add_argument(
"--xnnpack-extended-ops",
action="store_true",
help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.",
)
parser.add_argument("-V", "--vulkan", action="store_true")
parser.add_argument("--mps", action="store_true")
parser.add_argument("--coreml", action="store_true")
parser.add_argument(
"--coreml-enable-state",
action="store_true",
help="This option is only for coreml, and is only supported for MacOS15+/iOS18+",
)
parser.add_argument(
"--coreml-preserve-sdpa",
action="store_true",
help="This option is only for coreml: Preserve sdpa in torch edge program to use coreml iOS18.sdpa op",
)
parser.add_argument(
"--coreml-quantize",
default=None,
choices=["b4w", "c4w"],
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)",
)
parser.add_argument(
"--coreml-ios",
type=int,
default=15,
choices=(15, 16, 17, 18),
help="This option is only for coreml: The minimum iOS version to deploy",
)
parser.add_argument(
"--coreml-compute-units",
type=str,
default="cpu_only",
choices=("cpu_only", "cpu_and_gpu", "cpu_and_ne", "all"),
help="This option is only for coreml: the compute units to use when running the model",
)
parser.add_argument(
"--qnn",
action="store_true",
help="Delegate llama2 to qnn backend (Qualcomm), please use it --kv_cahce=True",
)
parser.add_argument(
"--expand_rope_table",
default=False,
action="store_true",
help="[Temp workaround] Expand sin/cos table in head dim to take vectorized path in optimized kernels.",
)
parser.add_argument(
"--generate_etrecord",
action="store_true",
required=False,
default=False,
help="Generate the ETRecord debug artifact.",
)
parser.add_argument(
"--generate_full_logits",
action="store_true",
required=False,
default=False,
help="Generate logits for all inputs.",
)
parser.add_argument(
"--soc_model",
help="[QNN backend] SoC model of current device. e.g. 'SM8650' for Snapdragon 8 Gen 3.",
type=str,
required=False,
default="SM8650",
)
parser.add_argument(
"-sq",
"--use_spin_quant",
type=str,
default=None,
choices=["cuda", "native"],
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
)
parser.add_argument(
"-qat",
"--use_qat",
default=False,
action="store_true",
help="Whether the checkpoin is pre-quantized with QAT or not.",
)
parser.add_argument(
"-lora",
"--use_lora",
type=int,
default=0,
help="Whether the checkpoint contains LoRA adaptors or not. 0: no LoRA adaptors; "
"otherwise, it means the rank of LoRA adaptors. Currently it only works if QAT is enabled.",
)
parser.add_argument(
"--preq_mode",
type=str,
default=None,
choices=["8da4w", "8da4w_output_8da8w"],
help="Quantization mode used for pre-quantized checkpoint. Only support 8da4w and 8da4w_output_8da8w right now.",
)
parser.add_argument(
"--preq_group_size",
type=int,
default=32,
help="group_size for pre-quantized checkpoint weight quantization",
)
parser.add_argument(
"--preq_embedding_quantize",
default="8,0",
type=str,
help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
)
parser.add_argument(
"--use_attention_sink",
default=None,
type=str,
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>', e.g., '4,2044,1024'.",
)
parser.add_argument(
"--output_prune_map",
default=None,
help="path to the output pruning token mapping file (token_map.json)",
)
parser.add_argument(
"--input_prune_map",
default=None,
help="path to the input pruning token mapping file (token_map.json)",
)
parser.add_argument(
"--export_only",
default=False,
action="store_true",
help="If true, stops right after torch.export() and saves the exported model.",
)
return parser
def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
path = str(path)
if verbose_export():
print(f"creating canonical path for {path}")
if not path.startswith("par:"):
return path
if not IS_FBCODE:
print("not FBCODE")
return path[4:]
else:
return_val = pkg_resources.resource_filename(pkg_name, path[4:])
if verbose_export():
print(f"canonical name is: {return_val}")
return return_val
def export_llama(args) -> str:
if args.profile_path is not None:
try:
from executorch.util.python_profiler import CProfilerFlameGraph
with CProfilerFlameGraph(args.profile_path):
builder = _export_llama(args)
assert (
filename := builder.get_saved_pte_filename()
) is not None, "Fail to get file name from builder"
return filename
except ImportError:
print(
"Please run `pip install snakeviz` to install required dependencies for cProfiler flamegraph."
)
return ""
else:
builder = _export_llama(args)
assert (
filename := builder.get_saved_pte_filename()
) is not None, "Fail to get file name from builder"
return filename
def _prepare_for_llama_export(args) -> LLMEdgeManager:
"""
Helper function for export_llama. Loads the model from checkpoint and params,
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
"""
# load model from checkpoint and params.json
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
checkpoint_dir = (
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
)
params_path = canonical_path(args.params)
output_dir_path = canonical_path(args.output_dir, dir=True)
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
# dtype override
if args.dtype_override is not None:
dtype_override = DType[args.dtype_override]
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
dtype_override = DType["fp16"]
else:
dtype_override = None
return (
_load_llama_model(
args.model,
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
generate_full_logits=args.generate_full_logits,
weight_type=weight_type,
enable_dynamic_shape=args.enable_dynamic_shape,
calibration_tasks=args.calibration_tasks,
calibration_limit=args.calibration_limit,
calibration_seq_length=args.calibration_seq_length,
calibration_data=args.calibration_data,
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
max_context_len=args.max_context_length,
input_prune_map_path=args.input_prune_map,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
dtype_override=dtype_override,
args=args,
)
.set_output_dir(output_dir_path)
.source_transform(_get_source_transforms(args.model, dtype_override, args))
)
def get_quantizer_and_quant_params(args):
pt2e_quant_params = get_pt2e_quantization_params(
args.pt2e_quantize, args.quantization_mode
)
quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
quant_dtype = None
if args.qnn and args.pt2e_quantize:
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
qnn_quantizer, quant_dtype = get_qnn_quantizer(
args.pt2e_quantize, args.quantization_mode
)
quantizers.append(qnn_quantizer)
if args.coreml and args.pt2e_quantize:
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
quantizers.append(coreml_quantizer)
if args.vulkan and args.pt2e_quantize:
assert (
len(quantizers) == 0
), "Should not enable both vulkan and other quantizers"
vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize)
quantizers.append(vulkan_quantizer)
logging.info(f"Applying quantizers: {quantizers}")
return pt2e_quant_params, quantizers, quant_dtype
def _qmode_type(value):
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"]
if value in choices:
return value
for pattern in patterns:
matches = re.findall(pattern, value)
if len(matches) == 1:
return value
raise argparse.ArgumentTypeError(
f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}."
)
def _validate_args(args):
"""
TODO: Combine all the backends under --backend args
"""
if args.max_context_length < args.max_seq_length:
raise ValueError(
f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
)
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
raise ValueError(
"Dynamic shape is not supported with coreml, MPS or qnn backends."
" Please use --disable_dynamic_shape."
)
if args.num_sharding > 0 and not args.qnn:
raise ValueError("Model shard is only supported with qnn backend now.")
if (
args.quantization_mode is not None
and args.quantization_mode.startswith("torchao:")
) or (
args.embedding_quantize is not None
and args.embedding_quantize.startswith("torchao:")
):
if args.enable_dynamic_shape:
raise ValueError(
"Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape."
"If you need this feature, please file an issue."
)
def _to_edge_and_lower_llama_xnnpack(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
) -> LLMEdgeManager: # noqa: C901
partitioners = []
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
modelname = f"xnnpack_dq_{modelname}"
if args.xnnpack_extended_ops:
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
modelname = f"xnnpack_{modelname}"
logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
if args.generate_etrecord:
raise NotImplementedError(
"export_llama does not support XNNPack and generating ETRecord at the moment."
)
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
partitioners
)
if args.verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
return builder.to_executorch(passes=additional_passes)
def _to_edge_and_lower_llama( # noqa: C901
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
):
builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()
# to_backend
partitioners = []
if args.vulkan:
partitioners.append(
get_vulkan_partitioner(
args.dtype_override,
args.enable_dynamic_shape,
)
)
# Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
modelname = f"vulkan_{modelname}"
# Need to remove asserts from the graph to prevent graph breaks
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
if args.mps:
partitioners.append(get_mps_partitioner(args.use_kv_cache))
modelname = f"mps_{modelname}"
if args.coreml:
coreml_partitioner = get_coreml_partitioner(
args.coreml_ios,
args.embedding_quantize,
args.pt2e_quantize,
args.coreml_quantize,
args.coreml_compute_units,
)
partitioners.append(coreml_partitioner)
modelname = f"coreml_{modelname}"
if args.qnn:
from executorch.extension.llm.custom_ops import model_sharding
partitioners.append(
get_qnn_partitioner(
args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model
)
)
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
_transform(builder_exported_to_edge.edge_manager.exported_program())
if args.num_sharding > 0:
model_sharding.split_graph(
builder_exported_to_edge.edge_manager.exported_program(),
builder_exported_to_edge.metadata["get_n_layers"],
shares=args.num_sharding,
)
from functools import partial
# pyre-ignore
from executorch.backends.qualcomm.quantizer.custom_annotation import (
get_custom_quant_ios_dtype,
)
atten = builder_exported_to_edge.model.layers[0].attention
if args.use_qnn_sha:
cache_shape = torch.Size(
(atten.max_batch_size, atten.max_context_len, atten.head_dim)
)
else:
cache_shape = torch.Size(
(
atten.max_batch_size,
atten.max_context_len,
atten.n_kv_heads,
atten.head_dim,
)
)
tag_quant_io(
builder_exported_to_edge.edge_manager.exported_program().graph_module,
partial(get_custom_quant_ios_dtype, cache_shape),
)
logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
logging.info("Generating etrecord")
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
builder = builder_exported_to_edge.to_backend(partitioners)
if args.verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program
canonicalize_program(builder.edge_manager.exported_program())
builder = builder.to_executorch(
passes=additional_passes,
)
# Generate ETRecord
if edge_manager_copy:
generate_etrecord(
et_record="etrecord.bin",
edge_dialect_program=edge_manager_copy,
executorch_program=builder.export_program,
)
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners)
if args.verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program
canonicalize_program(builder.edge_manager.exported_program())
builder = builder.to_executorch(passes=additional_passes)
return builder
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()
builder_exported.run_canonical_optimizations()
modelname = builder_exported.modelname
if args.export_only:
exit()
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
args.xnnpack = True
if args.xnnpack:
builder = _to_edge_and_lower_llama_xnnpack(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
)
else:
builder = _to_edge_and_lower_llama(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
)
if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
if builder.dtype == DType.fp16:
modelname = f"{modelname}_h"
if args.output_name:
modelname = args.output_name
if modelname.endswith(".pte"):
output_file = modelname
modelname = modelname[:-4]
print(f"modelname: {modelname}")
print(f"output_file: {output_file}")
else:
output_file = f"{builder.output_dir}/{modelname}.pte"
print(f"modelname: {modelname}")
print(f"output_file: {output_file}")
else:
output_file = f"{builder.output_dir}/{modelname}.pte"
builder.save_to_pte(output_file)
return builder
def _load_llama_model_metadata(
weight_type: WeightType,
use_kv_cache: bool,
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
max_seq_len: int,
max_context_len: int,
n_layers: int,
vocab_size: int,
metadata_str: Optional[str] = None,
):
is_fairseq2 = weight_type == WeightType.FAIRSEQ2
metadata = {
"get_bos_id": 3 if is_fairseq2 else 1,
"get_eos_ids": [3] if is_fairseq2 else [2],
"get_max_seq_len": max_seq_len,
"get_max_context_len": max_context_len,
"get_n_layers": n_layers,
"get_vocab_size": vocab_size,
"use_kv_cache": use_kv_cache,
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
"enable_dynamic_shape": enable_dynamic_shape,
}
if metadata_str:
try:
extra = json.loads(metadata_str)
for k, v in extra.items():
metadata[k] = v
except JSONDecodeError:
logging.error("Invalid metadata, should be a valid JSON string")
return metadata
def _load_llama_model(
modelname: str = "llama3",
*,
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
params_path: str,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = False,
weight_type: WeightType = WeightType.LLAMA,
enable_dynamic_shape: bool = False,
calibration_tasks: Optional[List[str]] = None,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
calibration_data: Optional[str] = None,
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
max_context_len: int = 128,
input_prune_map_path: Optional[str] = None,
output_prune_map_path: Optional[str] = None,
metadata_str: Optional[str] = None,
dtype_override: Optional[DType] = None,
args,
) -> "LLMEdgeManager":
"""
A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
can help further lower the model to ExecuTorch.
Returns:
An instance of LLMEdgeManager which contains the eager mode model.
"""
assert (
checkpoint or checkpoint_dir
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
logging.info(
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
)
if modelname in EXECUTORCH_DEFINED_MODELS:
module_name = "llama"
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
elif modelname in TORCHTUNE_DEFINED_MODELS: