Skip to content

Enable loading pre-quantized INT4 weights in Llama4 #330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jiawenliu64
Copy link
Contributor

Generate INT4 MP8 checkpoint:

torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --quantization_mode int4_mixed --world_size 8

Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --world_size 1 --quantization-mode int4_mixed

Generate FP8 MP8 checkpoint:

torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --quantization_mode fp8_mixed --world_size 8

Verify generated FP8 MP8 checkpoint with fp8_mixed (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --world_size 8 --quantization-mode fp8_mixed

Verify BF16 MP8 checkpoint (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8

Verify BF16 MP8 checkpoint with fp8_mixed (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 --quantization-mode fp8_mixed

Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output):

PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 1 --quantization-mode int4_mixed

Generate INT4 MP8 checkpoint:
```
torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --quantization_mode int4_mixed --world_size 8
```
Verify generated INT4 MP8 checkpoint with int4_mixed on single GPU (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-INT4 --world_size 1 --quantization-mode int4_mixed
```
Generate FP8 MP8 checkpoint:
```
torchrun --nproc-per-node=8 -m models.llama4.scripts.quantize --ckpt_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct --output_dir ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --quantization_mode fp8_mixed --world_size 8
```
Verify generated FP8 MP8 checkpoint with fp8_mixed (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct-FP8 --world_size 8 --quantization-mode fp8_mixed
```

Verify BF16 MP8 checkpoint (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8
```
Verify BF16 MP8 checkpoint with fp8_mixed (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=8 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 8 --quantization-mode fp8_mixed
```
Verify BF16 MP8 checkpoint with int4_mixed on single GPU (output):
```
PYTHONPATH=$(git rev-parse --show-toplevel) torchrun --nproc_per_node=1 -m models.llama4.scripts.chat_completion ../checkpoints/Llama-4-Scout-17B-16E-Instruct --world_size 1 --quantization-mode int4_mixed
```
@jiawenliu64 jiawenliu64 requested a review from jianyuh April 24, 2025 06:13
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 24, 2025
@jiawenliu64 jiawenliu64 changed the title Enable loading precompiled INT4 weights in Llama4 Enable loading pre-quantized INT4 weights in Llama4 Apr 24, 2025
dtype = torch.get_default_dtype()
if int4_weight:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like complexity that truly doesn't belong at this layer. can we please keep it outside into quantization code somehow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want llama-models to become torchao or vllm or whatever really. it is not a full fledged all powerful inference engine.

)
model_args.quantization_args = QuantizationArgs()
model_args.quantization_args.int4_weight = True
print("Loaded scale checkpoint")
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you move the model.load_state_dict() to convert_to_quantized_model() then you can do the following:

  • change the structure of the Transformer from the outside in this code path (whatever you are doing with Experts)
  • move all this scale ckpt paths complexity into quantization land

nobody reading generation.py should know about quantization unless they want to dig into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants