You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Following the sample to quantize Gemma2 to FP8, and using the Gemma3-27B-IT model, the llm compressor ends as expected but then when running vllm it fails with KeyError: 'vision_model.encoder.layers.0.mlp.fc1.weight_scale'
Expected behavior
To improve UX, I'd expect the provided samples would be self consistent without further modifications by end users.
Environment
Include all relevant environment information:
OS [e.g. Ubuntu 20.04]: Ubuntu 24.04
Python version [e.g. 3.7]: 3.12
LLM Compressor version or commit hash [e.g. 0.1.0, f7245c8]: both 0.4.1 and db91486
ML framework version(s) [e.g. torch 2.3.1]: torch 2.6.0
Other Python package versions [e.g. vLLM, compressed-tensors, numpy, ONNX]: pip install --upgrade llmcompressor==0.4.1 vllm==0.8.2 lm_eval==0.4.3 as well as huggingface_hub==0.30.1 and hf_transfer==0.1.9
Other relevant environment information [e.g. hardware, CUDA version]: GCP a3-highgpu-1g, 1x H100, 570.86.15, and CUDA 12.8
To Reproduce
Exact steps to reproduce the behavior:
#!/bin/bash
apt update
apt install -y python3-virtualenv
virtualenv llm
source llm/bin/activate
pip install --upgrade llmcompressor==0.4.1 vllm==0.8.2 lm_eval==0.4.3
pip install --upgrade huggingface_hub[hf_transfer]
export HF_HUB_ENABLE_HF_TRANSFER=1
export HF_TOKEN=<HF Token>
export HF_HOME=/models
mkdir -p /models
_MODEL=google/gemma-3-27b-it
_QUANTIZED_MODEL=/root/gemma-3-27b-it-FP8-Dynamic
huggingface-cli download "${_MODEL}"
export CUDA_VISIBLE_DEVICES=0
python3 gemma3.py
# https://github.com/vllm-project/llm-compressor/issues/1305
for _FILE in chat_template.json preprocessor_config.json processor_config.json
do
huggingface-cli download ${_MODEL} ${_FILE} --local-dir="${_QUANTIZED_MODEL}/"
done
python3 - << EOF
from vllm import LLM
model = LLM("${_MODEL}")
model.generate("Hello my name is")
EOF
lm_eval \
--model vllm \
--model_args pretrained=${_MODEL},add_bos_token=True,max_model_len=4096 \
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
python3 - << EOF
from vllm import LLM
model = LLM("${_QUANTIZED_MODEL}")
model.generate("Hello my name is")
EOF
lm_eval \
--model vllm \
--model_args pretrained=${_QUANTIZED_MODEL},add_bos_token=True,max_model_len=4096 \
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
Errors
If applicable, add a full print-out of any errors or exceptions that are raised or include screenshots to help explain your problem.
INFO 04-01 18:19:01 [__init__.py:239] Automatically detected platform cuda.
INFO 04-01 18:19:07 [config.py:585] This model supports multiple tasks: {'generate', 'embed', 'classify', 'reward', 'score'}. Defaulting to 'generate'.
INFO 04-01 18:19:08 [config.py:1697] Chunked prefill is enabled with max_num_batched_tokens=16384.
INFO 04-01 18:19:09 [core.py:54] Initializing a V1 LLM engine (v0.8.2) with config: model='/root/gemma-3-27b-it-FP8-Dynamic', speculative_config=None, tokenizer='/root/gemma-3-27b-it-FP8-Dynamic', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=/root/gemma-3-27b-it-FP8-Dynamic, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 04-01 18:19:10 [utils.py:2321] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x75eeab66e3f0>
INFO 04-01 18:19:11 [parallel_state.py:954] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 04-01 18:19:11 [cuda.py:220] Using Flash Attention backend on V1 engine.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
INFO 04-01 18:19:16 [gpu_model_runner.py:1174] Starting to load model /root/gemma-3-27b-it-FP8-Dynamic...
INFO 04-01 18:19:16 [config.py:3243] cudagraph sizes specified by model runner [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312, 320, 328, 336, 344, 352, 360, 368, 376, 384, 392, 400, 408, 416, 424, 432, 440, 448, 456, 464, 472, 480, 488, 496, 504, 512] is overridden by config [512, 384, 256, 128, 4, 2, 1, 392, 264, 136, 8, 400, 272, 144, 16, 408, 280, 152, 24, 416, 288, 160, 32, 424, 296, 168, 40, 432, 304, 176, 48, 440, 312, 184, 56, 448, 320, 192, 64, 456, 328, 200, 72, 464, 336, 208, 80, 472, 344, 216, 88, 120, 480, 352, 248, 224, 96, 488, 504, 360, 232, 104, 496, 368, 240, 112, 376]
WARNING 04-01 18:19:16 [topk_topp_sampler.py:63] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
Loading safetensors checkpoint shards: 0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 29% Completed | 2/7 [00:00<00:01, 2.65it/s]
ERROR 04-01 18:19:18 [core.py:343] EngineCore hit an exception: Traceback (most recent call last):
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 335, in run_engine_core
ERROR 04-01 18:19:18 [core.py:343] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 290, in __init__
ERROR 04-01 18:19:18 [core.py:343] super().__init__(vllm_config, executor_class, log_stats)
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 60, in __init__
ERROR 04-01 18:19:18 [core.py:343] self.model_executor = executor_class(vllm_config)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 52, in __init__
ERROR 04-01 18:19:18 [core.py:343] self._init_executor()
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 47, in _init_executor
ERROR 04-01 18:19:18 [core.py:343] self.collective_rpc("load_model")
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 04-01 18:19:18 [core.py:343] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/utils.py", line 2255, in run_method
ERROR 04-01 18:19:18 [core.py:343] return func(*args, **kwargs)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 136, in load_model
ERROR 04-01 18:19:18 [core.py:343] self.model_runner.load_model()
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1177, in load_model
ERROR 04-01 18:19:18 [core.py:343] self.model = get_model(vllm_config=self.vllm_config)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
ERROR 04-01 18:19:18 [core.py:343] return loader.load_model(vllm_config=vllm_config)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/model_loader/loader.py", line 444, in load_model
ERROR 04-01 18:19:18 [core.py:343] loaded_weights = model.load_weights(
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/gemma3_mm.py", line 793, in load_weights
ERROR 04-01 18:19:18 [core.py:343] return loader.load_weights(weights)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/utils.py", line 235, in load_weights
ERROR 04-01 18:19:18 [core.py:343] autoloaded_weights = set(self._load_module("", self.module, weights))
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/utils.py", line 196, in _load_module
ERROR 04-01 18:19:18 [core.py:343] yield from self._load_module(prefix,
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/utils.py", line 173, in _load_module
ERROR 04-01 18:19:18 [core.py:343] loaded_params = module_load_weights(weights)
ERROR 04-01 18:19:18 [core.py:343] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-01 18:19:18 [core.py:343] File "/root/llm/lib/python3.12/site-packages/vllm/model_executor/models/siglip.py", line 518, in load_weights
ERROR 04-01 18:19:18 [core.py:343] param = params_dict[name]
ERROR 04-01 18:19:18 [core.py:343] ~~~~~~~~~~~^^^^^^
ERROR 04-01 18:19:18 [core.py:343] KeyError: 'vision_model.encoder.layers.0.mlp.fc1.weight_scale'
ERROR 04-01 18:19:18 [core.py:343]
Loading safetensors checkpoint shards: 29% Completed | 2/7 [00:01<00:04, 1.20it/s]
CRITICAL 04-01 18:19:18 [core_client.py:269] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
Killed
Additional context
Add any other context about the problem here. Also include any relevant files.
The text was updated successfully, but these errors were encountered:
Hi @m4r1k , this is due to the gemma 3 models being multi-modal, whereas gemma 2 are purely causal LM. Multi-modal requires a bit more configuration and setup with respect to language-only. We typically exclude quantization of the vision encoder (e.g. Qwen 2.5 VL example here. If you include
Describe the bug
Following the sample to quantize Gemma2 to FP8, and using the Gemma3-27B-IT model, the llm compressor ends as expected but then when running vllm it fails with
KeyError: 'vision_model.encoder.layers.0.mlp.fc1.weight_scale'
Expected behavior
To improve UX, I'd expect the provided samples would be self consistent without further modifications by end users.
Environment
Include all relevant environment information:
f7245c8
]: both 0.4.1 anddb91486
pip install --upgrade llmcompressor==0.4.1 vllm==0.8.2 lm_eval==0.4.3
as well ashuggingface_hub==0.30.1
andhf_transfer==0.1.9
570.86.15
, and CUDA12.8
To Reproduce
Exact steps to reproduce the behavior:
Errors
If applicable, add a full print-out of any errors or exceptions that are raised or include screenshots to help explain your problem.
Additional context
Add any other context about the problem here. Also include any relevant files.
The text was updated successfully, but these errors were encountered: