-
Notifications
You must be signed in to change notification settings - Fork 3k
XAttention for XE1 platform #33307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
XAttention for XE1 platform #33307
Conversation
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
| CacheHint::Cached, | ||
| CacheHint::Cached>(q_gather, gather_offsets, gather_pred); | ||
| rQ[ri].format<uint>() = gathered; | ||
| rQ[ri].format<half>() = cm_mul<half>(rQ[ri].format<half>(), (half)scale_factor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not directly use gathered to do cm_mul to avoid one register copying?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly multiplying gathered increases load-use stalls (sync.nop +23). Splitting load and scale phases preserves load/mul counts and avoids extra scoreboard waits, so I kept this form. will update it.
|
|
||
| #endif | ||
|
|
||
| // This condition only works for head_size <= 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add assert for this limitation?
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
5f84164 to
be5632c
Compare
|
|
||
| #endif | ||
|
|
||
| // This condition only works for head_size <= 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a guard for the noted limitation (head_size <= 128).
Any way, is it a limitation on the XE1 non-LSC path solely? SLM capacity is a critical constraint.
Also consider static_assert(head_size % REG_N == 0) to catch misconfigurations.
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
| half* prefetch_k_pos = (half*)k_cache_base + prefetch_block_id * blk_stride + ((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ) * head_size; | ||
| cm_ptr_prefetch<REG_K/2, DataSize::U32, CacheHint::Cached, CacheHint::Cached>((const unsigned int *const)prefetch_k_pos, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although stateless prefetch/load works, I would also recommend use stateful api for K cache. 1. stateful api is more reliable in terms of out of boundary access. 2. memory access api is aligned for q, k/v caches, as I see all kinds of api are used in one kernel in a mixed way.
It is not a mandatory, of course.
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/impls/cm/include/cm_pa_common.hpp
Outdated
Show resolved
Hide resolved
| } | ||
| } | ||
| } | ||
| if (q_tokens_left == 0) return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you remove this line for early exit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The early-exit was moved earlier: we now return right after clamping q_tokens_left (if (q_tokens_left == 0) return;).
This makes the later guard redundant, so it was removed. The writeback path is never reached when q_tokens_left == 0.
| auto P2 = P.format<half, num_P_tiles, REG_M * REG_K>(); | ||
| matrix<half, REG_K/2, REG_N*2*VALUE_TILE_NUM> Vmat; | ||
| #pragma unroll | ||
| for(int k = 0, ri=0; k < head_size; k += REG_N * VALUE_TILE_NUM, ri += num_P_tiles * VALUE_TILE_NUM) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any performance impact to when USE_LSC==1? We probably need a further check here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the existing tests, there should be no performance degradation, so we can conduct further tests
| #endif | ||
| matrix<half, head_size/REG_K, REG_K*REG_N> rQ; | ||
| matrix <float, head_size/REG_N*num_P_tiles, REG_M*REG_N> rO; | ||
| matrix <float, head_size/REG_M, REG_M*REG_N> rO; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For ARL-H, shape of rO is [head_size/REG_N, REG_M*REG_N]. Here it requires REG_M== REG_N. Is there any check/assert to guarantee this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some assert.
|
|
||
| if (q_tokens_left < 0) q_tokens_left = 0; | ||
| if (q_tokens_left > q_step) q_tokens_left = q_step; | ||
| if (q_tokens_left == 0) return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this check ahead? Sounds not. We still need this thread to contribute for WG, like K/V cache prefetch, and dequant (for i8).
| #ifdef CM_HAS_LSC_UNTYPED_2D | ||
| #define USE_LSC 1 | ||
| #else | ||
| #define USE_LSC 0 | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds we can directly use CM_HAS_LSC_UNTYPED_2D?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, update it.
| CM_INLINE constexpr auto reduce2d(matrix_ref<T, N, M> src) { | ||
| constexpr int group_size = M / group_count; | ||
| if constexpr (N > stop) { | ||
| if constexpr (N > stop && group_size > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you please explain we need this check "&& group_size > 1"? @fish-jiang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check the line 977, only need to call reduce_2d once, and no need to have hard code. Also, we can use different BLOCK_SG_M size without other kernel code change.
| // constexpr int BLOCK_WG_K = 64; // same in sg // because unroll 4 times along K ?? | ||
| constexpr int SUM_N = BLOCK_SG_N / (BLOCK_SIZE/STRIDE); | ||
|
|
||
| // // #ifndef BLOCK_SG_M | ||
| // #define BLOCK_SG_M 32 | ||
| // #define BLOCK_SG_N 16 | ||
| // #define SG_M 4 | ||
| // #define SG_N 8 | ||
| // #define HEAD_SIZE 128 | ||
| // #define KV_BLOCK_SIZE 256 | ||
| // #define STRIDE 16 | ||
| // // #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this commented lines, please.
| #endif | ||
| // 0~2 M[:]xK[0:16] 2~4 K[16:32] --> 32 * 2 regs | ||
| matrix<half, 2, BLOCK_REG_B> b0, b1; | ||
| matrix<half, REG_N, BLOCK_REG_B> b0, b1; // ping-pong B |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is any problem here that we allocate REG_N number of BLOCK_REG_B registers? @fish-jiang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No hard code here, we can adjust BLOCK_SG_N size
| #define CUR_TYPE CUR_TYPE_(SOFTMAX_TYPE) | ||
|
|
||
| template <int M, int N> | ||
| CM_INLINE void cm_load_2d(matrix_ref<SOFTMAX_TYPE, M, N> out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are similar function cm_load_2d in estimate.hpp too. We need a refactor in a sperate PR maybe to unite them.
| } else { | ||
| scale_val = 255.0 / (max_val - min_val); | ||
| zp_val = (0.0 - min_val) * scale_val; | ||
| scale_val = half(255.0) / (max_val - min_val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep the compute of scale_val in float precision. There is a known issue led by when it is calculated in half precision. Please check #33485
BTW, in this PR why do we change this file for ARL-H?
ceciliapeng2011
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some kernels are increasingly hard to read and maintain as more options extend the functionalities (to support kvcache compress types, xe1/xe2 arch, etc). We definitely need a refactor to improve it. Probably in a sperate PR.
- pa_single_token.cm
- estimate.hpp
- find_blocks.hpp
- cm_pa_common.hpp
| template <typename T1, typename T2> | ||
| CM_INLINE void Transpose_8x8(matrix_ref<T1, 8, 8> in, matrix_ref<T2, 8, 8> out) { | ||
| matrix<T2, 8, 8> temp; | ||
| temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); | ||
| temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); | ||
| temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); | ||
| temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); | ||
| temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); | ||
| temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); | ||
| temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); | ||
| temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); | ||
|
|
||
| out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); | ||
| out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); | ||
| out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); | ||
| out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); | ||
| out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); | ||
| out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); | ||
| out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); | ||
| out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can simply include <cm_attention_common.hpp> to get Transpose_8x8.
BTW, there once an optimization to Transpose_16x16 which improves performance a lot. Maybe the similar approach is applicable to Transepose_8x8 too.
| const uint seq_idx = get_cm_global_id_2nd(0); | ||
| const uint kv_head_num_idx = get_cm_global_id_2nd(1) / Q_head_chunks_per_kv_head; | ||
| const uint head_num_idx = get_cm_global_id_2nd(1) * Q_head_chunk_size; | ||
| //# KV_PARTITION_SIZE --> EU thread | ||
| const uint wg_thread_id = cm_global_id(2); | ||
| const uint wg_thread_id = get_cm_global_id_2nd(2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this part with an invented function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On A770, cm_global_id() is not accessible/unsupported in our CM environment, so using it breaks compilation/runtime.
| auto batch = cm_get_global_id_2nd(0); | ||
| auto head = cm_get_global_id_2nd(1); | ||
| auto offset = cm_group_id(2) * REDUCE_SPLIT_SIZE; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, why invent this?
Details:
Tickets: