Skip to content

[Bug] precomputed_feature does not work with Llama 4 vision #8065

Open
@AlienKevin

Description

@AlienKevin

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Precomputed features are supported in VLMs like Qwen 2.5 VL and allow precomputed image features to be passed to SGL without going through the vision encoder. I tried to adapt the official usage example with Qwen to Llama 4 but found that the Llama 4 processor might be outdated and does not support inputing precomputed_features.

Reproduction

Following the doc, we first set the LLM model

import nest_asyncio

nest_asyncio.apply()  # Run this first.

model_path = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
chat_template = "llama-4"

We then load the prompt image

# Lets create a prompt.

from io import BytesIO
import requests
from PIL import Image

from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.conversation import chat_templates

image = Image.open(
    BytesIO(
        requests.get(
            "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
        ).content
    )
)

conv = chat_templates[chat_template].copy()
conv.append_message(conv.roles[0], f"What's shown here: {conv.image_token}?")
conv.append_message(conv.roles[1], "")
conv.image_data = [image]

print(conv.get_prompt())

Next, we load the LLM weights

from sglang import Engine

llm = Engine(
    model_path='meta-llama/Llama-4-Scout-17B-16E-Instruct',
    tokenizer_path='meta-llama/Llama-4-Scout-17B-16E-Instruct',
    tokenizer_mode='auto',
    skip_tokenizer_init=False,
    load_format='auto',
    model_loader_extra_config='{}',
    trust_remote_code=True,
    dtype='auto',
    kv_cache_dtype='auto',
    quantization='fp8',
    quantization_param_path=None,
    context_length=20000,
    device='cuda',
    served_model_name='meta-llama/Llama-4-Scout-17B-16E-Instruct',
    chat_template='llama_4_vision',
    completion_template=None,
    is_embedding=False,
    enable_multimodal=True,
    revision=None,
    hybrid_kvcache_ratio=None,
    impl='auto',
    host='0.0.0.0',
    port=8000,
    mem_fraction_static=0.91,
    max_running_requests=None,
    max_total_tokens=None,
    chunked_prefill_size=32768,
    max_prefill_tokens=32768,
    schedule_policy='fcfs',
    schedule_conservativeness=1.0,
    cpu_offload_gb=0,
    page_size=1,
    tp_size=8,
    pp_size=1,
    max_micro_batch_size=None,
    stream_interval=1,
    stream_output=False,
    random_seed=941523765,
    constrained_json_whitespace_pattern=None,
    watchdog_timeout=300,
    dist_timeout=None,
    download_dir=None,
    base_gpu_id=0,
    gpu_id_step=1,
    sleep_on_idle=False,
    log_level='info',
    log_level_http=None,
    log_requests=False,
    log_requests_level=0,
    crash_dump_folder=None,
    show_time_cost=False,
    enable_metrics=False,
    bucket_time_to_first_token=None,
    bucket_e2e_request_latency=None,
    bucket_inter_token_latency=None,
    collect_tokens_histogram=False,
    decode_log_interval=40,
    enable_request_time_stats_logging=False,
    kv_events_config=None,
    api_key=None,
    file_storage_path='sglang_storage',
    enable_cache_report=False,
    reasoning_parser=None,
    tool_call_parser=None,
    dp_size=1,
    load_balance_method='round_robin',
    dist_init_addr=None,
    nnodes=1,
    node_rank=0,
    json_model_override_args='{}',
    preferred_sampling_params=None,
    lora_paths=None,
    max_loras_per_batch=8,
    lora_backend='triton',
    attention_backend=None,
    sampling_backend='flashinfer',
    grammar_backend='xgrammar',
    mm_attention_backend=None,
    speculative_algorithm=None,
    speculative_draft_model_path=None,
    speculative_num_steps=None,
    speculative_eagle_topk=None,
    speculative_num_draft_tokens=None,
    speculative_accept_threshold_single=1.0,
    speculative_accept_threshold_acc=1.0,
    speculative_token_map=None,
    ep_size=1,
    enable_ep_moe=False,
    enable_deepep_moe=False,
    enable_flashinfer_moe=False,
    deepep_mode='auto',
    ep_num_redundant_experts=0,
    ep_dispatch_algorithm='static',
    init_expert_location='trivial',
    enable_eplb=False,
    eplb_algorithm='auto',
    eplb_rebalance_num_iterations=1000,
    eplb_rebalance_layers_per_chunk=None,
    expert_distribution_recorder_mode=None,
    expert_distribution_recorder_buffer_size=1000,
    enable_expert_distribution_metrics=False,
    deepep_config=None,
    moe_dense_tp_size=None,
    enable_double_sparsity=False,
    ds_channel_config_path=None,
    ds_heavy_channel_num=32,
    ds_heavy_token_num=256,
    ds_heavy_channel_type='qk',
    ds_sparse_decode_threshold=4096,
    disable_radix_cache=True,
    cuda_graph_max_bs=2000,
    cuda_graph_bs=[2000],
    disable_cuda_graph=False,
    disable_cuda_graph_padding=False,
    enable_profile_cuda_graph=False,
    enable_nccl_nvls=False,
    enable_tokenizer_batch_encode=False,
    disable_outlines_disk_cache=False,
    disable_custom_all_reduce=False,
    enable_mscclpp=False,
    disable_overlap_schedule=False,
    disable_overlap_cg_plan=False,
    enable_mixed_chunk=False,
    enable_dp_attention=False,
    enable_dp_lm_head=False,
    enable_two_batch_overlap=False,
    enable_torch_compile=False,
    torch_compile_max_bs=32,
    torchao_config='',
    enable_nan_detection=False,
    enable_p2p_check=False,
    triton_attention_reduce_in_fp32=False,
    triton_attention_num_kv_splits=8,
    num_continuous_decode_steps=1,
    delete_ckpt_after_loading=False,
    enable_memory_saver=False,
    allow_auto_truncate=False,
    enable_custom_logit_processor=False,
    enable_hierarchical_cache=False,
    hicache_ratio=2.0,
    hicache_size=0,
    hicache_write_policy='write_through_selective',
    flashinfer_mla_disable_ragged=False,
    disable_shared_experts_fusion=False,
    disable_chunked_prefix_cache=False,
    disable_fast_image_processor=False,
    enable_return_hidden_states=False,
    warmups=None,
    debug_tensor_dump_output_folder=None,
    debug_tensor_dump_input_file=None,
    debug_tensor_dump_inject=False,
    debug_tensor_dump_prefill_only=False,
    disaggregation_mode='null',
    disaggregation_transfer_backend='mooncake',
    disaggregation_bootstrap_port=8998,
    disaggregation_decode_tp=None,
    disaggregation_decode_dp=None,
    disaggregation_prefill_pp=1,
    disaggregation_ib_device=None,
    num_reserved_decode_tokens=512,
    pdlb_url=None,
    custom_weight_loader=[],
    weight_loader_disable_mmap=False
)

Lastly, we precompute image features with HF processor and pass them to SGL

processed_prompt = processor(
    images=[image], text=conv.get_prompt(), return_tensors="pt"
)
# print(processed_prompt)
input_ids = processed_prompt["input_ids"][0].detach().cpu().tolist()
precomputed_features = vision(
    processed_prompt["pixel_values"].cuda(),
)

mm_item = dict(
    modality="IMAGE",
    precomputed_features=precomputed_features,
)
out = llm.generate(input_ids=input_ids, image_data=[mm_item])
print(out["text"])

The above outputs the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[9], line 14
      6 precomputed_features = vision(
      7     processed_prompt["pixel_values"].cuda(),
      8 )
     10 mm_item = dict(
     11     modality="IMAGE",
     12     precomputed_features=precomputed_features,
     13 )
---> 14 out = llm.generate(input_ids=input_ids, image_data=[mm_item])
     15 print(out["text"])

File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/entrypoints/engine.py:218, in Engine.generate(self, prompt, sampling_params, input_ids, image_data, return_logprob, logprob_start_len, top_logprobs_num, token_ids_logprob, lora_path, custom_logit_processor, return_hidden_states, stream, bootstrap_host, bootstrap_port, bootstrap_room, data_parallel_rank)
    216     return generator_wrapper()
    217 else:
-->[ 218](about:blank)     ret = loop.run_until_complete(generator.__anext__())
    219     return ret

File /usr/local/lib/python3.10/dist-packages/nest_asyncio.py:98, in _patch_loop.<locals>.run_until_complete(self, future)
     95 if not f.done():
     96     raise RuntimeError(
     97         'Event loop stopped before Future completed.')
--->[ 98](about:blank) return f.result()

File /usr/lib/python3.10/asyncio/futures.py:201, in Future.result(self)
    199 self.__log_traceback = False
    200 if self._exception is not None:
-->[ 201](about:blank)     raise self._exception.with_traceback(self._exception_tb)
    202 return self._result

