Skip to content

[Gemma3] - FP8 Dynamic KeyError: 'vision_model.encoder.layers.0.mlp.fc1.weight_scale' #1306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
m4r1k opened this issue Apr 1, 2025 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@m4r1k
Copy link

m4r1k commented Apr 1, 2025

Describe the bug
Following the sample to quantize Gemma2 to FP8, and using the Gemma3-27B-IT model, the llm compressor ends as expected but then when running vllm it fails with KeyError: 'vision_model.encoder.layers.0.mlp.fc1.weight_scale'

Expected behavior
To improve UX, I'd expect the provided samples would be self consistent without further modifications by end users.

Environment
Include all relevant environment information:

  1. OS [e.g. Ubuntu 20.04]: Ubuntu 24.04
  2. Python version [e.g. 3.7]: 3.12
  3. LLM Compressor version or commit hash [e.g. 0.1.0, f7245c8]: both 0.4.1 and db91486
  4. ML framework version(s) [e.g. torch 2.3.1]: torch 2.6.0
  5. Other Python package versions [e.g. vLLM, compressed-tensors, numpy, ONNX]: pip install --upgrade llmcompressor==0.4.1 vllm==0.8.2 lm_eval==0.4.3 as well as huggingface_hub==0.30.1 and hf_transfer==0.1.9
  6. Other relevant environment information [e.g. hardware, CUDA version]: GCP a3-highgpu-1g, 1x H100, 570.86.15, and CUDA 12.8

To Reproduce
Exact steps to reproduce the behavior:

#!/bin/bash

apt update
apt install -y python3-virtualenv
virtualenv llm
source llm/bin/activate

pip install --upgrade llmcompressor==0.4.1 vllm==0.8.2 lm_eval==0.4.3
pip install --upgrade huggingface_hub[hf_transfer]

export HF_HUB_ENABLE_HF_TRANSFER=1
export HF_TOKEN=<HF Token>
export HF_HOME=/models

mkdir -p /models

_MODEL=google/gemma-3-27b-it
_QUANTIZED_MODEL=/root/gemma-3-27b-it-FP8-Dynamic

huggingface-cli download "${_MODEL}"

export CUDA_VISIBLE_DEVICES=0
python3 gemma3.py

# https://github.com/vllm-project/llm-compressor/issues/1305
for _FILE in chat_template.json preprocessor_config.json processor_config.json
do
	huggingface-cli download ${_MODEL} ${_FILE} --local-dir="${_QUANTIZED_MODEL}/"
done

python3 - << EOF
from vllm import LLM
model = LLM("${_MODEL}")
model.generate("Hello my name is")
EOF

lm_eval \
  --model vllm \
  --model_args pretrained=${_MODEL},add_bos_token=True,max_model_len=4096 \
  --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250

python3 - << EOF
from vllm import LLM
model = LLM("${_QUANTIZED_MODEL}")
model.generate("Hello my name is")
EOF

lm_eval \
  --model vllm \
  --model_args pretrained=${_QUANTIZED_MODEL},add_bos_token=True,max_model_len=4096 \
  --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250

Errors
If applicable, add a full print-out of any errors or exceptions that are raised or include screenshots to help explain your problem.

