Skip to content

[V1][Kernel] Flashinfer HND KV cache layout #19280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
KV cache helper for store.
"""

import torch

import vllm.envs as envs
Expand Down Expand Up @@ -94,9 +93,11 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,


def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if vllm_config.model_config is None:
if vllm_config.model_config is None or kv_config is None:
logger.warning("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.")
else:
Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,15 @@ def get_kv_cache_shape(

@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_connector_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError("Unknown cache layout format %s.", cache_layout)
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order


Expand Down
44 changes: 39 additions & 5 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""Attention layer with FlashInfer."""
from __future__ import annotations

import functools
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional

Expand All @@ -16,6 +18,8 @@
AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
Expand All @@ -28,10 +32,25 @@
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be defined like all other env variables in vLLM. envs.py and VLLM_FLASHINFER_KV_CACHE_LAYOUT

I also think we should have the env variable be VLLM_KV_CACHE_LAYOUT rather than having a specific one for each attention backend type.

"").upper()

logger = init_logger(__name__)


@functools.lru_cache
def get_flashinfer_kv_cache_layout():
# Override with format specified by the user.
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
if not cache_layout:
cache_layout = get_kv_connector_cache_layout()
else:
logger.info("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)

return cache_layout


class FlashInferBackend(AttentionBackend):

accept_output_buffer: bool = True
Expand Down Expand Up @@ -65,6 +84,19 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_flashinfer_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order


@dataclass
class PerLayerParameters:
Expand Down Expand Up @@ -289,7 +321,7 @@ def _get_workspace_buffer(self):
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
self._get_workspace_buffer(), get_flashinfer_kv_cache_layout())
return self._prefill_wrapper

def _get_decode_wrapper(self):
Expand All @@ -302,14 +334,15 @@ def _get_decode_wrapper(self):
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
get_flashinfer_kv_cache_layout(),
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
2, self._get_workspace_buffer(), "NHD")
2, self._get_workspace_buffer(),
get_flashinfer_kv_cache_layout())
return self._cascade_wrapper

def _plan(self, attn_metadata: FlashInferMetadata):
Expand Down Expand Up @@ -607,6 +640,7 @@ def forward(
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens

stride_order = FlashInferBackend.get_kv_cache_stride_order()
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
Expand All @@ -621,7 +655,7 @@ def forward(
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
Expand All @@ -637,7 +671,7 @@ def forward(
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
Expand Down