File /usr/lib/python3.10/asyncio/tasks.py:232, in Task.__step(***failed resolving arguments***)
    228 try:
    229     if exc is None:
    230         # We use the `send` method directly, because coroutines
    231         # don't have `__iter__` and `__next__` methods.
-->[ 232](about:blank)         result = coro.send(None)
    233     else:
    234         result = coro.throw(exc)

File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tokenizer_manager.py:460, in TokenizerManager.generate_request(self, obj, request)
    458 async with self.model_update_lock.reader_lock:
    459     if obj.is_single:
-->[ 460](about:blank)         tokenized_obj = await self._tokenize_one_request(obj)
    461         state = self._send_one_request(obj, tokenized_obj, created_time)
    462         async for response in self._wait_one_response(obj, state, request):

File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tokenizer_manager.py:514, in TokenizerManager._tokenize_one_request(self, obj)
    512 if not isinstance(obj.audio_data, list):
    513     obj.audio_data = [obj.audio_data]
-->[ 514](about:blank) mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
    515     image_data=obj.image_data,
    516     audio_data=obj.audio_data,
    517     input_text=input_text or input_ids,
    518     request_obj=obj,
    519     max_req_input_len=self.max_req_input_len,
    520 )
    521 if mm_inputs and "input_ids" in mm_inputs:
    522     input_ids = mm_inputs["input_ids"]

File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/multimodal/processors/mllama4.py:57, in Mllama4ImageProcessor.process_mm_data_async(self, image_data, input_text, max_req_input_len, *args, **kwargs)
     54 processor = self._processor
     56 # Process the prompt and images
--->[ 57](about:blank) processor_output = self.process_mm_data(
     58     input_text=processed_data.input_text,
     59     images=processed_data.images,
     60 )
     62 # Handle image resolutions and aspect ratios
     63 if "pixel_values" not in processor_output:  # no image processed

File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/multimodal/processors/base_processor.py:180, in BaseMultimodalProcessor.process_mm_data(self, input_text, images, videos, audios, **kwargs)
    176 if hasattr(processor, "image_processor") and isinstance(
    177     processor.image_processor, BaseImageProcessorFast
    178 ):
    179     kwargs["device"] = "cuda"
-->[ 180](about:blank) result = processor.__call__(
    181     text=[input_text],
    182     padding=True,
    183     return_tensors="pt",
    184     **kwargs,
    185 )
    186 if "pixel_values" in result and isinstance(
    187     result["pixel_values"], torch.Tensor
    188 ):
    189     result["pixel_values"] = result["pixel_values"].to("cpu")

File /usr/local/lib/python3.10/dist-packages/transformers/models/llama4/processing_llama4.py:191, in Llama4Processor.__call__(self, images, text, audio, videos, **kwargs)
    189 image_inputs = {}
    190 if images is not None:
-->[ 191](about:blank)     images = make_flat_list_of_images(images)
    192     image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
    193     image_height, image_width = image_inputs["pixel_values"][0].shape[-2:]

File /usr/local/lib/python3.10/dist-packages/transformers/image_utils.py:246, in make_flat_list_of_images(images)
    243     if images.ndim == 4:
    244         return list(images)
-->[ 246](about:blank) raise ValueError(f"Could not make a flat list of images from {images}")

