Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add T5 model (2/2) #11901

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_seq_len,
alibi_slopes,
None, # TODO add custom bias
kv_cache_dtype,
k_scale,
v_scale,
Expand All @@ -134,6 +135,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_seq_len,
alibi_slopes,
None,
kv_cache_dtype,
k_scale,
v_scale,
Expand Down
32 changes: 24 additions & 8 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ __device__ void paged_attention_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len]
const int padded_max_seq_len, // Avoid recomputing from seq_lens.
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
Expand Down Expand Up @@ -154,6 +156,14 @@ __device__ void paged_attention_kernel(
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];

// NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence
// and current head.
const float* attn_bias_vec =
attn_bias == nullptr
? nullptr
: attn_bias + seq_idx * num_heads * padded_max_seq_len +
head_idx * padded_max_seq_len;

// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
Expand Down Expand Up @@ -293,8 +303,10 @@ __device__ void paged_attention_kernel(
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
// Add the ALiBi bias if slopes are given, then add custom bias if given.
// TODO mutually exclusive?
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0;

if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
Expand Down Expand Up @@ -512,6 +524,8 @@ __global__ void paged_attention_v1_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ attn_bias,
const int padded_max_seq_len, // Avoid recomputing from seq_lens.
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
Expand All @@ -520,9 +534,9 @@ __global__ void paged_attention_v1_kernel(
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len,
q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}

Expand All @@ -548,17 +562,19 @@ __global__ void paged_attention_v2_kernel(
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ attn_bias,
const int padded_max_seq_len, // Avoid recomputing from seq_lens.
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias,
padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale,
v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step);
}

// Grid: (num_heads, num_seqs).
Expand Down
44 changes: 29 additions & 15 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);

// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
Expand All @@ -53,8 +53,9 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& attn_bias, float k_scale, float v_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
Expand All @@ -73,7 +74,21 @@ void paged_attention_v1_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;

const float* attn_bias_ptr =
attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
: nullptr;
const int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
if (attn_bias_ptr) {
const torch::Tensor& abias = attn_bias.value();
TORCH_CHECK(abias.dtype() == torch::kFloat32,
"Unsupported bias dtype: ", abias.dtype());
TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
"The last dimension of the attention bias must "
"match the block-aligned maximum sequence length (",
padded_max_seq_len,
"). However, the given dimensions are: ", abias.sizes());
}
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
Expand All @@ -82,13 +97,11 @@ void paged_attention_v1_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
const int logits_size = padded_max_seq_len * sizeof(float);
const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
const int shared_mem_size = std::max(logits_size, outputs_size);

dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
Expand Down Expand Up @@ -135,8 +148,8 @@ void paged_attention_v1_launcher(
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale, \
tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);

#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
Expand Down Expand Up @@ -176,7 +189,8 @@ void paged_attention_v1(
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
Expand Down
45 changes: 31 additions & 14 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \
attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
Expand All @@ -54,8 +55,9 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& attn_bias, float k_scale, float v_scale,
const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
Expand All @@ -74,7 +76,21 @@ void paged_attention_v2_launcher(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;

const float* attn_bias_ptr =
attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
: nullptr;
const int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
if (attn_bias_ptr) {
const torch::Tensor& abias = attn_bias.value();
TORCH_CHECK(abias.dtype() == torch::kFloat32,
"Unsupported bias dtype: ", abias.dtype());
TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
"The last dimension of the attention bias must "
"match the block-aligned maximum sequence length (",
padded_max_seq_len,
"). However, the given dimensions are: ", abias.sizes());
}
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
Expand All @@ -86,16 +102,16 @@ void paged_attention_v2_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
const int logits_size = PARTITION_SIZE * sizeof(float);
const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);

// For paged attention v2 kernel.
dim3 grid(num_heads, num_seqs, max_num_partitions);
int shared_mem_size = std::max(logits_size, outputs_size);
const int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
Expand Down Expand Up @@ -142,7 +158,7 @@ void paged_attention_v2_launcher(
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

Expand Down Expand Up @@ -187,7 +203,8 @@ void paged_attention_v2(
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
Expand Down
10 changes: 8 additions & 2 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,17 @@ void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
TORCH_CHECK(!attn_bias.has_value(),
"CPU backend does not support custom attention bias.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
Expand Down Expand Up @@ -781,14 +784,17 @@ void paged_attention_v2(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& attn_bias,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
TORCH_CHECK(!attn_bias.has_value(),
"CPU backend does not support custom attention bias.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
Expand Down
5 changes: 3 additions & 2 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
// TODO attn_bias on cpu
ops.def(
"paged_attention_v1("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
Expand All @@ -43,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
Expand Down
Loading
Loading