INFO 04-01 18:19:01 [__init__.py:239] Automatically detected platform cuda.
INFO 04-01 18:19:07 [config.py:585] This model supports multiple tasks: {'generate', 'embed', 'classify', 'reward', 'score'}. Defaulting to 'generate'.
INFO 04-01 18:19:08 [config.py:1697] Chunked prefill is enabled with max_num_batched_tokens=16384.
INFO 04-01 18:19:09 [core.py:54] Initializing a V1 LLM engine (v0.8.2) with config: model='/root/gemma-3-27b-it-FP8-Dynamic', speculative_config=None, tokenizer='/root/gemma-3-27b-it-FP8-Dynamic', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=/root/gemma-3-27b-it-FP8-Dynamic, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 04-01 18:19:10 [utils.py:2321] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x75eeab66e3f0>
INFO 04-01 18:19:11 [parallel_state.py:954] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 04-01 18:19:11 [cuda.py:220] Using Flash Attention backend on V1 engine.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
INFO 04-01 18:19:16 [gpu_model_runner.py:1174] Starting to load model /root/gemma-3-27b-it-FP8-Dynamic...
INFO 04-01 18:19:16 [config.py:3243] cudagraph sizes specified by model runner [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312, 320, 328, 336, 344, 352, 360, 368, 376, 384, 392, 400, 408, 416, 424, 432, 440, 448, 456, 464, 472, 480, 488, 496, 504, 512] is overridden by config [512, 384, 256, 128, 4, 2, 1, 392, 264, 136, 8, 400, 272, 144, 16, 408, 280, 152, 24, 416, 288, 160, 32, 424, 296, 168, 40, 432, 304, 176, 48, 440, 312, 184, 56, 448, 320, 192, 64, 456, 328, 200, 72, 464, 336, 208, 80, 472, 344, 216, 88, 120, 480, 352, 248, 224, 96, 488, 504, 360, 232, 104, 496, 368, 240, 112, 376]
WARNING 04-01 18:19:16 [topk_topp_sampler.py:63] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:00<00:01,  2.65it/s]
ERROR 04-01 18:19:18 [core.py:343] EngineCore hit an exception: Traceback (most recent call last):
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 335, in run_engine_core
ERROR 04-01 18:19:18 [core.py:343]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 04-01 18:19:18 [core.py:343]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 290, in __init__
ERROR 04-01 18:19:18 [core.py:343]     super().__init__(vllm_config, executor_class, log_stats)
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 60, in __init__
ERROR 04-01 18:19:18 [core.py:343]     self.model_executor = executor_class(vllm_config)
ERROR 04-01 18:19:18 [core.py:343]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 52, in __init__
ERROR 04-01 18:19:18 [core.py:343]     self._init_executor()
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 47, in _init_executor
ERROR 04-01 18:19:18 [core.py:343]     self.collective_rpc("load_model")
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 04-01 18:19:18 [core.py:343]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 04-01 18:19:18 [core.py:343]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/utils.py", line 2255, in run_method
ERROR 04-01 18:19:18 [core.py:343]     return func(*args, **kwargs)
ERROR 04-01 18:19:18 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 136, in load_model
ERROR 04-01 18:19:18 [core.py:343]     self.model_runner.load_model()
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1177, in load_model
ERROR 04-01 18:19:18 [core.py:343]     self.model = get_model(vllm_config=self.vllm_config)
ERROR 04-01 18:19:18 [core.py:343]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
ERROR 04-01 18:19:18 [core.py:343]     return loader.load_model(vllm_config=vllm_config)
ERROR 04-01 18:19:18 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/model_loader/loader.py", line 444, in load_model
ERROR 04-01 18:19:18 [core.py:343]     loaded_weights = model.load_weights(
ERROR 04-01 18:19:18 [core.py:343]                      ^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/gemma3_mm.py", line 793, in load_weights
ERROR 04-01 18:19:18 [core.py:343]     return loader.load_weights(weights)
ERROR 04-01 18:19:18 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/utils.py", line 235, in load_weights
ERROR 04-01 18:19:18 [core.py:343]     autoloaded_weights = set(self._load_module("", self.module, weights))
ERROR 04-01 18:19:18 [core.py:343]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/utils.py", line 196, in _load_module
ERROR 04-01 18:19:18 [core.py:343]     yield from self._load_module(prefix,
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/utils.py", line 173, in _load_module
ERROR 04-01 18:19:18 [core.py:343]     loaded_params = module_load_weights(weights)
ERROR 04-01 18:19:18 [core.py:343]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343]   File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/siglip.py", line 518, in load_weights
ERROR 04-01 18:19:18 [core.py:343]     param = params_dict[name]
ERROR 04-01 18:19:18 [core.py:343]             ~~~~~~~~~~~^^^^^^
ERROR 04-01 18:19:18 [core.py:343] KeyError: 'vision_model.encoder.layers.0.mlp.fc1.weight_scale'
ERROR 04-01 18:19:18 [core.py:343]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:01<00:04,  1.20it/s]

CRITICAL 04-01 18:19:18 [core_client.py:269] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
Killed

Additional context
Add any other context about the problem here. Also include any relevant files.

@m4r1k m4r1k added the bug Something isn't working label Apr 1, 2025
@brian-dellabetta
Copy link
Collaborator

Hi @m4r1k , this is due to the gemma 3 models being multi-modal, whereas gemma 2 are purely causal LM. Multi-modal requires a bit more configuration and setup with respect to language-only. We typically exclude quantization of the vision encoder (e.g. Qwen 2.5 VL example here. If you include

ignore=["lm_head", "re:vision_model.*"],

in your modifier, does it succeed?

@brian-dellabetta brian-dellabetta self-assigned this Apr 2, 2025
@biskett
Copy link

biskett commented Apr 10, 2025

In my case, with below parameter, it works well.

ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"]

@brian-dellabetta
Copy link
Collaborator

@biskett excellent! I did check to make sure we didn't need any custom tracing logic. You can run something like

llmcompressor.trace \
    --model_id google/gemma-3-27b-it \
    --model_class Gemma3ForCausalLM \
    --ignore "lm_head" "re:vision_tower.*" "re:multi_modal_projector.*" \
    --modality text

for text and

llmcompressor.trace \
    --model_id google/gemma-3-27b-it \
    --model_class Gemma3ForConditionalGeneration \
    --ignore "lm_head" "re:vision_tower.*" "re:multi_modal_projector.*" \
    --modality vision

for vision, and if they succeed we don't need to add custom traceable classes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants