-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Enable loading pre-quantized INT4 weights in Llama4 #330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Generate INT4 MP8 checkpoint: ``` torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --quantization_mode int4_mixed --world_size 8 ``` Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --world_size 1 --quantization-mode int4_mixed ``` Generate FP8 MP8 checkpoint: ``` torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --quantization_mode fp8_mixed --world_size 8 ``` Verify generated FP8 MP8 checkpoint with fp8_mixed (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --world_size 8 --quantization-mode fp8_mixed ``` Verify BF16 MP8 checkpoint (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 ``` Verify BF16 MP8 checkpoint with fp8_mixed (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 --quantization-mode fp8_mixed ``` Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output): ``` PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 1 --quantization-mode int4_mixed ```
dtype = torch.get_default_dtype() | ||
if int4_weight: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels like complexity that truly doesn't belong at this layer. can we please keep it outside into quantization code somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want llama-models
to become torchao
or vllm
or whatever really. it is not a full fledged all powerful inference engine.
) | ||
model_args.quantization_args = QuantizationArgs() | ||
model_args.quantization_args.int4_weight = True | ||
print("Loaded scale checkpoint") | ||
torch.set_default_tensor_type(torch.BFloat16Tensor) | ||
model = Transformer(model_args) | ||
print("Loading state dict...") | ||
model.load_state_dict(state_dict, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you move the model.load_state_dict()
to convert_to_quantized_model()
then you can do the following:
- change the structure of the Transformer from the outside in this code path (whatever you are doing with Experts)
- move all this scale ckpt paths complexity into quantization land
nobody reading generation.py should know about quantization unless they want to dig into it.
Generate INT4 MP8 checkpoint:
Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output):
Generate FP8 MP8 checkpoint:
Verify generated FP8 MP8 checkpoint with fp8_mixed (output):
Verify BF16 MP8 checkpoint (output):
Verify BF16 MP8 checkpoint with fp8_mixed (output):
Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output):