-
Notifications
You must be signed in to change notification settings - Fork 621
Open
Labels
Description
Hi, I'm trying to use the MLA C++ API from FlashInfer but encountered some issues.
I referenced src/bench_batch_decode.cu to use the flashinfer::BatchDecodeWithPagedKVCacheWrapperMLA interface.
Here’s how I implemented it:
1. #include <flashinfer_ops.cuh>
2. Use the interface flashinfer::BatchDecodeWithPagedKVCacheWrapperMLA
in my cmakelist.txt:
include(FetchContent)
FetchContent_Declare(
flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG main
GIT_SUBMODULES ""
)
FetchContent_Populate(flashinfer)
However, during compilation, I got this error:
In file included from /data/A/build/_deps/flashinfer-src/src/flashinfer_ops.cuh:27,
from /data/Akernels/flashinfer_kernels.cu:5:
/data/A/build/_deps/flashinfer-src/src/utils.h:33:10: fatal error: generated/dispatch.inc: No such file or directory
33 | #include "generated/dispatch.inc"
When I commented out #include "generated/dispatch.inc" in flashinfer/src/utils.h, the compilation then failed with:
/data/A/build/_deps/flashinfer-src/src/flashinfer_ops.cuh(297): error: identifier "USE_FP16_QK_REDUCTION" is undefined
switch (use_fp16_qk_reduction) { _DISPATCH_CASES_use_fp16_qk_reduction(USE_FP16_QK_REDUCTION, {switch (head_dim) { _DISPATCH_CASES_head_dim(HEAD_DIM, {switch (pos_encoding_mode) { _DISPATCH_CASES_pos_encoding_mode(POS_ENCODING_MODE, { using Params = SinglePrefillParams<DTypeIn, DTypeIn, DTypeO>; using AttentionVariant = DefaultAttention< true, false, false, false>; Params params(q, k, v, custom_mask, o, lse, nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, -1, 0.f, sm_scale, rope_scale, rope_theta); return SinglePrefillWithKVCacheDispatched<HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MaskMode::kCustom, AttentionVariant>(params, tmp, stream); }) default: std::ostringstream oss; oss << __PRETTY_FUNCTION__ << " failed to dispatch " "positional encoding mode" " " << int(pos_encoding_mode); throw Error(__FUNCTION__, "/data/A/build/_deps/flashinfer-src/src/flashinfer_ops.cuh", 300, oss.str()); }}) default: std::ostringstream oss; oss << __PRETTY_FUNCTION__ << " failed to dispatch " "head_dim" " " << int(head_dim); throw Error(__FUNCTION__, "/data/A/build/_deps/flashinfer-src/src/flashinfer_ops.cuh", 299, oss.str()); }}) default: std::ostringstream oss; oss << __PRETTY_FUNCTION__ << " failed to dispatch " "use_fp16_qk_reduction" " " << int(use_fp16_qk_reduction); throw Error(__FUNCTION__, "/data/A/build/_deps/flashinfer-src/src/flashinfer_ops.cuh", 297, oss.str()); }
Question: What is the correct way to integrate FlashInfer’s decode MLA interface into my C++ project? Should I follow specific build steps or configuration flags to resolve these issues?