Open
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 5. Please use English, otherwise it will be closed.
Describe the bug
Precomputed features are supported in VLMs like Qwen 2.5 VL and allow precomputed image features to be passed to SGL without going through the vision encoder. I tried to adapt the official usage example with Qwen to Llama 4 but found that the Llama 4 processor might be outdated and does not support inputing precomputed_features.
Reproduction
Following the doc, we first set the LLM model
import nest_asyncio
nest_asyncio.apply() # Run this first.
model_path = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
chat_template = "llama-4"
We then load the prompt image
# Lets create a prompt.
from io import BytesIO
import requests
from PIL import Image
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.conversation import chat_templates
image = Image.open(
BytesIO(
requests.get(
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
).content
)
)
conv = chat_templates[chat_template].copy()
conv.append_message(conv.roles[0], f"What's shown here: {conv.image_token}?")
conv.append_message(conv.roles[1], "")
conv.image_data = [image]
print(conv.get_prompt())
Next, we load the LLM weights
from sglang import Engine
llm = Engine(
model_path='meta-llama/Llama-4-Scout-17B-16E-Instruct',
tokenizer_path='meta-llama/Llama-4-Scout-17B-16E-Instruct',
tokenizer_mode='auto',
skip_tokenizer_init=False,
load_format='auto',
model_loader_extra_config='{}',
trust_remote_code=True,
dtype='auto',
kv_cache_dtype='auto',
quantization='fp8',
quantization_param_path=None,
context_length=20000,
device='cuda',
served_model_name='meta-llama/Llama-4-Scout-17B-16E-Instruct',
chat_template='llama_4_vision',
completion_template=None,
is_embedding=False,
enable_multimodal=True,
revision=None,
hybrid_kvcache_ratio=None,
impl='auto',
host='0.0.0.0',
port=8000,
mem_fraction_static=0.91,
max_running_requests=None,
max_total_tokens=None,
chunked_prefill_size=32768,
max_prefill_tokens=32768,
schedule_policy='fcfs',
schedule_conservativeness=1.0,
cpu_offload_gb=0,
page_size=1,
tp_size=8,
pp_size=1,
max_micro_batch_size=None,
stream_interval=1,
stream_output=False,
random_seed=941523765,
constrained_json_whitespace_pattern=None,
watchdog_timeout=300,
dist_timeout=None,
download_dir=None,
base_gpu_id=0,
gpu_id_step=1,
sleep_on_idle=False,
log_level='info',
log_level_http=None,
log_requests=False,
log_requests_level=0,
crash_dump_folder=None,
show_time_cost=False,
enable_metrics=False,
bucket_time_to_first_token=None,
bucket_e2e_request_latency=None,
bucket_inter_token_latency=None,
collect_tokens_histogram=False,
decode_log_interval=40,
enable_request_time_stats_logging=False,
kv_events_config=None,
api_key=None,
file_storage_path='sglang_storage',
enable_cache_report=False,
reasoning_parser=None,
tool_call_parser=None,
dp_size=1,
load_balance_method='round_robin',
dist_init_addr=None,
nnodes=1,
node_rank=0,
json_model_override_args='{}',
preferred_sampling_params=None,
lora_paths=None,
max_loras_per_batch=8,
lora_backend='triton',
attention_backend=None,
sampling_backend='flashinfer',
grammar_backend='xgrammar',
mm_attention_backend=None,
speculative_algorithm=None,
speculative_draft_model_path=None,
speculative_num_steps=None,
speculative_eagle_topk=None,
speculative_num_draft_tokens=None,
speculative_accept_threshold_single=1.0,
speculative_accept_threshold_acc=1.0,
speculative_token_map=None,
ep_size=1,
enable_ep_moe=False,
enable_deepep_moe=False,
enable_flashinfer_moe=False,
deepep_mode='auto',
ep_num_redundant_experts=0,
ep_dispatch_algorithm='static',
init_expert_location='trivial',
enable_eplb=False,
eplb_algorithm='auto',
eplb_rebalance_num_iterations=1000,
eplb_rebalance_layers_per_chunk=None,
expert_distribution_recorder_mode=None,
expert_distribution_recorder_buffer_size=1000,
enable_expert_distribution_metrics=False,
deepep_config=None,
moe_dense_tp_size=None,
enable_double_sparsity=False,
ds_channel_config_path=None,
ds_heavy_channel_num=32,
ds_heavy_token_num=256,
ds_heavy_channel_type='qk',
ds_sparse_decode_threshold=4096,
disable_radix_cache=True,
cuda_graph_max_bs=2000,
cuda_graph_bs=[2000],
disable_cuda_graph=False,
disable_cuda_graph_padding=False,
enable_profile_cuda_graph=False,
enable_nccl_nvls=False,
enable_tokenizer_batch_encode=False,
disable_outlines_disk_cache=False,
disable_custom_all_reduce=False,
enable_mscclpp=False,
disable_overlap_schedule=False,
disable_overlap_cg_plan=False,
enable_mixed_chunk=False,
enable_dp_attention=False,
enable_dp_lm_head=False,
enable_two_batch_overlap=False,
enable_torch_compile=False,
torch_compile_max_bs=32,
torchao_config='',
enable_nan_detection=False,
enable_p2p_check=False,
triton_attention_reduce_in_fp32=False,
triton_attention_num_kv_splits=8,
num_continuous_decode_steps=1,
delete_ckpt_after_loading=False,
enable_memory_saver=False,
allow_auto_truncate=False,
enable_custom_logit_processor=False,
enable_hierarchical_cache=False,
hicache_ratio=2.0,
hicache_size=0,
hicache_write_policy='write_through_selective',
flashinfer_mla_disable_ragged=False,
disable_shared_experts_fusion=False,
disable_chunked_prefix_cache=False,
disable_fast_image_processor=False,
enable_return_hidden_states=False,
warmups=None,
debug_tensor_dump_output_folder=None,
debug_tensor_dump_input_file=None,
debug_tensor_dump_inject=False,
debug_tensor_dump_prefill_only=False,
disaggregation_mode='null',
disaggregation_transfer_backend='mooncake',
disaggregation_bootstrap_port=8998,
disaggregation_decode_tp=None,
disaggregation_decode_dp=None,
disaggregation_prefill_pp=1,
disaggregation_ib_device=None,
num_reserved_decode_tokens=512,
pdlb_url=None,
custom_weight_loader=[],
weight_loader_disable_mmap=False
)
Lastly, we precompute image features with HF processor and pass them to SGL
processed_prompt = processor(
images=[image], text=conv.get_prompt(), return_tensors="pt"
)
# print(processed_prompt)
input_ids = processed_prompt["input_ids"][0].detach().cpu().tolist()
precomputed_features = vision(
processed_prompt["pixel_values"].cuda(),
)
mm_item = dict(
modality="IMAGE",
precomputed_features=precomputed_features,
)
out = llm.generate(input_ids=input_ids, image_data=[mm_item])
print(out["text"])
The above outputs the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[9], line 14
6 precomputed_features = vision(
7 processed_prompt["pixel_values"].cuda(),
8 )
10 mm_item = dict(
11 modality="IMAGE",
12 precomputed_features=precomputed_features,
13 )
---> 14 out = llm.generate(input_ids=input_ids, image_data=[mm_item])
15 print(out["text"])
File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/entrypoints/engine.py:218, in Engine.generate(self, prompt, sampling_params, input_ids, image_data, return_logprob, logprob_start_len, top_logprobs_num, token_ids_logprob, lora_path, custom_logit_processor, return_hidden_states, stream, bootstrap_host, bootstrap_port, bootstrap_room, data_parallel_rank)
216 return generator_wrapper()
217 else:
-->[ 218](about:blank) ret = loop.run_until_complete(generator.__anext__())
219 return ret
File /usr/local/lib/python3.10/dist-packages/nest_asyncio.py:98, in _patch_loop.<locals>.run_until_complete(self, future)
95 if not f.done():
96 raise RuntimeError(
97 'Event loop stopped before Future completed.')
--->[ 98](about:blank) return f.result()
File /usr/lib/python3.10/asyncio/futures.py:201, in Future.result(self)
199 self.__log_traceback = False
200 if self._exception is not None:
-->[ 201](about:blank) raise self._exception.with_traceback(self._exception_tb)
202 return self._result
File /usr/lib/python3.10/asyncio/tasks.py:232, in Task.__step(***failed resolving arguments***)
228 try:
229 if exc is None:
230 # We use the `send` method directly, because coroutines
231 # don't have `__iter__` and `__next__` methods.
-->[ 232](about:blank) result = coro.send(None)
233 else:
234 result = coro.throw(exc)
File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tokenizer_manager.py:460, in TokenizerManager.generate_request(self, obj, request)
458 async with self.model_update_lock.reader_lock:
459 if obj.is_single:
-->[ 460](about:blank) tokenized_obj = await self._tokenize_one_request(obj)
461 state = self._send_one_request(obj, tokenized_obj, created_time)
462 async for response in self._wait_one_response(obj, state, request):
File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/managers/tokenizer_manager.py:514, in TokenizerManager._tokenize_one_request(self, obj)
512 if not isinstance(obj.audio_data, list):
513 obj.audio_data = [obj.audio_data]
-->[ 514](about:blank) mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
515 image_data=obj.image_data,
516 audio_data=obj.audio_data,
517 input_text=input_text or input_ids,
518 request_obj=obj,
519 max_req_input_len=self.max_req_input_len,
520 )
521 if mm_inputs and "input_ids" in mm_inputs:
522 input_ids = mm_inputs["input_ids"]
File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/multimodal/processors/mllama4.py:57, in Mllama4ImageProcessor.process_mm_data_async(self, image_data, input_text, max_req_input_len, *args, **kwargs)
54 processor = self._processor
56 # Process the prompt and images
--->[ 57](about:blank) processor_output = self.process_mm_data(
58 input_text=processed_data.input_text,
59 images=processed_data.images,
60 )
62 # Handle image resolutions and aspect ratios
63 if "pixel_values" not in processor_output: # no image processed
File /home/lik/benchmarking_toolkit/my_sglang/python/sglang/srt/multimodal/processors/base_processor.py:180, in BaseMultimodalProcessor.process_mm_data(self, input_text, images, videos, audios, **kwargs)
176 if hasattr(processor, "image_processor") and isinstance(
177 processor.image_processor, BaseImageProcessorFast
178 ):
179 kwargs["device"] = "cuda"
-->[ 180](about:blank) result = processor.__call__(
181 text=[input_text],
182 padding=True,
183 return_tensors="pt",
184 **kwargs,
185 )
186 if "pixel_values" in result and isinstance(
187 result["pixel_values"], torch.Tensor
188 ):
189 result["pixel_values"] = result["pixel_values"].to("cpu")
File /usr/local/lib/python3.10/dist-packages/transformers/models/llama4/processing_llama4.py:191, in Llama4Processor.__call__(self, images, text, audio, videos, **kwargs)
189 image_inputs = {}
190 if images is not None:
-->[ 191](about:blank) images = make_flat_list_of_images(images)
192 image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
193 image_height, image_width = image_inputs["pixel_values"][0].shape[-2:]
File /usr/local/lib/python3.10/dist-packages/transformers/image_utils.py:246, in make_flat_list_of_images(images)
243 if images.ndim == 4:
244 return list(images)
-->[ 246](about:blank) raise ValueError(f"Could not make a flat list of images from {images}")
ValueError: Could not make a flat list of images from [{'modality': 'IMAGE', 'precomputed_features': BaseModelOutput(last_hidden_state=tensor([[[-1.0010e-01, -1.5332e-01, -6.6895e-02, ..., -1.0681e-02,
-1.0449e-01, -3.2227e-02],
Seems like the precomputed_features are directly passed to the HuggingFace preprocessor even though the preprocessor should be skipped.
Environment
Python: 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H200
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.8, V12.8.93
CUDA Driver Version: 570.133.20
PyTorch: 2.7.1+cu128
sglang: 0.4.9.post2
sgl_kernel: 0.2.5
flashinfer_python: 0.2.7.post1
triton: 3.3.1
transformers: 4.53.0
torchao: 0.9.0
numpy: 2.2.6
aiohttp: 3.12.14
fastapi: 0.116.1
hf_transfer: 0.1.9
huggingface_hub: 0.33.4
interegular: 0.3.3
modelscope: 1.28.0
orjson: 3.10.18
outlines: 0.1.11
packaging: 25.0
psutil: 7.0.0
pydantic: 2.11.7
python-multipart: 0.0.20
pyzmq: 27.0.0
uvicorn: 0.35.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.21
openai: Module Not Found
tiktoken: Module Not Found
anthropic: Module Not Found
litellm: Module Not Found
decord: 0.6.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 NIC8NIC9 NIC10 NIC11 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 PXB NODE NODE NODE NODE NODE SYS SYS SYSSYS SYS SYS 0-55,112-167 0 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 NODE NODE NODE PXB NODE NODE SYS SYS SYSSYS SYS SYS 0-55,112-167 0 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 NODE NODE NODE NODE PXB NODE SYS SYS SYSSYS SYS SYS 0-55,112-167 0 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 NODE NODE NODE NODE NODE PXB SYS SYS SYSSYS SYS SYS 0-55,112-167 0 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 SYS SYS SYS SYS SYS SYS PXB NODE NODENODE NODE NODE 56-111,168-223 1 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 SYS SYS SYS SYS SYS SYS NODE NODE NODEPXB NODE NODE 56-111,168-223 1 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 SYS SYS SYS SYS SYS SYS NODE NODE NODENODE PXB NODE 56-111,168-223 1 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X SYS SYS SYS SYS SYS SYS NODE NODE NODENODE NODE PXB 56-111,168-223 1 N/A
NIC0 PXB NODE NODE NODE SYS SYS SYS SYS X NODE NODE NODE NODE NODE SYS SYS SYSSYS SYS SYS
NIC1 NODE NODE NODE NODE SYS SYS SYS SYS NODE X PIX NODE NODE NODE SYS SYS SYSSYS SYS SYS
NIC2 NODE NODE NODE NODE SYS SYS SYS SYS NODE PIX X NODE NODE NODE SYS SYS SYSSYS SYS SYS
NIC3 NODE PXB NODE NODE SYS SYS SYS SYS NODE NODE NODE X NODE NODE SYS SYS SYSSYS SYS SYS
NIC4 NODE NODE PXB NODE SYS SYS SYS SYS NODE NODE NODE NODE X NODE SYS SYS SYSSYS SYS SYS
NIC5 NODE NODE NODE PXB SYS SYS SYS SYS NODE NODE NODE NODE NODE X SYS SYS SYSSYS SYS SYS
NIC6 SYS SYS SYS SYS PXB NODE NODE NODE SYS SYS SYS SYS SYS SYS X NODE NODENODE NODE NODE
NIC7 SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYS SYS SYS SYS NODE X PIXNODE NODE NODE
NIC8 SYS SYS SYS SYS NODE NODE NODE NODE SYS SYS SYS SYS SYS SYS NODE PIX X NODE NODE NODE
NIC9 SYS SYS SYS SYS NODE PXB NODE NODE SYS SYS SYS SYS SYS SYS NODE NODE NODE X NODE NODE
NIC10 SYS SYS SYS SYS NODE NODE PXB NODE SYS SYS SYS SYS SYS SYS NODE NODE NODENODE X NODE
NIC11 SYS SYS SYS SYS NODE NODE NODE PXB SYS SYS SYS SYS SYS SYS NODE NODE NODENODE NODE X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
NIC5: mlx5_5
NIC6: mlx5_6
NIC7: mlx5_7
NIC8: mlx5_8
NIC9: mlx5_9
NIC10: mlx5_10
NIC11: mlx5_11
ulimit soft: 1048576
Metadata
Metadata
Assignees
Labels
No labels