ValueError: Could not make a flat list of images from [{'modality': 'IMAGE', 'precomputed_features': BaseModelOutput(last_hidden_state=tensor([[[-1.0010e-01, -1.5332e-01, -6.6895e-02,  ..., -1.0681e-02,
          -1.0449e-01, -3.2227e-02],

Seems like the precomputed_features are directly passed to the HuggingFace preprocessor even though the preprocessor should be skipped.

Environment

Python: 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H200
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.8, V12.8.93
CUDA Driver Version: 570.133.20
PyTorch: 2.7.1+cu128
sglang: 0.4.9.post2
sgl_kernel: 0.2.5
flashinfer_python: 0.2.7.post1
triton: 3.3.1
transformers: 4.53.0
torchao: 0.9.0
numpy: 2.2.6
aiohttp: 3.12.14
fastapi: 0.116.1
hf_transfer: 0.1.9
huggingface_hub: 0.33.4
interegular: 0.3.3
modelscope: 1.28.0
orjson: 3.10.18
outlines: 0.1.11
packaging: 25.0
psutil: 7.0.0
pydantic: 2.11.7
python-multipart: 0.0.20
pyzmq: 27.0.0
uvicorn: 0.35.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.21
openai: Module Not Found
tiktoken: Module Not Found
anthropic: Module Not Found
litellm: Module Not Found
decord: 0.6.0
NVIDIA Topology: 
	GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	NIC0	NIC1	NIC2	NIC3	NIC4	NIC5	NIC6	NIC7	NIC8NIC9	NIC10	NIC11	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NV18	PXB	NODE	NODE	NODE	NODE	NODE	SYS	SYS	SYSSYS	SYS	SYS	0-55,112-167	0		N/A
GPU1	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NODE	NODE	NODE	PXB	NODE	NODE	SYS	SYS	SYSSYS	SYS	SYS	0-55,112-167	0		N/A
GPU2	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NODE	NODE	NODE	NODE	PXB	NODE	SYS	SYS	SYSSYS	SYS	SYS	0-55,112-167	0		N/A
GPU3	NV18	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NODE	NODE	NODE	NODE	NODE	PXB	SYS	SYS	SYSSYS	SYS	SYS	0-55,112-167	0		N/A
GPU4	NV18	NV18	NV18	NV18	 X 	NV18	NV18	NV18	SYS	SYS	SYS	SYS	SYS	SYS	PXB	NODE	NODENODE	NODE	NODE	56-111,168-223	1		N/A
GPU5	NV18	NV18	NV18	NV18	NV18	 X 	NV18	NV18	SYS	SYS	SYS	SYS	SYS	SYS	NODE	NODE	NODEPXB	NODE	NODE	56-111,168-223	1		N/A
GPU6	NV18	NV18	NV18	NV18	NV18	NV18	 X 	NV18	SYS	SYS	SYS	SYS	SYS	SYS	NODE	NODE	NODENODE	PXB	NODE	56-111,168-223	1		N/A
GPU7	NV18	NV18	NV18	NV18	NV18	NV18	NV18	 X 	SYS	SYS	SYS	SYS	SYS	SYS	NODE	NODE	NODENODE	NODE	PXB	56-111,168-223	1		N/A
NIC0	PXB	NODE	NODE	NODE	SYS	SYS	SYS	SYS	 X 	NODE	NODE	NODE	NODE	NODE	SYS	SYS	SYSSYS	SYS	SYS				
NIC1	NODE	NODE	NODE	NODE	SYS	SYS	SYS	SYS	NODE	 X 	PIX	NODE	NODE	NODE	SYS	SYS	SYSSYS	SYS	SYS				
NIC2	NODE	NODE	NODE	NODE	SYS	SYS	SYS	SYS	NODE	PIX	 X 	NODE	NODE	NODE	SYS	SYS	SYSSYS	SYS	SYS				
NIC3	NODE	PXB	NODE	NODE	SYS	SYS	SYS	SYS	NODE	NODE	NODE	 X 	NODE	NODE	SYS	SYS	SYSSYS	SYS	SYS				
NIC4	NODE	NODE	PXB	NODE	SYS	SYS	SYS	SYS	NODE	NODE	NODE	NODE	 X 	NODE	SYS	SYS	SYSSYS	SYS	SYS				
NIC5	NODE	NODE	NODE	PXB	SYS	SYS	SYS	SYS	NODE	NODE	NODE	NODE	NODE	 X 	SYS	SYS	SYSSYS	SYS	SYS				
NIC6	SYS	SYS	SYS	SYS	PXB	NODE	NODE	NODE	SYS	SYS	SYS	SYS	SYS	SYS	 X 	NODE	NODENODE	NODE	NODE				
NIC7	SYS	SYS	SYS	SYS	NODE	NODE	NODE	NODE	SYS	SYS	SYS	SYS	SYS	SYS	NODE	 X 	PIXNODE	NODE	NODE				
NIC8	SYS	SYS	SYS	SYS	NODE	NODE	NODE	NODE	SYS	SYS	SYS	SYS	SYS	SYS	NODE	PIX	 X NODE	NODE	NODE				
NIC9	SYS	SYS	SYS	SYS	NODE	PXB	NODE	NODE	SYS	SYS	SYS	SYS	SYS	SYS	NODE	NODE	NODE X 	NODE	NODE				
NIC10	SYS	SYS	SYS	SYS	NODE	NODE	PXB	NODE	SYS	SYS	SYS	SYS	SYS	SYS	NODE	NODE	NODENODE	 X 	NODE				
NIC11	SYS	SYS	SYS	SYS	NODE	NODE	NODE	PXB	SYS	SYS	SYS	SYS	SYS	SYS	NODE	NODE	NODENODE	NODE	 X 				

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8
  NIC9: mlx5_9
  NIC10: mlx5_10
  NIC11: mlx5_11


ulimit soft: 1048576

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions