diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index daedaadb1a7..98645a0cd4c 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -118,6 +118,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -138,6 +139,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index eb216dc8baf..d19771a19cc 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -104,6 +104,8 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -154,6 +156,14 @@ __device__ void paged_attention_kernel( const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence + // and current head. + const float* attn_bias_vec = + attn_bias == nullptr + ? nullptr + : attn_bias + seq_idx * num_heads * padded_max_seq_len + + head_idx * padded_max_seq_len; + // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread // group fetch or compute 16 bytes at a time. For example, if the size of a @@ -293,8 +303,10 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot( q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. + // Add the ALiBi bias if slopes are given, then add custom bias if given. + // TODO mutually exclusive? qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -512,6 +524,8 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -520,9 +534,9 @@ __global__ void paged_attention_v1_kernel( KV_DTYPE, IS_BLOCK_SPARSE>( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, + max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len, + q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -548,6 +562,8 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -555,10 +571,10 @@ __global__ void paged_attention_v2_kernel( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias, + padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 9b3a5c4b101..c32e852204d 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -29,21 +29,21 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \ - BLOCK_SIZE, NUM_THREADS, \ - KV_DTYPE, IS_BLOCK_SPARSE>), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ - NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \ - <<<grid, block, shared_mem_size, stream>>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \ + BLOCK_SIZE, NUM_THREADS, \ + KV_DTYPE, IS_BLOCK_SPARSE>), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ + NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \ + <<<grid, block, shared_mem_size, stream>>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. template <typename T, typename CACHE_T, int BLOCK_SIZE, @@ -53,7 +53,8 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, + const std::optional<torch::Tensor>& alibi_slopes, + const std::optional<torch::Tensor>& attn_bias, torch::Tensor& k_scale, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { @@ -73,7 +74,21 @@ void paged_attention_v1_launcher( alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr()) + : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); + } T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); @@ -84,13 +99,11 @@ void paged_attention_v1_launcher( const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int logits_size = padded_max_seq_len * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); @@ -137,8 +150,8 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ IS_BLOCK_SPARSE>( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ + seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale, \ + tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ @@ -179,6 +192,7 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, + const std::optional<torch::Tensor>& attn_bias, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9935359e02f..ccfd6cd60f5 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -36,8 +36,9 @@ <<<grid, block, shared_mem_size, stream>>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ + attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \ + kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ @@ -54,7 +55,8 @@ void paged_attention_v2_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, + const std::optional<torch::Tensor>& alibi_slopes, + const std::optional<torch::Tensor>& attn_bias, torch::Tensor& k_scale, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { @@ -74,7 +76,21 @@ void paged_attention_v2_launcher( alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr()) + : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); + } T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); @@ -88,16 +104,16 @@ void paged_attention_v2_launcher( const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + const int logits_size = PARTITION_SIZE * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel. dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); @@ -144,7 +160,7 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -190,6 +206,7 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, + const std::optional<torch::Tensor>& attn_bias, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index b9764056e8a..4c67a775a0b 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -459,7 +459,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, + int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, + const c10::optional<torch::Tensor>& attn_bias, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, @@ -467,6 +468,8 @@ void paged_attention_v1( const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -782,6 +785,7 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, + const std::optional<torch::Tensor>& attn_bias, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, @@ -789,6 +793,8 @@ void paged_attention_v2( const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 5d1c5f4c83d..9ddb837d182 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -24,12 +24,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. + // TODO attn_bias on cpu ops.def( "paged_attention_v1(" " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/csrc/ops.h b/csrc/ops.h index e39d4ef3188..6a467fd6f37 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,7 +33,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, + int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, + const c10::optional<torch::Tensor>& attn_bias, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, @@ -45,7 +46,8 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, + int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, + const c10::optional<torch::Tensor>& attn_bias, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c03806f430a..29e91762c74 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -29,7 +29,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +43,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b667d8d9e03..314db5db121 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -39,6 +39,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] +USE_CUSTOM_ATTN_BIAS = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ @@ -62,16 +63,11 @@ def ref_masked_attention( def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], -) -> None: + output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int, + key_cache: torch.Tensor, value_cache: torch.Tensor, + block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float, + alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[List[torch.Tensor]]) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] @@ -104,15 +100,17 @@ def ref_single_query_cached_kv_attention( keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - alibi_bias = None + bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) - - out = ref_masked_attention(q, keys, values, scale, alibi_bias) + bias = alibi_bias + if attn_bias is not None: + bias = attn_bias[i] if bias is None else bias + attn_bias[i] + out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -124,6 +122,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("use_custom_attn_bias", USE_CUSTOM_ATTN_BIAS) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @@ -136,6 +135,7 @@ def test_paged_attention( num_heads: Tuple[int, int], head_size: int, use_alibi: bool, + use_custom_attn_bias: bool, block_size: int, dtype: torch.dtype, kv_cache_dtype: str, @@ -155,7 +155,7 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None + alibi_slopes, attn_bias = None, None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) @@ -163,9 +163,30 @@ def test_paged_attention( seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) + attn_bias_list = None + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + if use_custom_attn_bias: + # NOTE (NickLucche) each sequence can have a different bias, + # depending on its len, but it *must* be padded to the block + # aligned max_seq_len and of type float32! + attn_bias_list = [ + torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) + for seq_len in seq_lens + ] + block_aligned_max_seq_len = max_num_blocks_per_seq * block_size + attn_bias = torch.empty( + num_seqs, + num_query_heads, + 1, + block_aligned_max_seq_len, # padded dim + device=device, + dtype=torch.float) + + for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): + # first `seq_len` entries of the bias matrix for each head/seq + attn_bias[i, :, :, :seq_len] = bias # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [ @@ -201,6 +222,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -209,7 +231,7 @@ def test_paged_attention( opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -242,18 +264,20 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -307,17 +331,10 @@ def test_paged_attention( value_cache = dequantized_value_cache ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - seq_lens, - scale, - alibi_slopes, - ) + ref_single_query_cached_kv_attention(ref_output, query, num_queries_per_kv, + key_cache, value_cache, block_tables, + seq_lens, scale, alibi_slopes, + attn_bias_list) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index e653d34d00e..bedff278f95 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -230,6 +230,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -267,6 +268,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/tests/models/encoder_decoder/language/conftest.py b/tests/models/encoder_decoder/language/conftest.py new file mode 100644 index 00000000000..c318dcb783c --- /dev/null +++ b/tests/models/encoder_decoder/language/conftest.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional, Type + +from transformers import AutoModelForSeq2SeqLM + +from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, + HfRunner, VllmRunner) +from ...utils import check_logprobs_close +from .utils import vllm_to_hf_output + + +def compare_hf_vllm_logprobs( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, + vllm_runner_kwargs: Optional[Dict[str, Any]] = None, + hf_tokens_to_skip: int = 0) -> None: + ''' + Test the provided model for a variety of encoder/decoder input prompts, + by validating it against corresponding HuggingFace (HF). + + Arguments: + + * hf_runner: HuggingFace (HF) test model runner + * vllm_runner: vLLM test model runner + * example_encoder_decoder_prompts: test fixture which provides a + dictionary of dummy prompts + * model: the HF ID of the specific BART variant under test + * dtype: the tensor datatype to employ + * max_tokens + * num_logprobs + * decoder_prompt_type: key into the example_encoder_decoder_prompts + dictionary; selects specific encoder/decoder + prompt scenarios to test + + A note on using HF BART as a baseline for validating vLLM BART, + specifically when the decoder prompt is None. + + The HF GenerationMixin's default behavior is to force the first + decoded token to be <BOS> if the prompt does not already contain + <BOS> (this is accomplished using a logit + processor setting.) + + So when we use HF BART as our baseline for comparison, note that + when the user provides a request with a None decoder prompt + (i.e. a singleton encoder prompt, or else an explicit encoder/ + decoder prompt with the decoder sub-prompt set to None), HF and + vLLM handle this in different ways: + + * HF will (1) tokenize the None prompt as an empty token-list, + (2) append <decoder-start-token> to the beginning, yielding + [<decoder-start-token>], (3) pass this token list to the model, and + then (4) after computing logits during prefill, override the model + logits & force <BOS> to be the first generated token. + + * vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder- + start-token to the beginning, yielding [<decoder-start-token><BOS>], + (3) pass these tokens to the model & proceed with generation. + + The net effect is that compared to vLLM, the list of HF *decoded* tokens + will contain one more initial <BOS> than the vLLM generated tokens, + because vLLM's <BOS> token is injected into the prompt rather than into + the generated output. This is in spite of the fact that overall, the + complete sequences (prompt + decoded tokens) produced by vLLM will match + HF. + + So when we use HF decoded token output to validate vLLM's decoded token + output, the testing process must account for the difference in decoded + token sequences between vLLM and HF specifically in the + decoder-prompt-is-None case. + + One option is to disable the logit processor feature that forces the + <BOS> token to be decoded (forced_bos_token_id = None), eliminating + the problem entirely. However this is not "normal" BART usage. + + The other option is - only in the decoder-prompt-is-None case - to + discard the first decoded token from the HF output before comparing it + to vLLM. + + To that end, `hf_tokens_to_skip` must be set to the number of HF decoded + tokens to skip during the process of validating the vLLM decoded output. + ''' + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default). + + # Note: currently encoder/decoder models are only compatible with + # enforce_eager=True. Normally this is not a problem because + # for encoder/decoder models vLLM will + # default to enforce_eager=True if enforce_eager + # is left unspecified. However, the + # VllmRunner test fixture (which wraps around the LLM class) defaults to + # enforce_eager=False (a behavior which a number of already-exisitng + # decoder-only unit tests expect), so when testing an encoder/decoder + # model we must explicitly specify enforce_eager=True in the VllmRunner + # constructor. + if not vllm_runner_kwargs: + vllm_runner_kwargs = dict() + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + **vllm_runner_kwargs) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs) + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_tokens_to_skip, + ) diff --git a/tests/models/encoder_decoder/language/language/__init__.py b/tests/models/encoder_decoder/language/language/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 81b629fdcf1..69a9a2a99e9 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -3,170 +3,12 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`. """ -from typing import List, Optional, Tuple, Type - import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.sequence import SampleLogprobs - -from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, - HfRunner, VllmRunner) -from ....utils import multi_gpu_test -from ...utils import check_logprobs_close - - -def vllm_to_hf_output( - vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "</s>" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "<s>" + hf_output_str - - return output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - ''' - Test the vLLM BART model for a variety of encoder/decoder input prompts, - by validating it against HuggingFace (HF) BART. - - Arguments: - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - * model: the HF ID of the specific BART variant under test - * dtype: the tensor datatype to employ - * max_tokens - * num_logprobs - * decoder_prompt_type: key into the example_encoder_decoder_prompts - dictionary; selects specific encoder/decoder - prompt scenarios to test +from tests.utils import multi_gpu_test - A note on using HF BART as a baseline for validating vLLM BART, - specifically when the decoder prompt is None. - - The HF GenerationMixin's default behavior is to force the first - decoded token to be <BOS> if the prompt does not already contain - <BOS> (this is accomplished using a logit - processor setting.) - - So when we use HF BART as our baseline for comparison, note that - when the user provides a request with a None decoder prompt - (i.e. a singleton encoder prompt, or else an explicit encoder/ - decoder prompt with the decoder sub-prompt set to None), HF and - vLLM handle this in different ways: - - * HF will (1) tokenize the None prompt as an empty token-list, - (2) append <decoder-start-token> to the beginning, yielding - [<decoder-start-token>], (3) pass this token list to the model, and - then (4) after computing logits during prefill, override the model - logits & force <BOS> to be the first generated token. - - * vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder- - start-token to the beginning, yielding [<decoder-start-token><BOS>], - (3) pass these tokens to the model & proceed with generation. - - The net effect is that compared to vLLM, the list of HF *decoded* tokens - will contain one more initial <BOS> than the vLLM generated tokens, - because vLLM's <BOS> token is injected into the prompt rather than into - the generated output. This is in spite of the fact that overall, the - complete sequences (prompt + decoded tokens) produced by vLLM will match - HF. - - So when we use HF decoded token output to validate vLLM's decoded token - output, the testing process must account for the difference in decoded - token sequences between vLLM and HF specifically in the - decoder-prompt-is-None case. - - One option is to disable the logit processor feature that forces the - <BOS> token to be decoded (forced_bos_token_id = None), eliminating - the problem entirely. However this is not "normal" BART usage. - - The other option is - only in the decoder-prompt-is-None case - to - discard the first decoded token from the HF output before comparing it - to vLLM. - - To that end, when testing the scenario where the decoder prompt is None - (and only in that one scenario), this test skips the first HF decoded - token during the process of validating the vLLM decoded output. - ''' - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default). - - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-exisitng - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) +from ....conftest import DecoderPromptType +from .conftest import compare_hf_vllm_logprobs @pytest.mark.parametrize( @@ -184,7 +26,7 @@ def run_test( def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - run_test( + compare_hf_vllm_logprobs( hf_runner, vllm_runner, example_encoder_decoder_prompts[decoder_prompt_type], @@ -194,7 +36,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=1, - ) + hf_tokens_to_skip=int(decoder_prompt_type == DecoderPromptType.NONE)) @multi_gpu_test(num_gpus=2) @@ -209,7 +51,7 @@ def test_models_distributed(hf_runner, vllm_runner, distributed_executor_backend, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - run_test( + compare_hf_vllm_logprobs( hf_runner, vllm_runner, example_encoder_decoder_prompts[decoder_prompt_type], diff --git a/tests/models/encoder_decoder/language/test_t5.py b/tests/models/encoder_decoder/language/test_t5.py new file mode 100644 index 00000000000..48f61210a73 --- /dev/null +++ b/tests/models/encoder_decoder/language/test_t5.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the outputs of HF and vLLM for T5 models using greedy sampling. +Based on tests/models/encoder_decoder/language/test_bart.py. + +Run `pytest tests/models/encoder_decoder/language/test_t5.py`. +""" +import pytest + +from tests.utils import multi_gpu_test +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) + +from ....conftest import DecoderPromptType +from .conftest import compare_hf_vllm_logprobs + + +@pytest.mark.parametrize( + "model", + [ + pytest.param("Finnish-NLP/ul2-tiny-nl6-finnish"), + # pytest.param("google/flan-t5-base"), + ], +) +@pytest.mark.parametrize("vllm_kwargs", [{"max_model_len": 512}]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +# TODO custom prompt here generate high entropy output, causing +# differences in sampled tokens. +@pytest.mark.parametrize("decoder_prompt_type", + [DecoderPromptType.NONE, DecoderPromptType.EMPTY_STR]) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type, + vllm_kwargs) -> None: + # Model only supported on xformers backend as of now. + with global_force_attn_backend_context_manager(_Backend.XFORMERS): + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + vllm_runner_kwargs=vllm_kwargs) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", ["google/t5-small"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.NONE]) +def test_models_distributed(hf_runner, vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, model, dtype, + max_tokens, num_logprobs, + decoder_prompt_type) -> None: + with global_force_attn_backend_context_manager(_Backend.XFORMERS): + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/models/encoder_decoder/language/utils.py b/tests/models/encoder_decoder/language/utils.py new file mode 100644 index 00000000000..ae91b160686 --- /dev/null +++ b/tests/models/encoder_decoder/language/utils.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional, Tuple + +from vllm.sequence import SampleLogprobs + +from ....conftest import DecoderPromptType + + +def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "</s>" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "<s>" + hf_output_str + + return output_ids, hf_output_str, out_logprobs diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a6823501676..98d7c9f875c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -49,6 +49,7 @@ def paged_attention_v1( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -60,8 +61,8 @@ def paged_attention_v1( ) -> None: torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, + seq_lens, block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -81,6 +82,7 @@ def paged_attention_v2( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -93,7 +95,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + alibi_slopes, attn_bias, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9765e7881ad..ef8009a3cc8 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -443,6 +443,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, tp_rank=self.tp_rank, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02bff57a62b..8f3c94920be 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -837,6 +837,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 723a4558d0b..196248682d7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -159,8 +159,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. + # when alibi slopes or custom attention bias are used. + # It is because of a limitation from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None @@ -292,7 +292,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: def _get_attn_bias( attn_metadata: XFormersMetadata, attn_type: str, -) -> Optional[AttentionBias]: +) -> Optional[List[AttentionBias]]: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -548,6 +548,7 @@ def forward( assert query.shape[0] == num_prefill_query_tokens assert decode_query.shape[0] == num_decode_query_tokens + attn_bias = _get_attn_bias(attn_metadata, attn_type) if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -555,6 +556,10 @@ def forward( # normal attention. # block tables are empty if the prompt does not have a cached # prefix. + # As prefill metadata are cached on first call, we need to make + # sure attn_bias is up to date. + if attn_bias: + _set_attn_bias(prefill_meta, attn_bias, attn_type) out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta, attn_type=attn_type) assert out.shape == output[:num_prefill_query_tokens].shape @@ -583,6 +588,7 @@ def forward( prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, + _get_attn_bias(attn_metadata, attn_type), self.sliding_window, layer._k_scale, layer._v_scale, @@ -600,6 +606,13 @@ def forward( block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + if attn_bias: + assert len( + attn_bias + ) == 1, "PagedAttention expects a single bias to be provided\ + for all input sequences." + + attn_bias = attn_bias[0] output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, @@ -611,6 +624,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + attn_bias, layer._k_scale, layer._v_scale, ) @@ -706,6 +720,7 @@ def _run_memory_efficient_xformers_forward( else: raise ValueError("Unknown AttentionType: %s", attn_type) + assert isinstance(attn_bias, BlockDiagonalMask) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -719,10 +734,10 @@ def _run_memory_efficient_xformers_forward( _set_attn_bias(attn_metadata, attn_bias, attn_type) - # No alibi slopes. + # No alibi slopes and no multi-sequence custom attention bias. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. - if self.alibi_slopes is None: + if self.alibi_slopes is None and len(attn_bias) == 1: # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) @@ -736,14 +751,16 @@ def _run_memory_efficient_xformers_forward( scale=self.scale) return out.view_as(original_query) - # Attention with alibi slopes. + # Attention with alibi slopes or multiple custom attention bias. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - assert attn_metadata.seq_lens is not None output = torch.empty_like(original_query) + seq_lens = attn_metadata.encoder_seq_lens \ + if attn_type == AttentionType.ENCODER else attn_metadata.seq_lens + assert seq_lens start = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens): + for i, seq_len in enumerate(seq_lens): end = start + seq_len out = xops.memory_efficient_attention_forward( query[None, start:end], diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 598ceea130d..150c99fc97f 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -105,6 +105,7 @@ def forward_decode( block_size, max_context_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2c60bd0c38d..89ae554bff9 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 - from dataclasses import dataclass from typing import List, Optional, Tuple @@ -97,6 +96,7 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], k_scale: torch.Tensor, v_scale: torch.Tensor, tp_rank: int = 0, @@ -142,6 +142,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -180,6 +181,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -205,11 +207,13 @@ def forward_prefix( context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], sliding_window: Optional[int], k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> torch.Tensor: output = torch.empty_like(query) + assert attn_bias is None, "Bias for prefix not yet enabled" context_attention_fwd( query, key, diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 53f89996f0f..4a25a6c37b3 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -91,7 +91,7 @@ def get_decoder_start_token_id(self) -> Optional[int]: return dec_start_token_id - def _get_default_enc_dec_decoder_prompt(self) -> List[int]: + def _maybe_get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' Specifically for encoder/decoder models: generate a default decoder prompt for when @@ -124,8 +124,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' bos_token_id = self.get_bos_token_id() - assert bos_token_id is not None - return [bos_token_id] + return [] if bos_token_id is None else [bos_token_id] def _prepare_decoder_input_ids_for_generation( self, @@ -157,7 +156,8 @@ def _prepare_decoder_input_ids_for_generation( if decoder_input_ids is None: # no decoder prompt input -> # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + decoder_input_ids = self._maybe_get_default_enc_dec_decoder_prompt( + ) if (len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c2d0fae7056..51c04737bb5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -104,6 +104,8 @@ "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 + "T5Model": ("t5", "T5ForConditionalGeneration"), + "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py new file mode 100644 index 00000000000..70a2f2a726e --- /dev/null +++ b/vllm/model_executor/models/t5.py @@ -0,0 +1,790 @@ +# SPDX-License-Identifier: Apache-2.0 +# Derived from T5 implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch T5 model.""" + +import math +import re +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import T5Config +# TODO best way to handle xformers imports? +from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias + +# TODO func should be in backend interface +from vllm.attention.backends.xformers import (XFormersMetadata, _get_attn_bias, + _set_attn_bias) +from vllm.attention.layer import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import maybe_prefix + + +class T5LayerNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. + No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states) -> torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is + # also known as Root Mean Square Layer Normalization + # https://arxiv.org/abs/1910.07467 thus variance is calculated w/o mean + # and there is no bias. Additionally we want to make sure that the + # accumulation for half-precision inputs is done in fp32. + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(nn.Module): + + def __init__(self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wo = RowParallelLinear(config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + # if ( + # isinstance(self.wo.weight, torch.Tensor) + # and hidden_states.dtype != self.wo.weight.dtype + # and self.wo.weight.dtype != torch.int8 + # ): + # hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + + def __init__(self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False, + quant_config=quant_config) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False, + quant_config=quant_config) + # Should not run in fp16 unless mixed-precision is used, + # see https://github.com/huggingface/transformers/issues/20287. + self.wo = RowParallelLinear(config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + + def __init__(self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense( + config, quant_config=quant_config) + else: + self.DenseReluDense = T5DenseActDense(config, + quant_config=quant_config) + + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward(self, hidden_states) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(nn.Module): + + def __init__(self, + config: T5Config, + attn_type: AttentionType, + has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.attn_type = attn_type + # Cross-attention has no relative pos encoding anyway + self.is_decoder = attn_type == AttentionType.DECODER + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = \ + config.relative_attention_num_buckets + self.relative_attention_max_distance = \ + config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + assert cache_config + # Alternatively we can get it from kv_cache size in fwd. + self.block_size = cache_config.block_size + + # Partition heads across multiple tensor parallel GPUs. + tp_world_size = get_tensor_model_parallel_world_size() + assert config.num_heads % tp_world_size == 0 + self.n_heads = config.num_heads // tp_world_size + + self.inner_dim = self.n_heads * self.key_value_proj_dim + # No GQA in t5. + self.n_kv_heads = self.n_heads + + self.qkv_proj = QKVParallelLinear(self.d_model, + self.d_model // self.n_heads, + self.n_heads, + self.n_kv_heads, + bias=False, + quant_config=quant_config) + + # NOTE (NickLucche) T5 employs a scaled weight initialization scheme + # instead of scaling attention scores directly. + self.attn = Attention(self.n_heads, + config.d_kv, + 1.0, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type) + + # Only the first SelfAttention block in encoder decoder has this + # embedding layer, the others reuse its output. + if self.has_relative_attention_bias: + self.relative_attention_bias = \ + VocabParallelEmbedding(self.relative_attention_num_buckets, + self.n_heads, + org_num_embeddings=\ + self.relative_attention_num_buckets, + quant_config=quant_config) + self.out_proj = RowParallelLinear( + self.inner_dim, + self.d_model, + bias=False, + quant_config=quant_config, + ) + + @staticmethod + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, + i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative + positions are invalid. We use smaller buckets for small absolute + relative_position and larger buckets for larger absolute + relative_positions. All relative positions >=max_distance map to the + same bucket. All relative positions <=-max_distance map to the same + bucket. This should allow for more graceful generalization to longer + sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """# noqa: E501 + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to( + torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins + # in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) / + math.log(max_distance / max_exact) * + (num_buckets - max_exact)).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1)) + + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) + return relative_buckets + + def compute_bias(self, + query_length, + key_length, + device=None) -> torch.Tensor: + """Compute binned relative position bias""" + # TODO possible tp issue? + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, + dtype=torch.long, + device=device)[:, None] + memory_position = torch.arange(key_length, + dtype=torch.long, + device=device)[None, :] + # max_seq_len, nh + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + x = values.permute([2, 0, 1]).unsqueeze( + 0) # shape (1, num_heads, query_length, key_length) + return x + + def forward( + self, + hidden_states: torch.Tensor, # (num_tokens, d_model) + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO auto-selection of xformers backend when t5 is detected + assert isinstance(attn_metadata, XFormersMetadata) + num_seqs = len( + attn_metadata.seq_lens) if attn_metadata.seq_lens else len( + attn_metadata.encoder_seq_lens) + qkv, _ = self.qkv_proj(hidden_states) + # Projection of 'own' hidden state (self-attention). No GQA here. + q, k, v = qkv.split(self.inner_dim, dim=-1) + + # NOTE (NickLucche) Attn bias is computed once per encoder or decoder + # forward, on the first call to T5Attention.forward. Subsequent + # *self-attention* layers will reuse it. + attn_bias = _get_attn_bias(attn_metadata, self.attn_type) + if self.attn_type == AttentionType.ENCODER_DECODER: + # Projection of encoder's hidden states, cross-attention. + if encoder_hidden_states is None: + # Decode phase, kv already cached + assert attn_metadata.num_prefills == 0 + k = None + v = None + else: + assert attn_metadata.num_prefills > 0 + # Prefill phase (first decoder forward), caching kv + qkv_enc, _ = self.qkv_proj(encoder_hidden_states) + _, k, v = qkv_enc.split(self.inner_dim, dim=-1) + # No custom attention bias must be set when running cross attn. + assert attn_bias is None + + # Not compatible with CP here (as all encoder-decoder models), + # as it assumes homogeneous batch (prefills or decodes). + elif self.has_relative_attention_bias: + assert attn_bias is None # to be recomputed + # Self-attention. Compute T5 relative positional encoding. + # The bias term is computed on longest sequence in batch. Biases + # for shorter sequences are slices of the longest. + # TODO xformers-specific code. + align_to = 8 + # bias expected shape: (num_seqs, NH, L, L_pad) for prefill, + # (num_seqs, NH, 1, L_pad) for decodes. + if self.attn_type == AttentionType.ENCODER: + # Encoder prefill stage, uses xFormers, hence sequence + # padding/alignment to 8 is required. + seq_len = attn_metadata.max_encoder_seq_len + padded_seq_len = (seq_len + align_to - + 1) // align_to * align_to + # TODO (NickLucche) avoid extra copy on repeat, + # provide multiple slices of same memory + position_bias = self.compute_bias(seq_len, + padded_seq_len).repeat( + num_seqs, 1, 1, 1) + # xFormers expects a list of biases, one matrix per sequence. + # As each sequence gets its own bias, no masking is required. + attn_bias = [ + p[None, :, :sq, :sq] for p, sq in zip( + position_bias, attn_metadata.encoder_seq_lens) + ] + elif attn_metadata.prefill_metadata: + # Decoder prefill stage, uses xFormers, hence sequence + # padding/alignment to 8 is required. First decoder step, + # seq_len is usually 1, but one can prepend different start + # tokens prior to generation. + seq_len = attn_metadata.max_prefill_seq_len + # ->align + padded_seq_len = (seq_len + align_to - + 1) // align_to * align_to + position_bias = self.compute_bias(seq_len, + padded_seq_len).repeat( + num_seqs, 1, 1, 1) + # Causal mask for prefill. + attn_bias = [ + LowerTriangularMaskWithTensorBias(pb[None, :, :sq, :sq]) + for pb, sq in zip(position_bias, attn_metadata.seq_lens) + ] + else: + # Decoder decoding stage, uses PagedAttention, hence sequence + # padding/alignment to `block_size` is required. Expected + # number of queries is always 1 (MQA not supported). + seq_len = attn_metadata.max_decode_seq_len + block_aligned_seq_len = (seq_len + self.block_size - 1 + ) // self.block_size * self.block_size + + # TODO bf16 bias support in PagedAttention. + position_bias = self.compute_bias( + seq_len, block_aligned_seq_len).float() + # Bias for the last query, the one at current decoding step. + position_bias = position_bias[:, :, -1:, :].repeat( + num_seqs, 1, 1, 1) + # No explicit masking required, this is done inside the + # paged attention kernel based on the sequence length. + attn_bias = [position_bias] + + # NOTE Assign bias term on metadata based on attn type: + # ENCODER->`encoder_attn_bias`, DECODER->`attn_bias`. + _set_attn_bias(attn_metadata, attn_bias, self.attn_type) + elif not self.has_relative_attention_bias: + # Encoder/Decoder Self-Attention Layer, attn bias already cached. + assert attn_bias is not None + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class T5LayerSelfAttention(nn.Module): + + def __init__( + self, + config, + has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.SelfAttention = T5Attention( + config, + AttentionType.DECODER + if "decoder" in prefix else AttentionType.ENCODER, + has_relative_attention_bias=has_relative_attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.SelfAttention") + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=None, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + + def __init__(self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.EncDecAttention = T5Attention(config, + AttentionType.ENCODER_DECODER, + has_relative_attention_bias=False, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.EncDecAttention") + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5Block(nn.Module): + + def __init__(self, + config: T5Config, + is_decoder: bool, + has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.is_decoder = is_decoder + self.self_attn = T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + + if self.is_decoder: + self.cross_attn = T5LayerCrossAttention( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.cross_attn") + + self.ffn = T5LayerFF(config, quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + if self.is_decoder: + hidden_states = self.cross_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + # Apply Feed Forward layer + hidden_states = self.ffn(hidden_states) + return hidden_states + + +class T5Stack(nn.Module): + + def __init__(self, + config: T5Config, + is_decoder: bool, + n_layers: int, + embed_tokens=None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.embed_tokens = embed_tokens + # Only the first block has relative positional encoding. + self.blocks = nn.ModuleList([ + T5Block(config, + is_decoder=is_decoder, + has_relative_attention_bias=i == 0, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}") for i in range(n_layers) + ]) + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + for idx, block in enumerate(self.blocks): + hidden_states = block( + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5Model(nn.Module): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: T5Config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.padding_idx = config.pad_token_id + self.shared = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size) + + self.encoder = T5Stack(config, + False, + config.num_layers, + self.shared, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = T5Stack(config, + True, + config.num_decoder_layers, + self.shared, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.shared(input_ids) + + def forward(self, input_ids: torch.Tensor, encoder_input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input: on a regular generate call, the encoder + # runs once, on the prompt. Subsequent decoder calls reuse output + # `encoder_hidden_states`. + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + # Clear attention bias state. + attn_metadata.attn_bias = None + attn_metadata.encoder_attn_bias = None + attn_metadata.cross_attn_bias = None + + decoder_outputs = self.decoder( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + # When capturing CUDA Graph + attn_metadata.attn_bias = None + attn_metadata.encoder_attn_bias = None + attn_metadata.cross_attn_bias = None + return decoder_outputs + + +class T5ForConditionalGeneration(nn.Module): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", + "lm_head.weight" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: T5Config = vllm_config.model_config.hf_config + self.model_dim = config.d_model + self.config = config + self.unpadded_vocab_size = config.vocab_size + if lora_config := vllm_config.lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.model = T5Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + # Although not in config, this is the default for hf models. + if self.config.tie_word_embeddings: + self.lm_head = self.model.shared + # in transformers this is smt more explicit, as in (after load) + # self.lm_head.weight = self.model.shared.weight + else: + self.lm_head = ParallelLMHead(self.unpadded_vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa: E501 + hidden_states = hidden_states * (self.model_dim**-0.5) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.shared(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + return self.model(input_ids, encoder_input_ids, kv_caches, + attn_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + model_params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + renamed_reg = [ + (re.compile(r'block\.(\d+)\.layer\.0'), r'blocks.\1.self_attn'), + (re.compile(r'decoder.block\.(\d+)\.layer\.1'), + r'decoder.blocks.\1.cross_attn'), + (re.compile(r'decoder.block\.(\d+)\.layer\.2'), + r'decoder.blocks.\1.ffn'), + # encoder has no cross-attn, but rather self-attention+ffn. + (re.compile(r'encoder.block\.(\d+)\.layer\.1'), + r'encoder.blocks.\1.ffn'), + (re.compile(r'\.o\.'), r'.out_proj.'), + ] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj.", ".q.", "q"), + (".qkv_proj.", ".k.", "k"), + (".qkv_proj.", ".v.", "v") + ] + + for name, loaded_weight in weights: + # No relative position attn bias on cross attention. + if name in self._keys_to_ignore_on_load_unexpected: + continue + + # Handle some renaming + for reg in renamed_reg: + name = re.sub(*reg, name) + + top_module, _ = name.split('.', 1) + if top_module != 'lm_head': + name = f"model.{name}" + + # Split q/k/v layers to unified QKVParallelLinear + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = model_params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Not a q/k/v layer. + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params