diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b9bed06d791..4412b8bb02c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,7 +3,6 @@ """ KV cache helper for store. """ - import torch import vllm.envs as envs @@ -94,9 +93,11 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, def get_kv_connector_cache_layout(): + # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is + # used for faster transfer. vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config - if vllm_config.model_config is None: + if vllm_config.model_config is None or kv_config is None: logger.warning("Unable to detect current VLLM config. " \ "Defaulting to NHD kv cache layout.") else: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 91a7c43cd8d..61e7e334f29 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -74,8 +74,7 @@ def get_kv_cache_shape( @staticmethod def get_kv_cache_stride_order() -> tuple[int, ...]: - # NOTE When running disaggregated PD with NIXL, HND layout is used for - # faster transfer. `stride_order` indicates the permutation that gets + # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_connector_cache_layout() if cache_layout == "NHD": @@ -83,7 +82,7 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: elif cache_layout == "HND": stride_order = (0, 1, 3, 2, 4) else: - raise ValueError("Unknown cache layout format %s.", cache_layout) + raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f1b61c152a9..102416481f4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -3,6 +3,8 @@ """Attention layer with FlashInfer.""" from __future__ import annotations +import functools +import os from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -16,6 +18,8 @@ AttentionType) from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_kv_connector_cache_layout) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -28,10 +32,25 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", + "").upper() logger = init_logger(__name__) +@functools.lru_cache +def get_flashinfer_kv_cache_layout(): + # Override with format specified by the user. + cache_layout = FLASHINFER_KV_CACHE_LAYOUT + if not cache_layout: + cache_layout = get_kv_connector_cache_layout() + else: + logger.info("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \ + "detected. Setting KV cache layout to %s.", cache_layout) + + return cache_layout + + class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True @@ -65,6 +84,19 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets us from + # `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_flashinfer_kv_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + @dataclass class PerLayerParameters: @@ -289,7 +321,7 @@ def _get_workspace_buffer(self): def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") + self._get_workspace_buffer(), get_flashinfer_kv_cache_layout()) return self._prefill_wrapper def _get_decode_wrapper(self): @@ -302,14 +334,15 @@ def _get_decode_wrapper(self): num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), - "NHD", + get_flashinfer_kv_cache_layout(), use_tensor_cores=use_tensor_cores) return self._decode_wrapper def _get_cascade_wrapper(self): if self._cascade_wrapper is None: self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), "NHD") + 2, self._get_workspace_buffer(), + get_flashinfer_kv_cache_layout()) return self._cascade_wrapper def _plan(self, attn_metadata: FlashInferMetadata): @@ -607,6 +640,7 @@ def forward( num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens + stride_order = FlashInferBackend.get_kv_cache_stride_order() # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() @@ -621,7 +655,7 @@ def forward( assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[num_decode_tokens:], @@ -637,7 +671,7 @@ def forward( assert decode_wrapper._sm_scale == self.scale decode_wrapper.run( decode_query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[:num_decode_tokens],