From 5e59703d2875f87f9677582eada90d8a8334692e Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 2 Dec 2024 18:02:18 +0000 Subject: [PATCH 01/17] wip Signed-off-by: NickLucche --- .../kernels/benchmark_paged_attention.py | 2 + csrc/attention/attention_kernels.cuh | 50 +++++++++---- csrc/attention/paged_attention_v1.cu | 47 +++++++----- csrc/attention/paged_attention_v2.cu | 19 +++-- csrc/cpu/attention.cpp | 10 ++- csrc/cpu/torch_bindings.cpp | 5 +- csrc/ops.h | 6 +- csrc/torch_bindings.cpp | 4 +- tests/kernels/test_attention.py | 75 ++++++++++++------- vllm/_custom_ops.py | 8 +- vllm/attention/backends/blocksparse_attn.py | 1 + vllm/attention/backends/rocm_flash_attn.py | 1 + vllm/attention/backends/xformers.py | 2 + vllm/attention/ops/paged_attn.py | 5 ++ 14 files changed, 157 insertions(+), 78 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14eef00b855ac..120b8ffe9c657 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -114,6 +114,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -134,6 +135,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 563e1438f0b01..25de77c324c62 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -104,6 +104,7 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -153,6 +154,21 @@ __device__ void paged_attention_kernel( const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + // TODO check if indexing still makes sense + // seq_len indexes on 'max_seq_lens' dim, + // it's like renaming dim you get attn_bias: seq_len x num_kv_heads x seq_len + // TODO each seq can have different len (seq_lens) but only one bias!! + // NOTE (NickLucche) `max_seq_len` bias values for current sequence and current head + const float* attn_bias_vec = + attn_bias == nullptr + ? nullptr + : attn_bias + seq_idx * num_heads * num_seq_blocks * BLOCK_SIZE + + head_idx * num_seq_blocks * BLOCK_SIZE; + // : attn_bias + seq_idx * num_kv_heads * num_seq_blocks * BLOCK_SIZE + + // const float* attn_bias_vec = attn_bias == nullptr + // ? nullptr + // : attn_bias + seq_idx * num_kv_heads * seq_len + + // kv_head_idx * seq_len; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread @@ -293,8 +309,12 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot( q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. + // NOTE here each thread adds its own alibi (one per head..) like I am + // sure not the whole group needs to do so Add the ALiBi bias if slopes + // are given. + // TODO mutually exclusive? qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -512,17 +532,18 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, const int q_stride, + const int kv_block_stride, const int kv_head_stride, const float k_scale, + const float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, + max_num_blocks_per_seq, alibi_slopes, attn_bias, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -548,15 +569,16 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, const int q_stride, + const int kv_block_stride, const int kv_head_stride, const float k_scale, + const float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias, + q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 27321148f6dda..13a10221db425 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -29,20 +29,20 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, attn_bias_ptr, q_stride, kv_block_stride, \ + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. @@ -53,8 +53,9 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -73,7 +74,12 @@ void paged_attention_v1_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; + if (attn_bias_ptr){ + TORCH_CHECK(attn_bias.value().dtype() == torch::kFloat32, "Unsupported bias dtype: ", attn_bias.value().dtype()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); @@ -135,8 +141,8 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ + seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale, \ + tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ @@ -176,7 +182,8 @@ void paged_attention_v1( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index a453b2243e48c..80e1d7cb962df 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -36,9 +36,9 @@ <<>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ + attn_bias_ptr, q_stride, kv_block_stride, kv_head_stride, k_scale, \ + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ @@ -54,8 +54,9 @@ void paged_attention_v2_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -74,6 +75,9 @@ void paged_attention_v2_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -142,7 +146,7 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -187,7 +191,8 @@ void paged_attention_v2( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ef5b14088c63b..eb33c66953a6e 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -459,7 +459,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -467,6 +468,8 @@ void paged_attention_v1( TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -781,7 +784,8 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -789,6 +793,8 @@ void paged_attention_v2( TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 74e4d8189d403..3cfa289848e21 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -24,12 +24,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. + // TODO attn_bias on cpu ops.def( "paged_attention_v1(" " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/csrc/ops.h b/csrc/ops.h index 9efd9b0c24700..87412b0eb746c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,7 +33,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -44,7 +45,8 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 956258c1001d3..5b23441cba8fd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -29,7 +29,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +43,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 3e3c0668198ad..6282a3abd6887 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -18,7 +18,8 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = 16 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -29,6 +30,7 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 @@ -37,6 +39,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] +USE_CUSTOM_ATTN_BIAS = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ @@ -60,16 +63,11 @@ def ref_masked_attention( def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], -) -> None: + output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int, + key_cache: torch.Tensor, value_cache: torch.Tensor, + block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float, + alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[List[torch.Tensor]]) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] @@ -102,15 +100,19 @@ def ref_single_query_cached_kv_attention( keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - alibi_bias = None + bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) - - out = ref_masked_attention(q, keys, values, scale, alibi_bias) + bias = alibi_bias + if attn_bias is not None: + # TODO test alibi + bias + bias = attn_bias[i] if bias is None else bias + attn_bias[i] + # print(f"ATTN BIAS {i}: {attn_bias[i]}") + out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -122,6 +124,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("use_custom_attn_bias", USE_CUSTOM_ATTN_BIAS) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @@ -134,12 +137,17 @@ def test_paged_attention( num_heads: Tuple[int, int], head_size: int, use_alibi: bool, + use_custom_attn_bias: bool, block_size: int, dtype: torch.dtype, kv_cache_dtype: str, seed: int, device: str, ) -> None: + # num_heads = (2, 2) + # num_seqs = 2 + # head_size = 32 + if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -153,7 +161,7 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None + alibi_slopes, attn_bias = None, None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) @@ -161,6 +169,20 @@ def test_paged_attention( seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) + attn_bias_list = None + if use_custom_attn_bias: + # NOTE (NickLucche) each sequence can have a different bias, + # depending on its len, but it *must* be float (f32)! + attn_bias_list = [torch.randn(num_query_heads, + 1, + seq_len, + dtype=torch.float) for seq_len in seq_lens] + attn_bias = torch.empty(num_seqs, num_query_heads, 1, max_seq_len, device=device, dtype=torch.float) + + for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): + # first seq_len entries of the bias for each head/seq + attn_bias[i, :, :, :seq_len] = bias + # print("bias shape", attn_bias.shape) # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size @@ -186,6 +208,7 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) + # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -199,19 +222,23 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, ) + # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): + assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -240,6 +267,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -249,6 +277,7 @@ def test_paged_attention( (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -305,17 +334,11 @@ def test_paged_attention( value_cache = dequantized_value_cache ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - seq_lens, - scale, - alibi_slopes, - ) + ref_single_query_cached_kv_attention(ref_output, query, num_queries_per_kv, + key_cache, value_cache, block_tables, + seq_lens, scale, alibi_slopes, + attn_bias_list) + # print("\nREF OUT", ref_output) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index afb350591e562..e5225311f0540 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -74,6 +74,7 @@ def paged_attention_v1( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -85,8 +86,8 @@ def paged_attention_v1( ) -> None: torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, + seq_lens, block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -106,6 +107,7 @@ def paged_attention_v2( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -118,7 +120,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + alibi_slopes, attn_bias, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 7089d59392c36..91de16b4608d4 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -440,6 +440,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias k_scale, v_scale, tp_rank=self.tp_rank, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index a91a5af5c3d58..1793830f8dd0b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -628,6 +628,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias k_scale, v_scale, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 8c8ca8520a9db..419f174b753ab 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -579,6 +579,7 @@ def forward( prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, + prefill_meta.attn_bias, self.sliding_window, k_scale, v_scale, @@ -607,6 +608,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + decode_meta.attn_bias, # TODO or cross_attn_bias?? k_scale, v_scale, ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 076f151ffcb61..bc651fa0dc326 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -95,6 +95,7 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], k_scale: float, v_scale: float, tp_rank: int = 0, @@ -140,6 +141,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -178,6 +180,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -203,11 +206,13 @@ def forward_prefix( context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], sliding_window: Optional[int], k_scale: float, v_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) + assert attn_bias is None, "Bias for prefix not yet enabled" context_attention_fwd( query, key, From 416412d5b672e77e111f78d25a9b3a3e295cb2bc Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:16:11 +0000 Subject: [PATCH 02/17] add working kernel with padded_max_seq_len as arg Signed-off-by: NickLucche --- csrc/attention/attention_kernels.cuh | 56 +++++++++++++--------------- csrc/attention/paged_attention_v1.cu | 38 ++++++++++--------- csrc/attention/paged_attention_v2.cu | 17 +++++++-- 3 files changed, 59 insertions(+), 52 deletions(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 25de77c324c62..08f9882f65f09 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -104,7 +104,8 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] + const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -154,21 +155,14 @@ __device__ void paged_attention_kernel( const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - // TODO check if indexing still makes sense - // seq_len indexes on 'max_seq_lens' dim, - // it's like renaming dim you get attn_bias: seq_len x num_kv_heads x seq_len - // TODO each seq can have different len (seq_lens) but only one bias!! - // NOTE (NickLucche) `max_seq_len` bias values for current sequence and current head + + // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence + // and current head. const float* attn_bias_vec = attn_bias == nullptr ? nullptr - : attn_bias + seq_idx * num_heads * num_seq_blocks * BLOCK_SIZE + - head_idx * num_seq_blocks * BLOCK_SIZE; - // : attn_bias + seq_idx * num_kv_heads * num_seq_blocks * BLOCK_SIZE + - // const float* attn_bias_vec = attn_bias == nullptr - // ? nullptr - // : attn_bias + seq_idx * num_kv_heads * seq_len + - // kv_head_idx * seq_len; + : attn_bias + seq_idx * num_heads * padded_max_seq_len + + head_idx * padded_max_seq_len; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread @@ -309,9 +303,7 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot( q_vecs[thread_group_offset], k_vecs); - // NOTE here each thread adds its own alibi (one per head..) like I am - // sure not the whole group needs to do so Add the ALiBi bias if slopes - // are given. + // Add the ALiBi bias if slopes are given, then add custom bias if given. // TODO mutually exclusive? qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0; @@ -532,17 +524,18 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const float* __restrict__ attn_bias, const int q_stride, - const int kv_block_stride, const int kv_head_stride, const float k_scale, - const float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, attn_bias, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len, + q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -569,18 +562,19 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const float* __restrict__ attn_bias, const int q_stride, - const int kv_block_stride, const int kv_head_stride, const float k_scale, - const float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias, - q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); + padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 13a10221db425..cbe5d1dd6f3f6 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -29,21 +29,21 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, attn_bias_ptr, q_stride, kv_block_stride, \ - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. template (attn_bias.value().data_ptr()) : nullptr; - if (attn_bias_ptr){ - TORCH_CHECK(attn_bias.value().dtype() == torch::kFloat32, "Unsupported bias dtype: ", attn_bias.value().dtype()); + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, + "Unexpected attn_bias shape: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 80e1d7cb962df..2b25a6afe765f 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -37,9 +37,10 @@ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ - attn_bias_ptr, q_stride, kv_block_stride, kv_head_stride, k_scale, \ - v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); \ + attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \ + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ @@ -78,7 +79,13 @@ void paged_attention_v2_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; - + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, + "Unexpected attn_bias shape: ", abias.sizes()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); @@ -91,6 +98,8 @@ void paged_attention_v2_launcher( constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); From 1d1f2a0487bad8e5e4ca02b7c78b51a1dfc621b5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:22:50 +0000 Subject: [PATCH 03/17] add attn_bias case to pagedattn tests Signed-off-by: NickLucche --- tests/kernels/test_attention.py | 51 ++++++++++++++------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 6282a3abd6887..0fb374266b628 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -18,8 +18,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -MAX_SEQ_LEN = 16 +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -30,7 +29,6 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing -# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 @@ -109,9 +107,7 @@ def ref_single_query_cached_kv_attention( 1, 1, -1) bias = alibi_bias if attn_bias is not None: - # TODO test alibi + bias bias = attn_bias[i] if bias is None else bias + attn_bias[i] - # print(f"ATTN BIAS {i}: {attn_bias[i]}") out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -144,10 +140,6 @@ def test_paged_attention( seed: int, device: str, ) -> None: - # num_heads = (2, 2) - # num_seqs = 2 - # head_size = 32 - if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -173,16 +165,20 @@ def test_paged_attention( if use_custom_attn_bias: # NOTE (NickLucche) each sequence can have a different bias, # depending on its len, but it *must* be float (f32)! - attn_bias_list = [torch.randn(num_query_heads, - 1, - seq_len, - dtype=torch.float) for seq_len in seq_lens] - attn_bias = torch.empty(num_seqs, num_query_heads, 1, max_seq_len, device=device, dtype=torch.float) + attn_bias_list = [ + torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) + for seq_len in seq_lens + ] + attn_bias = torch.empty(num_seqs, + num_query_heads, + 1, + max_seq_len, + device=device, + dtype=torch.float) for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): - # first seq_len entries of the bias for each head/seq + # first `seq_len` entries of the bias matrix for each head/seq attn_bias[i, :, :, :seq_len] = bias - # print("bias shape", attn_bias.shape) # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size @@ -208,7 +204,6 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -227,18 +222,15 @@ def test_paged_attention( k_scale, v_scale, ) - # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - attn_bias, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): - assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -273,14 +265,14 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - attn_bias, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -338,7 +330,6 @@ def test_paged_attention( key_cache, value_cache, block_tables, seq_lens, scale, alibi_slopes, attn_bias_list) - # print("\nREF OUT", ref_output) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two From 7fb263d15863f20f068d105e8899ae76d2bea9a3 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:32:29 +0000 Subject: [PATCH 04/17] format Signed-off-by: NickLucche --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- vllm/attention/backends/blocksparse_attn.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 2 +- vllm/attention/backends/xformers.py | 4 +++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 120b8ffe9c657..53d1f5803cf20 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -114,7 +114,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, - None, # TODO add custom bias + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 91de16b4608d4..ee37ff476f810 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -440,7 +440,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - None, # TODO support attn_bias + None, # TODO support attn_bias k_scale, v_scale, tp_rank=self.tp_rank, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1793830f8dd0b..337a7b185c4c6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -628,7 +628,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - None, # TODO support attn_bias + None, # TODO support attn_bias k_scale, v_scale, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 419f174b753ab..a5cd2ff8ba217 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -608,7 +608,9 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - decode_meta.attn_bias, # TODO or cross_attn_bias?? + # TODO (NickLucche) cross_attn_bias not needed for T5-like + # models, abstract bias selection if needed. + decode_meta.attn_bias, k_scale, v_scale, ) From ac6bf63ae5c70cd3ddafff47d1e5eeb0d18b266d Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:57:23 +0000 Subject: [PATCH 05/17] format Signed-off-by: NickLucche --- vllm/attention/backends/xformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a5cd2ff8ba217..a8aafcc049608 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -608,7 +608,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - # TODO (NickLucche) cross_attn_bias not needed for T5-like + # TODO (NickLucche) cross_attn_bias not needed for T5-like # models, abstract bias selection if needed. decode_meta.attn_bias, k_scale, From 5c47f43fc80e78cdd6e91a8ed3449ec7eac240e5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 27 Dec 2024 17:21:43 +0000 Subject: [PATCH 06/17] enforce last dim of attn bias to be block aligned Signed-off-by: NickLucche --- csrc/attention/paged_attention_v1.cu | 17 ++++++++++------- csrc/attention/paged_attention_v2.cu | 21 ++++++++++++--------- tests/kernels/test_attention.py | 19 +++++++++++-------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index cbe5d1dd6f3f6..0b04b55a4e13a 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -77,12 +77,17 @@ void paged_attention_v1_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -92,13 +97,11 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int logits_size = padded_max_seq_len * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 2b25a6afe765f..5eeba75d5cf1c 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -79,12 +79,17 @@ void paged_attention_v2_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -97,18 +102,16 @@ void paged_attention_v2_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + const int logits_size = PARTITION_SIZE * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel. dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 0fb374266b628..b9cfe6437183f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -162,26 +162,29 @@ def test_paged_attention( max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) attn_bias_list = None + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size if use_custom_attn_bias: # NOTE (NickLucche) each sequence can have a different bias, - # depending on its len, but it *must* be float (f32)! + # depending on its len, but it *must* be padded to the block + # aligned max_seq_len and of type float32! attn_bias_list = [ torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) for seq_len in seq_lens ] - attn_bias = torch.empty(num_seqs, - num_query_heads, - 1, - max_seq_len, - device=device, - dtype=torch.float) + block_aligned_max_seq_len = max_num_blocks_per_seq * block_size + attn_bias = torch.empty( + num_seqs, + num_query_heads, + 1, + block_aligned_max_seq_len, # padded dim + device=device, + dtype=torch.float) for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): # first `seq_len` entries of the bias matrix for each head/seq attn_bias[i, :, :, :seq_len] = bias # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [ From f97939fdb5791954c0f4d106bac7079a26c62661 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 2 Dec 2024 18:02:18 +0000 Subject: [PATCH 07/17] wip Signed-off-by: NickLucche --- csrc/attention/attention_kernels.cuh | 15 +++++++++++++++ tests/kernels/test_attention.py | 11 ++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 08f9882f65f09..a83bbf8b7648f 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -155,6 +155,21 @@ __device__ void paged_attention_kernel( const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + // TODO check if indexing still makes sense + // seq_len indexes on 'max_seq_lens' dim, + // it's like renaming dim you get attn_bias: seq_len x num_kv_heads x seq_len + // TODO each seq can have different len (seq_lens) but only one bias!! + // NOTE (NickLucche) `max_seq_len` bias values for current sequence and current head + const float* attn_bias_vec = + attn_bias == nullptr + ? nullptr + : attn_bias + seq_idx * num_heads * num_seq_blocks * BLOCK_SIZE + + head_idx * num_seq_blocks * BLOCK_SIZE; + // : attn_bias + seq_idx * num_kv_heads * num_seq_blocks * BLOCK_SIZE + + // const float* attn_bias_vec = attn_bias == nullptr + // ? nullptr + // : attn_bias + seq_idx * num_kv_heads * seq_len + + // kv_head_idx * seq_len; // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence // and current head. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b9cfe6437183f..de5aba65d50d0 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -18,7 +18,8 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = 16 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -29,6 +30,7 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 @@ -140,6 +142,10 @@ def test_paged_attention( seed: int, device: str, ) -> None: + # num_heads = (2, 2) + # num_seqs = 2 + # head_size = 32 + if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -207,6 +213,7 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) + # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -225,6 +232,7 @@ def test_paged_attention( k_scale, v_scale, ) + # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, @@ -234,6 +242,7 @@ def test_paged_attention( and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): + assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape From f8df36a8f8844c6ba0a276a9e8b59624c52032f4 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 2 Dec 2024 18:02:18 +0000 Subject: [PATCH 08/17] wip Signed-off-by: NickLucche --- .../offline_inference_encoder_decoder.py | 58 +- .../encoder_decoder/language/conftest.py | 147 ++++ .../encoder_decoder/language/test_bart.py | 162 +--- .../encoder_decoder/language/test_t5.py | 265 +++++++ .../models/encoder_decoder/language/utils.py | 16 + vllm/attention/backends/xformers.py | 30 +- vllm/inputs/preprocess.py | 3 + vllm/model_executor/models/registry.py | 2 + vllm/model_executor/models/t5.py | 697 ++++++++++++++++++ 9 files changed, 1170 insertions(+), 210 deletions(-) create mode 100644 tests/models/encoder_decoder/language/conftest.py create mode 100644 tests/models/encoder_decoder/language/test_t5.py create mode 100644 tests/models/encoder_decoder/language/utils.py create mode 100644 vllm/model_executor/models/t5.py diff --git a/examples/offline_inference/offline_inference_encoder_decoder.py b/examples/offline_inference/offline_inference_encoder_decoder.py index 0f266d7918853..f386ebb6c2176 100644 --- a/examples/offline_inference/offline_inference_encoder_decoder.py +++ b/examples/offline_inference/offline_inference_encoder_decoder.py @@ -11,8 +11,10 @@ # Create a BART encoder/decoder model instance llm = LLM( - model="facebook/bart-large-cnn", + # model="facebook/bart-large-cnn", + model="google-t5/t5-small", dtype=dtype, + enforce_eager=True ) # Get BART tokenizer @@ -24,41 +26,9 @@ # encoder/decoder model. # # - Helpers for building prompts -text_prompt_raw = "Hello, my name is" -text_prompt = TextPrompt(prompt="The president of the United States is") -tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( - prompt="The capital of France is")) -# - Pass a single prompt to encoder/decoder model -# (implicitly encoder input prompt); -# decoder input prompt is assumed to be None +to_translate = "My name is Azeem and I live in India" +text_prompt_raw = "translate English to German: "+to_translate -single_text_prompt_raw = text_prompt_raw # Pass a string directly -single_text_prompt = text_prompt # Pass a TextPrompt -single_tokens_prompt = tokens_prompt # Pass a TokensPrompt - -# - Pass explicit encoder and decoder input prompts within one data structure. -# Encoder and decoder prompts can both independently be text or tokens, with -# no requirement that they be the same prompt type. Some example prompt-type -# combinations are shown below, note that these are not exhaustive. - -enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt string directly, & - # pass decoder prompt tokens - encoder_prompt=single_text_prompt_raw, - decoder_prompt=single_tokens_prompt, -) -enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( - # Pass TextPrompt to encoder, and - # pass decoder prompt string directly - encoder_prompt=single_text_prompt, - decoder_prompt=single_text_prompt_raw, -) -enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( - # Pass encoder prompt tokens directly, and - # pass TextPrompt to decoder - encoder_prompt=single_tokens_prompt, - decoder_prompt=single_text_prompt, -) # - Finally, here's a useful helper function for zipping encoder and # decoder prompts together into a list of ExplicitEncoderDecoderPrompt @@ -69,19 +39,21 @@ # - Let's put all of the above example prompts together into one list # which we will pass to the encoder/decoder LLM. -prompts = [ - single_text_prompt_raw, single_text_prompt, single_tokens_prompt, - enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 -] + zipped_prompt_list +# prompts = [ +# single_text_prompt_raw, single_text_prompt, single_tokens_prompt, +# enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 +# ] + zipped_prompt_list +prompts = [text_prompt_raw]#, "Se ni mondo"] print(prompts) # Create a sampling params object. sampling_params = SamplingParams( - temperature=0, - top_p=1.0, - min_tokens=0, - max_tokens=20, + temperature=0.2, + max_tokens=100, + # top_p=1.0, + # min_tokens=0, + # max_tokens=20, ) # Generate output tokens from the prompts. The output is a list of diff --git a/tests/models/encoder_decoder/language/conftest.py b/tests/models/encoder_decoder/language/conftest.py new file mode 100644 index 0000000000000..a751f73dfa9a6 --- /dev/null +++ b/tests/models/encoder_decoder/language/conftest.py @@ -0,0 +1,147 @@ +from transformers import AutoModelForSeq2SeqLM +from ....conftest import (DecoderPromptType, HfRunner, VllmRunner, ExplicitEncoderDecoderPrompt) +from typing import List, Optional, Tuple, Type, Dict, Any +from ...utils import check_logprobs_close +from .utils import vllm_to_hf_output + + + +# TODO docs +def compare_hf_vllm_logprobs( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, + vllm_runner_kwargs: Optional[Dict[str, Any]] = dict(), +) -> None: + ''' + Test the vLLM BART model for a variety of encoder/decoder input prompts, + by validating it against HuggingFace (HF) BART. + + Arguments: + + * hf_runner: HuggingFace (HF) test model runner + * vllm_runner: vLLM test model runner + * example_encoder_decoder_prompts: test fixture which provides a + dictionary of dummy prompts + * model: the HF ID of the specific BART variant under test + * dtype: the tensor datatype to employ + * max_tokens + * num_logprobs + * decoder_prompt_type: key into the example_encoder_decoder_prompts + dictionary; selects specific encoder/decoder + prompt scenarios to test + + A note on using HF BART as a baseline for validating vLLM BART, + specifically when the decoder prompt is None. + + The HF GenerationMixin's default behavior is to force the first + decoded token to be if the prompt does not already contain + (this is accomplished using a logit + processor setting.) + + So when we use HF BART as our baseline for comparison, note that + when the user provides a request with a None decoder prompt + (i.e. a singleton encoder prompt, or else an explicit encoder/ + decoder prompt with the decoder sub-prompt set to None), HF and + vLLM handle this in different ways: + + * HF will (1) tokenize the None prompt as an empty token-list, + (2) append to the beginning, yielding + [], (3) pass this token list to the model, and + then (4) after computing logits during prefill, override the model + logits & force to be the first generated token. + + * vLLM will (1) tokenize the None prompt as [], (2) append decoder- + start-token to the beginning, yielding [], + (3) pass these tokens to the model & proceed with generation. + + The net effect is that compared to vLLM, the list of HF *decoded* tokens + will contain one more initial than the vLLM generated tokens, + because vLLM's token is injected into the prompt rather than into + the generated output. This is in spite of the fact that overall, the + complete sequences (prompt + decoded tokens) produced by vLLM will match + HF. + + So when we use HF decoded token output to validate vLLM's decoded token + output, the testing process must account for the difference in decoded + token sequences between vLLM and HF specifically in the + decoder-prompt-is-None case. + + One option is to disable the logit processor feature that forces the + token to be decoded (forced_bos_token_id = None), eliminating + the problem entirely. However this is not "normal" BART usage. + + The other option is - only in the decoder-prompt-is-None case - to + discard the first decoded token from the HF output before comparing it + to vLLM. + + To that end, when testing the scenario where the decoder prompt is None + (and only in that one scenario), this test skips the first HF decoded + token during the process of validating the vLLM decoded output. + ''' + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default). + + # Note: currently encoder/decoder models are only compatible with + # enforce_eager=True. Normally this is not a problem because + # for encoder/decoder models vLLM will + # default to enforce_eager=True if enforce_eager + # is left unspecified. However, the + # VllmRunner test fixture (which wraps around the LLM class) defaults to + # enforce_eager=False (a behavior which a number of already-exisitng + # decoder-only unit tests expect), so when testing an encoder/decoder + # model we must explicitly specify enforce_eager=True in the VllmRunner + # constructor. + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, **vllm_runner_kwargs) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs) + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + hf_skip_tokens = (1 + if decoder_prompt_type == DecoderPromptType.NONE else 0) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) \ No newline at end of file diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 10aba8427944f..b43edb6fcb5f8 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -5,167 +5,11 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.sequence import SampleLogprobs from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, HfRunner, VllmRunner) from ....utils import multi_gpu_test -from ...utils import check_logprobs_close - - -def vllm_to_hf_output( - vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "" + hf_output_str - - return output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -) -> None: - ''' - Test the vLLM BART model for a variety of encoder/decoder input prompts, - by validating it against HuggingFace (HF) BART. - - Arguments: - - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - * model: the HF ID of the specific BART variant under test - * dtype: the tensor datatype to employ - * max_tokens - * num_logprobs - * decoder_prompt_type: key into the example_encoder_decoder_prompts - dictionary; selects specific encoder/decoder - prompt scenarios to test - - A note on using HF BART as a baseline for validating vLLM BART, - specifically when the decoder prompt is None. - - The HF GenerationMixin's default behavior is to force the first - decoded token to be if the prompt does not already contain - (this is accomplished using a logit - processor setting.) - - So when we use HF BART as our baseline for comparison, note that - when the user provides a request with a None decoder prompt - (i.e. a singleton encoder prompt, or else an explicit encoder/ - decoder prompt with the decoder sub-prompt set to None), HF and - vLLM handle this in different ways: - - * HF will (1) tokenize the None prompt as an empty token-list, - (2) append to the beginning, yielding - [], (3) pass this token list to the model, and - then (4) after computing logits during prefill, override the model - logits & force to be the first generated token. - - * vLLM will (1) tokenize the None prompt as [], (2) append decoder- - start-token to the beginning, yielding [], - (3) pass these tokens to the model & proceed with generation. - - The net effect is that compared to vLLM, the list of HF *decoded* tokens - will contain one more initial than the vLLM generated tokens, - because vLLM's token is injected into the prompt rather than into - the generated output. This is in spite of the fact that overall, the - complete sequences (prompt + decoded tokens) produced by vLLM will match - HF. - - So when we use HF decoded token output to validate vLLM's decoded token - output, the testing process must account for the difference in decoded - token sequences between vLLM and HF specifically in the - decoder-prompt-is-None case. - - One option is to disable the logit processor feature that forces the - token to be decoded (forced_bos_token_id = None), eliminating - the problem entirely. However this is not "normal" BART usage. - - The other option is - only in the decoder-prompt-is-None case - to - discard the first decoded token from the HF output before comparing it - to vLLM. - - To that end, when testing the scenario where the decoder prompt is None - (and only in that one scenario), this test skips the first HF decoded - token during the process of validating the vLLM decoded output. - ''' - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default). - - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-exisitng - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) +from .utils import compare_hf_vllm_logprobs @pytest.mark.parametrize( @@ -183,7 +27,7 @@ def run_test( def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - run_test( + compare_hf_vllm_logprobs( hf_runner, vllm_runner, example_encoder_decoder_prompts[decoder_prompt_type], @@ -208,7 +52,7 @@ def test_models_distributed(hf_runner, vllm_runner, distributed_executor_backend, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - run_test( + compare_hf_vllm_logprobs( hf_runner, vllm_runner, example_encoder_decoder_prompts[decoder_prompt_type], diff --git a/tests/models/encoder_decoder/language/test_t5.py b/tests/models/encoder_decoder/language/test_t5.py new file mode 100644 index 0000000000000..1ff43defbc726 --- /dev/null +++ b/tests/models/encoder_decoder/language/test_t5.py @@ -0,0 +1,265 @@ +"""Compare the outputs of HF and vLLM for T5 models using greedy sampling. +Based on tests/models/encoder_decoder/language/test_bart.py. + +Run `pytest tests/models/encoder_decoder/language/test_t5.py`. +""" +from typing import List, Optional, Tuple, Type + +import pytest +from transformers import AutoModelForSeq2SeqLM + +from tests.kernels.utils import make_test_metadata +from vllm.attention.layer import Attention +from vllm.attention.selector import global_force_attn_backend_context_manager +from vllm.config import set_current_vllm_config + +from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, + HfRunner, VllmRunner) +from ....utils import multi_gpu_test +from .conftest import compare_hf_vllm_logprobs +import torch +from vllm.model_executor.models.t5 import T5Attention, T5Config, AttentionType +from vllm.platforms import current_platform + +@pytest.mark.parametrize( + "model", + [ + # pytest.param("google/t5-small", + # marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("google-t5/t5-small"), + ], +) +@pytest.mark.parametrize( + "vllm_kwargs", + [{ + "max_model_len": 512 + }] + ) +@pytest.mark.parametrize("dtype", ["float"])#, "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type, vllm_kwargs) -> None: + # TODO force backend + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + vllm_runner_kwargs=vllm_kwargs + ) + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +@pytest.fixture +def dist_init(): + from vllm.distributed import init_distributed_environment, cleanup_dist_env_and_memory, initialize_model_parallel + import tempfile + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield + cleanup_dist_env_and_memory() + +# TODO more cases +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +def test_t5_bias_attention(dtype, dist_init) -> None: + import random + seed = 0 + MAX_SEQ_LEN = 32 + block_size = 16 + NUM_BLOCKS = 4321 + current_platform.seed_everything(seed) + config = T5Config() + + # setup kv caches + head_size = config.d_kv + num_heads = (config.num_heads, config.num_heads) + num_seqs = 1 + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + # query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + # query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + # seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables_lst: List[List[int]] = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables_lst.append(block_table) + + block_tables = torch.tensor(block_tables_lst, dtype=torch.int) + + # Create the KV caches. + kv_cache_dtype = 'auto' + from vllm.utils import create_kv_caches_with_random + key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + 'cuda') + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + k_scale = v_scale = 1.0 + + + from vllm.attention.selector import _Backend + x = torch.randn(num_seqs, max_seq_len, config.d_model, device='cuda', dtype=torch.float) + with global_force_attn_backend_context_manager(_Backend.XFORMERS): + + from vllm.attention.backends.xformers import XFormersMetadata + from vllm.attention.backends.xformers import XFormersBackend + from vllm import LLM + + from vllm.forward_context import set_forward_context + from vllm.config import VllmConfig + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + encoder_seq_start_loc = torch.zeros(len(seq_lens) + 1, + dtype=torch.int32, + device='cuda') + meta = XFormersBackend.make_metadata( + seq_lens=None,#seq_lens, + max_decode_seq_len=0, num_prefills=None, + num_prefill_tokens=None, num_decode_tokens=0, + seq_lens_tensor=None,#torch.tensor(seq_lens), + slot_mapping=None,#torch.zeros(1), + multi_modal_placeholder_index_maps=None, + max_prefill_seq_len=None,#MAX_SEQ_LEN, + use_cuda_graph=False, + context_lens_tensor=None, + # no block tables on encoder forward + block_tables=torch.tensor([]).cuda(), + # block_tables=block_tables, + num_encoder_tokens=sum(seq_lens), encoder_seq_lens=seq_lens,encoder_seq_lens_tensor=torch.tensor(seq_lens).cuda(), + max_encoder_seq_len=max(seq_lens), encoder_seq_start_loc=encoder_seq_start_loc) + # # NOTE use compute_bias here + # attn_bias = t5_attn.compute_bias(MAX_SEQ_LEN, MAX_SEQ_LEN) + + # same weights should be loaded + # TODO load model without engine overhead + llm = LLM(model="google-t5/t5-small", load_format='safetensors', enforce_eager=True, dtype='float') + model = llm.llm_engine.model_executor.driver_worker.model_runner.model + t5_attn = model.model.encoder.blocks[0].self_attn.SelfAttention + print("\nTYPE", type(t5_attn)) + # TODO decoder + # FIXME this is kinda close, maybe issue is not with xformers custom bias attn + # t5_attn = T5Attention(config, AttentionType.ENCODER, has_relative_attention_bias=True).cuda() + # t5_attn.has_relative_attention_bias = False + assert t5_attn.has_relative_attention_bias + from transformers import T5Tokenizer, T5ForConditionalGeneration + from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention + hfmodel = T5ForConditionalGeneration.from_pretrained('google-t5/t5-small', return_dict=True) + print("My T5", t5_attn) + # this must be set to call attn.impl.forward + # vllm_config.compilation_config.static_forward_context[".attn"] = t5_attn.attn + vllm_config.compilation_config.static_forward_context["model.encoder.blocks.0.self_attn.SelfAttention.attn"] = t5_attn.attn + hf_attn = hfmodel.encoder.block[0].layer[0].SelfAttention.cuda() + # hf_attn.has_relative_attention_bias = False + assert hf_attn.has_relative_attention_bias + # hf_attn = HFT5Attention(config, has_relative_attention_bias=True).cuda() + + + with set_forward_context(meta, vllm_config): + # kv_cache for xformers [2, num_blocks, block_size * num_kv_heads * head_size] + kvc = torch.stack([key_cache.reshape(NUM_BLOCKS, -1), value_cache.reshape(NUM_BLOCKS, -1)], 0) + output = t5_attn(x, kvc, meta) + ref_output, *_ = hf_attn(x) + + atol, rtol = 1e-3, 1e-5 + # torch.testing.assert_close(output, ref_output.squeeze(), atol=atol, rtol=rtol) + + # **cross attn** + t5_attn = model.model.decoder.blocks[0].cross_attn.EncDecAttention + print("\nTYPE", type(t5_attn)) + assert not t5_attn.has_relative_attention_bias + vllm_config.compilation_config.static_forward_context["model.decoder.blocks.0.cross_attn.EncDecAttention.attn"] = t5_attn.attn + hf_attn = hfmodel.decoder.block[0].layer[1].EncDecAttention.cuda() + assert not hf_attn.has_relative_attention_bias + + meta = XFormersBackend.make_metadata( + seq_lens=seq_lens, + max_decode_seq_len=MAX_SEQ_LEN, num_prefills=0, + num_prefill_tokens=0, num_decode_tokens=1, + max_prefill_seq_len=None, + seq_lens_tensor=torch.tensor(seq_lens), + slot_mapping=None,#torch.zeros(1), + multi_modal_placeholder_index_maps=None, + use_cuda_graph=False, + context_lens_tensor=None, + block_tables=torch.tensor([]).cuda(), + # block_tables=block_tables + ) + + + with set_forward_context(meta, vllm_config): + output = t5_attn(x, kvc, meta) + ref_output, *_ = hf_attn(x) + + torch.testing.assert_close(output, ref_output.squeeze(), atol=atol, rtol=rtol) + + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", ["google/t5-small"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +def test_models_distributed(hf_runner, vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, model, dtype, + max_tokens, num_logprobs, + decoder_prompt_type) -> None: + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/models/encoder_decoder/language/utils.py b/tests/models/encoder_decoder/language/utils.py new file mode 100644 index 0000000000000..c4828683dafdb --- /dev/null +++ b/tests/models/encoder_decoder/language/utils.py @@ -0,0 +1,16 @@ +from typing import List, Optional, Tuple +from ....conftest import (DecoderPromptType) +from vllm.sequence import SampleLogprobs + +def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a8aafcc049608..6b4839bd90771 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -285,7 +285,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: def _get_attn_bias( attn_metadata: XFormersMetadata, attn_type: str, -) -> Optional[AttentionBias]: +) -> Optional[List[AttentionBias]]: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -551,6 +551,8 @@ def forward( # normal attention. # block tables are empty if the prompt does not have a cached # prefix. + # TODO this should be forwarded when splitting prefill/decode_meta + _set_attn_bias(prefill_meta, _get_attn_bias(attn_metadata, attn_type), attn_type) out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta, attn_type=attn_type) assert out.shape == output[:num_prefill_query_tokens].shape @@ -579,7 +581,7 @@ def forward( prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, - prefill_meta.attn_bias, + _get_attn_bias(attn_metadata, attn_type), self.sliding_window, k_scale, v_scale, @@ -597,6 +599,10 @@ def forward( block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + attn_bias = _get_attn_bias(attn_metadata, attn_type) + if attn_bias: + attn_bias = attn_bias[0] + # print("Bias shape", attn_bias.shape) output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, @@ -608,9 +614,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - # TODO (NickLucche) cross_attn_bias not needed for T5-like - # models, abstract bias selection if needed. - decode_meta.attn_bias, + attn_bias, k_scale, v_scale, ) @@ -662,6 +666,7 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. + # FIXME this is None should be rel pos encoding attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: @@ -718,15 +723,23 @@ def _run_memory_efficient_xformers_forward( attn_metadata.seq_lens) _set_attn_bias(attn_metadata, attn_bias, attn_type) + + # if isinstance(attn_bias[0], torch.Tensor): + # print("IS THIS WORKING PREFILL shape", [b.shape for b in attn_bias]) + # print("IS THIS WORKING PREFILL stride", [b.stride() for b in attn_bias]) + # print("QUERY shape", query.shape, key.shape) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. - if self.alibi_slopes is None: + # TODO refactor custom attn bias must not go here + if self.alibi_slopes is None and len(attn_bias)==1: # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) + # if isinstance(attn_bias[0], torch.Tensor): + # print("RUNNING SINGLE ATTN BIAS VERSION WITH", attn_bias[0].shape) out = xops.memory_efficient_attention_forward( query, key, @@ -735,15 +748,16 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) return out.view_as(original_query) + # Attention with alibi slopes. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - assert attn_metadata.seq_lens is not None output = torch.empty_like(original_query) + seq_lens = attn_metadata.encoder_seq_lens if attn_type == AttentionType.ENCODER else attn_metadata.seq_lens start = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens): + for i, seq_len in enumerate(seq_lens): end = start + seq_len out = xops.memory_efficient_attention_forward( query[None, start:end], diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a738ffe18e3ae..5e35ca02bf811 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -122,6 +122,9 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' bos_token_id = self.get_bos_token_id() + if bos_token_id is None: + # TODO do I have to make another config to set pad id as bos? T5 has no bos..pad is used in transformers too + bos_token_id = 0 assert bos_token_id is not None return [bos_token_id] diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 62840b8c1bcda..44dd813911c1f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -100,6 +100,8 @@ "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 + "T5Model": ("t5", "T5ForConditionalGeneration"), + "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py new file mode 100644 index 0000000000000..8867989e1f1cb --- /dev/null +++ b/vllm/model_executor/models/t5.py @@ -0,0 +1,697 @@ +# Derived from T5 implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch T5 model.""" + +import math +from typing import List, Optional, Tuple, Union, Set, Iterable +import re + +import torch +from torch import nn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import (ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.attention.layer import Attention, AttentionType, AttentionMetadata +from vllm.config import CacheConfig +from transformers import T5Config +from vllm.attention.backends.xformers import XFormersMetadata +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import get_sampler, SamplerOutput +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors +from vllm.model_executor.sampling_metadata import SamplingMetadata +from .utils import maybe_prefix + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states)->torch.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states)->torch.Tensor: + hidden_states, _ = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + # if ( + # isinstance(self.wo.weight, torch.Tensor) + # and hidden_states.dtype != self.wo.weight.dtype + # and self.wo.weight.dtype != torch.int8 + # ): + # hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.wi_0 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) + self.wi_1 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states)->torch.Tensor: + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + + # TODO + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + # if ( + # isinstance(self.wo.weight, torch.Tensor) + # and hidden_states.dtype != self.wo.weight.dtype + # and self.wo.weight.dtype != torch.int8 + # ): + # hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config, quant_config=quant_config) + else: + self.DenseReluDense = T5DenseActDense(config, quant_config=quant_config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states)->torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(nn.Module): + def __init__( + self, + config: T5Config, + attn_type: AttentionType, + has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "" + ): + super().__init__() + self.prefix = prefix + self.attn_type = attn_type + # Cross-attention has no relative pos encoding anyway + self.is_decoder = attn_type == AttentionType.DECODER + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + # No GQA in t5. + self.n_kv_heads = self.n_heads + + self.qkv_proj = QKVParallelLinear(self.d_model, self.d_model // self.n_heads, self.n_heads, self.n_kv_heads, bias=False, quant_config=quant_config) + + # TODO refactor in utils shared with bart + # if self.total_num_kv_heads >= tp_world_size: + # # Number of KV heads is greater than TP size, so we partition + # # the KV heads across multiple tensor parallel GPUs. + # assert self.total_num_kv_heads % tp_world_size == 0 + # else: + # # Number of KV heads is less than TP size, so we replicate + # # the KV heads across multiple tensor parallel GPUs. + # assert tp_world_size % self.total_num_kv_heads == 0 + # self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + + # NOTE (NickLucche) T5 employs a scaled weight initialization scheme + # instead of scaling attention scores directly. + self.attn = Attention(self.n_heads, config.d_kv, 1.0, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn") + + # Only the first SelfAttention block in encoder decoder has this + # embedding layer, the others re-use its output. + if self.has_relative_attention_bias: + print(f"EMBEDDING {self.relative_attention_num_buckets} TO {self.n_heads}, max dist {self.relative_attention_max_distance}") + self.relative_attention_bias = VocabParallelEmbedding(self.relative_attention_num_buckets,\ + self.n_heads, org_num_embeddings=self.relative_attention_num_buckets, quant_config=quant_config) + # self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.out_proj = RowParallelLinear( + self.inner_dim, + self.d_model, + bias=False, + quant_config=quant_config, + ) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None)->torch.Tensor: + """Compute binned relative position bias""" + # TODO possible tp issue? + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + # max_seq_len, nh + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + x = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return x + + def forward( + self, + hidden_states: torch.Tensor, # (num_tokens, d_model) + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + )->torch.Tensor: + # TODO auto-selection of xformers backend when t5 is detected + assert isinstance(attn_metadata, XFormersMetadata) + is_profile_run = kv_cache.numel() == 0 + if not is_profile_run: + # TODO xformers only + block_size = kv_cache.shape[2] // self.inner_dim + num_seqs = len(attn_metadata.seq_lens) if attn_metadata.seq_lens else len(attn_metadata.encoder_seq_lens) + qkv, _ = self.qkv_proj(hidden_states) + # Projection of 'own' hidden state (self-attention). No GQA here. + q, k, v = qkv.split(self.inner_dim, dim=-1) + + # NOTE (NickLucche) Attn bias is computed once per encoder or decoder + # forward, on the first call to T5Attention.forward. Subsequent + # *self-attention* layers will re-use it. + # TODO func should be in backend interface + from vllm.attention.backends.xformers import _get_attn_bias, _set_attn_bias + attn_bias = _get_attn_bias(attn_metadata, self.attn_type) + if self.attn_type == AttentionType.ENCODER_DECODER: + # Projection of encoder's hidden states, cross-attention. + if encoder_hidden_states is None: + # Decode phase, kv already cached + assert attn_metadata.num_prefills == 0 + k = None + v = None + else: + assert attn_metadata.num_prefills > 0 + # Prefill phase (first decode forward), caching kv + qkv_enc, _ = self.qkv_proj(encoder_hidden_states) + _, k, v = qkv_enc.split(self.inner_dim, dim=-1) + # No custom attention bias must be set when running cross attn. + assert attn_bias is None + # Skip when profiling. + # FIXME should be enabled on profiling run to assess memory of bias. + # TODO NOT compatible with CP here, assumes homogeneous batch + elif self.has_relative_attention_bias and not is_profile_run: + assert attn_bias is None # to be recomputed + # Self-attention. T5 relative positional encoding. + # Compute bias based on longest sequence in batch. Biases for + # shorter sequences are subsets of the longest. + if self.attn_type == AttentionType.ENCODER: + seq_len = attn_metadata.max_encoder_seq_len + else: + # Decoder can receive both prefill and decoding requests + seq_len = attn_metadata.max_prefill_seq_len if attn_metadata.prefill_metadata else attn_metadata.max_decode_seq_len + block_aligned_seq_len = (seq_len + block_size - 1) // block_size * block_size + + # TODO xformers-specific, attention bias are to be provided as a list. + align_to = 8 + print("IN", hidden_states.shape) + # TODO chunked prefill + # what I want: (num_seqs, NH, L, L_pad) for prefill, (num_seqs, NH, 1, L_pad) for decodes + if self.attn_type == AttentionType.ENCODER: + # NOTE seq padding needed for xformers! + # HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, you need to ensure memory is aligned by slicing a bigger tensor. Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])` + padded_seq_len = (seq_len + align_to - 1) // align_to * align_to + # FIXME blockdiagonal mask with tensor bias?? + pb = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) + print(f"[{self.prefix}][ENCODING, xformers aligned] Seq len compute bias/shape", seq_len, padded_seq_len) + # print("SEQ LEN", attn_metadata.encoder_seq_lens, pb.shape) xformers needs a list, one matrix per sequence + attn_metadata.encoder_attn_bias = [p[:, :sq, :sq].unsqueeze(0) for p, sq in zip(pb, attn_metadata.encoder_seq_lens)] + print("Bias shape", attn_metadata.encoder_attn_bias[0].shape) + elif attn_metadata.decode_metadata is None: # TODO join with statement above + # NOTE first decoder step, prefill: seq len here is usually 1, but one can prepend different start tokens prior to generation. XFormers is used. + print(f"[{self.prefix}][DECODER but xformers prefill]", attn_metadata.max_decode_seq_len, attn_metadata.max_prefill_seq_len) + padded_seq_len = (seq_len + align_to - 1) // align_to * align_to + # position_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1) + # this is always 1 (seqlen) but needs filling!! + # position_bias = self.compute_bias(seq_len, align_to).repeat(num_seqs, 1, 1, 1) + position_bias = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) + print(f"[{self.prefix}][DECODER but xformers prefill] Seq len compute bias/shape", seq_len, position_bias.shape) + # ->align + # position_bias = position_bias.repeat(1, 1, align_to, align_to) + # print("DECODER RUN BUT NO METADATA", position_bias.shape) + # TODO debug, can it be removed? + # position_bias[:, :, 1:, 1:] = torch.finfo(position_bias.dtype).min + # attn_metadata.attn_bias = [pb[None, :, :seq_len, :seq_len] for pb in position_bias] + attn_metadata.attn_bias = [pb[None, :, :seq_len, :seq_len] for pb, sq in zip(position_bias, attn_metadata.seq_lens)] + # attn_metadata.attn_bias = [position_bias[:, :, :1, :seq_len]] + print("[DECODER but xformers prefill] Bias shape", attn_metadata.attn_bias[0].shape) + else: + # Repeat along dim0: (num_seqs, n_heads, 1, L) + # TODO (NickLucche): allow single bias for whole batch to avoid extra-copy. + # TODO this needs to be block aligned!! + # position_bias = self.compute_bias(seq_len, block_aligned_seq_len).repeat(num_seqs, 1, 1, 1) + position_bias = self.compute_bias(1, block_aligned_seq_len).repeat(num_seqs, 1, 1, 1) + # position_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1) + print(f"[{self.prefix}][DECODING] Seq len compute bias/shape", seq_len, seq_len) + # position_bias[:, :, seq_len:, seq_len:] = torch.finfo(position_bias.dtype).min + attn_metadata.attn_bias = [position_bias] + # attn_metadata.attn_bias = [position_bias[:, :, :seq_len, :seq_len]] + print("Bias shape", attn_metadata.attn_bias[0].shape) + # TODO set attn bias + elif not self.has_relative_attention_bias and not is_profile_run: + # Encoder/Decoder Self-Attention Layer, attn bias already cached. + assert attn_bias is not None + + # masking extra bias entries is done in the kernel + # mask = (seq_range >= prompt_lens).unsqueeze(1).unsqueeze(3) + # position_bias *= mask + # xformers masking + # for i in range(batch_size): + # input_metadata.attn_bias[ + # i, :, :, + # input_metadata.prompt_lens[i]:, ] = torch.finfo( + # input_metadata.attn_bias.dtype).min + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=self.attn_type) + # if not self.is_decoder: + # print("\n\nSAVING\n\n") + # torch.save(attn_output, "tensor_vllm.pth") + # FIXME all equal until the next line + output, _ = self.out_proj(attn_output) + return output + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): + super().__init__() + self.SelfAttention = T5Attention(config, AttentionType.DECODER if "decoder" in prefix else AttentionType.ENCODER, + has_relative_attention_bias=has_relative_attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.SelfAttention") + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + )->torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=None, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str=""): + super().__init__() + self.EncDecAttention = T5Attention(config, AttentionType.ENCODER_DECODER, has_relative_attention_bias=False, cache_config=cache_config, quant_config=quant_config,prefix=f"{prefix}.EncDecAttention") + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + )->torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5Block(nn.Module): + def __init__(self, config: T5Config, is_decoder: bool, has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + + prefix: str = ""): + super().__init__() + self.is_decoder = is_decoder + self.self_attn = T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, cache_config=cache_config, quant_config=quant_config,prefix=f"{prefix}.self_attn") + + if self.is_decoder: + self.cross_attn = T5LayerCrossAttention(config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.cross_attn") + + self.ffn = T5LayerFF(config, quant_config=quant_config) + # TODO remove + self.prefix = prefix + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor]=None, + )->torch.Tensor: + + hidden_states = self.self_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # print(f"[{self.prefix}] HIDDEN inblock", hidden_states.shape, hidden_states.mean()) + if self.is_decoder: + hidden_states = self.cross_attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + # print(f"[{self.prefix}] HIDDEN inblock", hidden_states.shape, hidden_states.mean()) + + # Apply Feed Forward layer + hidden_states = self.ffn(hidden_states) + # print(f"[{self.prefix}] HIDDEN inblock", hidden_states.shape, hidden_states.mean()) + return hidden_states + + +class T5Stack(nn.Module): + def __init__(self, config: T5Config, is_decoder: bool, n_layers: int, embed_tokens=None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str=""): + super().__init__() + self.embed_tokens = embed_tokens + # Only the first block has relative positional encoding. + self.blocks = nn.ModuleList( + [T5Block(config, is_decoder=is_decoder, has_relative_attention_bias=i==0, + cache_config=cache_config,quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.prefix = prefix # TODO remove + + + def forward( + self, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor]=None + )-> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + for idx, block in enumerate(self.blocks): + # print(f"[{self.prefix}] HIDDEN", hidden_states.shape, hidden_states.mean()) + hidden_states = block( + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + # print(f"[{self.prefix}] HIDDEN out", hidden_states.shape, hidden_states.mean()) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5Model(nn.Module): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, *, vllm_config: VllmConfig, prefix:str=""): + super().__init__() + config: T5Config = vllm_config.model_config.hf_config + # TODO lora + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.padding_idx = config.pad_token_id # TODO decoding token + self.shared = VocabParallelEmbedding(config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size) + + self.encoder = T5Stack(config, False, config.num_layers, self.shared, cache_config=cache_config,quant_config=quant_config,prefix=f"{prefix}.encoder") + # assert config.num_layers == config.num_decoder_layers + self.decoder = T5Stack(config, True, config.num_decoder_layers, self.shared, cache_config=cache_config,quant_config=quant_config, prefix=f"{prefix}.decoder") + + def get_input_embeddings(self, input_ids: torch.Tensor)->torch.Tensor: + return self.shared(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + encoder_input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata + ) ->torch.Tensor: + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input: on a regular generate call, the encoder + # runs once, on the prompt. Subsequent decoder calls re-use output + # `encoder_hidden_states`. + print("Running on encoder input ids", encoder_input_ids.shape, "on this many sequences", len(attn_metadata.encoder_seq_lens)) + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + # Clear attention bias state. + attn_metadata.attn_bias = None + attn_metadata.encoder_attn_bias = None + attn_metadata.cross_attn_bias = None + print("ENC OUT HIDDEN", encoder_hidden_states.shape, encoder_hidden_states.mean()) + print("Running on decoder input ids (0 as input token)", input_ids) + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + print("DEC OUT HIDDEN", decoder_outputs.shape, decoder_outputs.mean()) + return decoder_outputs + + +class T5ForConditionalGeneration(nn.Module): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, *, vllm_config: VllmConfig, prefix:str=""): + super().__init__() + config: T5Config = vllm_config.model_config.hf_config + self.model_dim = config.d_model + self.config = config + self.unpadded_vocab_size = config.vocab_size + # TODO + # if lora_config := vllm_config.lora_config: + # self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.model = T5Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # Although not in config, this is the default for hf models. + if self.config.tie_word_embeddings: + self.lm_head = self.model.shared + # in transformers this is smt more explicit, as in (after load) + # self.lm_head.weight = self.model.shared.weight + else: + self.lm_head = ParallelLMHead(self.unpadded_vocab_size, config.d_model, org_num_embeddings=config.vocab_size) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = get_sampler() + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + hidden_states = hidden_states * (self.model_dim**-0.5) + print("hidden states input", hidden_states.shape) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def get_input_embeddings(self, input_ids: torch.Tensor)->torch.Tensor: + return self.model.shared(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, +) -> torch.Tensor: + return self.model(input_ids, encoder_input_ids, kv_caches, attn_metadata) + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]] + ): + model_params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + renamed_reg = [ + (re.compile(r'block\.(\d+)\.layer\.0'), r'blocks.\1.self_attn'), + (re.compile(r'decoder.block\.(\d+)\.layer\.1'), r'decoder.blocks.\1.cross_attn'), + (re.compile(r'decoder.block\.(\d+)\.layer\.2'), r'decoder.blocks.\1.ffn'), + # encoder has no cross-attn, but rather self-attention+ffn. + (re.compile(r'encoder.block\.(\d+)\.layer\.1'), r'encoder.blocks.\1.ffn'), + (re.compile(r'\.o\.'), r'.out_proj.'), + ] + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj.", ".q.", "q"), + (".qkv_proj.", ".k.", "k"), + (".qkv_proj.", ".v.", "v") + ] + + for name, loaded_weight in weights: + # No relative position attn bias on cross attention. + if name == "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight": + continue + + # Handle some renaming + for reg in renamed_reg: + name = re.sub(*reg, name) + + top_module, _ = name.split('.', 1) + if top_module != 'lm_head': + name = f"model.{name}" + + # Split q/k/v layers to unified QKVParallelLinear + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = model_params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Not a q/k/v layer. + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params \ No newline at end of file From 0d7b0c50d97de41e667dc7604a8a639f5fb20906 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 7 Jan 2025 11:32:25 +0000 Subject: [PATCH 09/17] first working version :) --- vllm/model_executor/models/t5.py | 109 ++++++++++++++----------------- 1 file changed, 48 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 8867989e1f1cb..4dcb95aeb9ee3 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -39,6 +39,10 @@ from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata from .utils import maybe_prefix +# TODO best way to handle xformers imports? +from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias +# TODO func should be in backend interface +from vllm.attention.backends.xformers import _get_attn_bias, _set_attn_bias class T5LayerNorm(nn.Module): @@ -276,8 +280,6 @@ def forward( # NOTE (NickLucche) Attn bias is computed once per encoder or decoder # forward, on the first call to T5Attention.forward. Subsequent # *self-attention* layers will re-use it. - # TODO func should be in backend interface - from vllm.attention.backends.xformers import _get_attn_bias, _set_attn_bias attn_bias = _get_attn_bias(attn_metadata, self.attn_type) if self.attn_type == AttentionType.ENCODER_DECODER: # Projection of encoder's hidden states, cross-attention. @@ -288,86 +290,71 @@ def forward( v = None else: assert attn_metadata.num_prefills > 0 - # Prefill phase (first decode forward), caching kv + # Prefill phase (first decoder forward), caching kv qkv_enc, _ = self.qkv_proj(encoder_hidden_states) _, k, v = qkv_enc.split(self.inner_dim, dim=-1) # No custom attention bias must be set when running cross attn. assert attn_bias is None - # Skip when profiling. + # FIXME should be enabled on profiling run to assess memory of bias. - # TODO NOT compatible with CP here, assumes homogeneous batch + # TODO NOT compatible with CP here (as all encoder-decoder models), + # as it assumes homogeneous batch (prefills or decodes). elif self.has_relative_attention_bias and not is_profile_run: assert attn_bias is None # to be recomputed - # Self-attention. T5 relative positional encoding. - # Compute bias based on longest sequence in batch. Biases for - # shorter sequences are subsets of the longest. - if self.attn_type == AttentionType.ENCODER: - seq_len = attn_metadata.max_encoder_seq_len - else: - # Decoder can receive both prefill and decoding requests - seq_len = attn_metadata.max_prefill_seq_len if attn_metadata.prefill_metadata else attn_metadata.max_decode_seq_len - block_aligned_seq_len = (seq_len + block_size - 1) // block_size * block_size - - # TODO xformers-specific, attention bias are to be provided as a list. + # Self-attention. Compute T5 relative positional encoding. + # The bias term is computed on longest sequence in batch. Biases + # for shorter sequences are slices of the longest. + # TODO xformers-specific code. align_to = 8 - print("IN", hidden_states.shape) - # TODO chunked prefill # what I want: (num_seqs, NH, L, L_pad) for prefill, (num_seqs, NH, 1, L_pad) for decodes if self.attn_type == AttentionType.ENCODER: - # NOTE seq padding needed for xformers! - # HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, you need to ensure memory is aligned by slicing a bigger tensor. Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])` + # Encoder prefill stage, uses xFormers, hence sequence + # padding/alignment to 8 is required. + seq_len = attn_metadata.max_encoder_seq_len padded_seq_len = (seq_len + align_to - 1) // align_to * align_to - # FIXME blockdiagonal mask with tensor bias?? - pb = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) + # TODO (NickLucche) avoid extra copy on repeat, provide multiple slices of same memory + position_bias = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) print(f"[{self.prefix}][ENCODING, xformers aligned] Seq len compute bias/shape", seq_len, padded_seq_len) - # print("SEQ LEN", attn_metadata.encoder_seq_lens, pb.shape) xformers needs a list, one matrix per sequence - attn_metadata.encoder_attn_bias = [p[:, :sq, :sq].unsqueeze(0) for p, sq in zip(pb, attn_metadata.encoder_seq_lens)] - print("Bias shape", attn_metadata.encoder_attn_bias[0].shape) - elif attn_metadata.decode_metadata is None: # TODO join with statement above - # NOTE first decoder step, prefill: seq len here is usually 1, but one can prepend different start tokens prior to generation. XFormers is used. + # xFormers expects a list of biases, one matrix per sequence. + # As each sequence gets its own bias, no masking is required. + attn_bias = [p[None, :, :sq, :sq] for p, sq in zip(position_bias, attn_metadata.encoder_seq_lens)] + elif attn_metadata.prefill_metadata: + # Decoder prefill stage, uses xFormers, hence sequence + # padding/alignment to 8 is required. First decoder step, + # seq_len is usually 1, but one can prepend different start + # tokens prior to generation. + seq_len = attn_metadata.max_prefill_seq_len print(f"[{self.prefix}][DECODER but xformers prefill]", attn_metadata.max_decode_seq_len, attn_metadata.max_prefill_seq_len) + # ->align padded_seq_len = (seq_len + align_to - 1) // align_to * align_to - # position_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1) - # this is always 1 (seqlen) but needs filling!! - # position_bias = self.compute_bias(seq_len, align_to).repeat(num_seqs, 1, 1, 1) position_bias = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) print(f"[{self.prefix}][DECODER but xformers prefill] Seq len compute bias/shape", seq_len, position_bias.shape) - # ->align - # position_bias = position_bias.repeat(1, 1, align_to, align_to) - # print("DECODER RUN BUT NO METADATA", position_bias.shape) - # TODO debug, can it be removed? - # position_bias[:, :, 1:, 1:] = torch.finfo(position_bias.dtype).min - # attn_metadata.attn_bias = [pb[None, :, :seq_len, :seq_len] for pb in position_bias] - attn_metadata.attn_bias = [pb[None, :, :seq_len, :seq_len] for pb, sq in zip(position_bias, attn_metadata.seq_lens)] - # attn_metadata.attn_bias = [position_bias[:, :, :1, :seq_len]] - print("[DECODER but xformers prefill] Bias shape", attn_metadata.attn_bias[0].shape) + print(f"[{self.prefix}][DECODER but xformers prefill] Seq lens vs Encoder seqlens", attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + # Causal mask for prefill. + attn_bias = [LowerTriangularMaskWithTensorBias(pb[None, :, :sq, :sq]) for pb, sq in zip(position_bias, attn_metadata.seq_lens)] else: - # Repeat along dim0: (num_seqs, n_heads, 1, L) - # TODO (NickLucche): allow single bias for whole batch to avoid extra-copy. - # TODO this needs to be block aligned!! - # position_bias = self.compute_bias(seq_len, block_aligned_seq_len).repeat(num_seqs, 1, 1, 1) - position_bias = self.compute_bias(1, block_aligned_seq_len).repeat(num_seqs, 1, 1, 1) - # position_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1) - print(f"[{self.prefix}][DECODING] Seq len compute bias/shape", seq_len, seq_len) - # position_bias[:, :, seq_len:, seq_len:] = torch.finfo(position_bias.dtype).min - attn_metadata.attn_bias = [position_bias] - # attn_metadata.attn_bias = [position_bias[:, :, :seq_len, :seq_len]] - print("Bias shape", attn_metadata.attn_bias[0].shape) - # TODO set attn bias + # Decoder decoding stage, uses PagedAttention, hence sequence + # padding/alignment to `block_size` is required. Expected + # number of queries is always 1 (MQA not supported). + seq_len = attn_metadata.max_decode_seq_len + block_aligned_seq_len = (seq_len + block_size - 1) // block_size * block_size + + position_bias = self.compute_bias(seq_len, block_aligned_seq_len) + # Bias for the last query, the one at current decoding step. + position_bias = position_bias[:, :, -1:, :].repeat(num_seqs, 1, 1, 1) + print(f"[{self.prefix}]****[DECODING]***** Seq len compute bias/shape", seq_len, block_aligned_seq_len, position_bias.shape) + # No explicit masking required, this is done inside the + # paged attention kernel based on the sequence length. + attn_bias = [position_bias] + + # NOTE Assign bias term on metadata based on attn type: + # ENCODER->`encoder_attn_bias`, DECODER->`attn_bias` + _set_attn_bias(attn_metadata, attn_bias, self.attn_type) + print("Bias shape", attn_bias[0].shape) elif not self.has_relative_attention_bias and not is_profile_run: # Encoder/Decoder Self-Attention Layer, attn bias already cached. assert attn_bias is not None - # masking extra bias entries is done in the kernel - # mask = (seq_range >= prompt_lens).unsqueeze(1).unsqueeze(3) - # position_bias *= mask - # xformers masking - # for i in range(batch_size): - # input_metadata.attn_bias[ - # i, :, :, - # input_metadata.prompt_lens[i]:, ] = torch.finfo( - # input_metadata.attn_bias.dtype).min - attn_output = self.attn(q, k, v, From 43eca38665d88c6f3cef4d070e639e46231979d7 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 7 Jan 2025 16:36:43 +0000 Subject: [PATCH 10/17] clean up Signed-off-by: NickLucche --- vllm/model_executor/models/t5.py | 70 +++++++++----------------------- 1 file changed, 19 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index 4dcb95aeb9ee3..f831a85ca7713 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from .utils import maybe_prefix # TODO best way to handle xformers imports? from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias @@ -146,7 +147,6 @@ def __init__( prefix: str = "" ): super().__init__() - self.prefix = prefix self.attn_type = attn_type # Cross-attention has no relative pos encoding anyway self.is_decoder = attn_type == AttentionType.DECODER @@ -155,25 +155,18 @@ def __init__( self.relative_attention_max_distance = config.relative_attention_max_distance self.d_model = config.d_model self.key_value_proj_dim = config.d_kv - self.n_heads = config.num_heads - self.dropout = config.dropout_rate + + # Partition heads across multiple tensor parallel GPUs. + tp_world_size = get_tensor_model_parallel_world_size() + assert config.num_heads % tp_world_size == 0 + self.n_heads = config.num_heads // tp_world_size + self.inner_dim = self.n_heads * self.key_value_proj_dim # No GQA in t5. self.n_kv_heads = self.n_heads self.qkv_proj = QKVParallelLinear(self.d_model, self.d_model // self.n_heads, self.n_heads, self.n_kv_heads, bias=False, quant_config=quant_config) - # TODO refactor in utils shared with bart - # if self.total_num_kv_heads >= tp_world_size: - # # Number of KV heads is greater than TP size, so we partition - # # the KV heads across multiple tensor parallel GPUs. - # assert self.total_num_kv_heads % tp_world_size == 0 - # else: - # # Number of KV heads is less than TP size, so we replicate - # # the KV heads across multiple tensor parallel GPUs. - # assert tp_world_size % self.total_num_kv_heads == 0 - # self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) - # NOTE (NickLucche) T5 employs a scaled weight initialization scheme # instead of scaling attention scores directly. self.attn = Attention(self.n_heads, config.d_kv, 1.0, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn") @@ -181,10 +174,8 @@ def __init__( # Only the first SelfAttention block in encoder decoder has this # embedding layer, the others re-use its output. if self.has_relative_attention_bias: - print(f"EMBEDDING {self.relative_attention_num_buckets} TO {self.n_heads}, max dist {self.relative_attention_max_distance}") self.relative_attention_bias = VocabParallelEmbedding(self.relative_attention_num_buckets,\ self.n_heads, org_num_embeddings=self.relative_attention_num_buckets, quant_config=quant_config) - # self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.out_proj = RowParallelLinear( self.inner_dim, self.d_model, @@ -314,7 +305,6 @@ def forward( padded_seq_len = (seq_len + align_to - 1) // align_to * align_to # TODO (NickLucche) avoid extra copy on repeat, provide multiple slices of same memory position_bias = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) - print(f"[{self.prefix}][ENCODING, xformers aligned] Seq len compute bias/shape", seq_len, padded_seq_len) # xFormers expects a list of biases, one matrix per sequence. # As each sequence gets its own bias, no masking is required. attn_bias = [p[None, :, :sq, :sq] for p, sq in zip(position_bias, attn_metadata.encoder_seq_lens)] @@ -324,12 +314,9 @@ def forward( # seq_len is usually 1, but one can prepend different start # tokens prior to generation. seq_len = attn_metadata.max_prefill_seq_len - print(f"[{self.prefix}][DECODER but xformers prefill]", attn_metadata.max_decode_seq_len, attn_metadata.max_prefill_seq_len) # ->align padded_seq_len = (seq_len + align_to - 1) // align_to * align_to position_bias = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) - print(f"[{self.prefix}][DECODER but xformers prefill] Seq len compute bias/shape", seq_len, position_bias.shape) - print(f"[{self.prefix}][DECODER but xformers prefill] Seq lens vs Encoder seqlens", attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) # Causal mask for prefill. attn_bias = [LowerTriangularMaskWithTensorBias(pb[None, :, :sq, :sq]) for pb, sq in zip(position_bias, attn_metadata.seq_lens)] else: @@ -339,18 +326,17 @@ def forward( seq_len = attn_metadata.max_decode_seq_len block_aligned_seq_len = (seq_len + block_size - 1) // block_size * block_size - position_bias = self.compute_bias(seq_len, block_aligned_seq_len) + # TODO bf16 bias support in PagedAttention. + position_bias = self.compute_bias(seq_len, block_aligned_seq_len).float() # Bias for the last query, the one at current decoding step. position_bias = position_bias[:, :, -1:, :].repeat(num_seqs, 1, 1, 1) - print(f"[{self.prefix}]****[DECODING]***** Seq len compute bias/shape", seq_len, block_aligned_seq_len, position_bias.shape) # No explicit masking required, this is done inside the # paged attention kernel based on the sequence length. attn_bias = [position_bias] # NOTE Assign bias term on metadata based on attn type: - # ENCODER->`encoder_attn_bias`, DECODER->`attn_bias` + # ENCODER->`encoder_attn_bias`, DECODER->`attn_bias`. _set_attn_bias(attn_metadata, attn_bias, self.attn_type) - print("Bias shape", attn_bias[0].shape) elif not self.has_relative_attention_bias and not is_profile_run: # Encoder/Decoder Self-Attention Layer, attn bias already cached. assert attn_bias is not None @@ -361,10 +347,6 @@ def forward( kv_cache, attn_metadata, attn_type=self.attn_type) - # if not self.is_decoder: - # print("\n\nSAVING\n\n") - # torch.save(attn_output, "tensor_vllm.pth") - # FIXME all equal until the next line output, _ = self.out_proj(attn_output) return output @@ -437,8 +419,6 @@ def __init__(self, config: T5Config, is_decoder: bool, has_relative_attention_bi self.cross_attn = T5LayerCrossAttention(config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.cross_attn") self.ffn = T5LayerFF(config, quant_config=quant_config) - # TODO remove - self.prefix = prefix def forward( self, @@ -453,7 +433,6 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, ) - # print(f"[{self.prefix}] HIDDEN inblock", hidden_states.shape, hidden_states.mean()) if self.is_decoder: hidden_states = self.cross_attn( hidden_states=hidden_states, @@ -461,11 +440,9 @@ def forward( attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) - # print(f"[{self.prefix}] HIDDEN inblock", hidden_states.shape, hidden_states.mean()) # Apply Feed Forward layer hidden_states = self.ffn(hidden_states) - # print(f"[{self.prefix}] HIDDEN inblock", hidden_states.shape, hidden_states.mean()) return hidden_states @@ -483,7 +460,6 @@ def __init__(self, config: T5Config, is_decoder: bool, n_layers: int, embed_toke prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)] ) self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.prefix = prefix # TODO remove def forward( @@ -496,14 +472,12 @@ def forward( hidden_states = self.embed_tokens(input_ids) for idx, block in enumerate(self.blocks): - # print(f"[{self.prefix}] HIDDEN", hidden_states.shape, hidden_states.mean()) hidden_states = block( hidden_states=hidden_states, kv_cache=kv_caches[idx], attn_metadata=attn_metadata, encoder_hidden_states=encoder_hidden_states, ) - # print(f"[{self.prefix}] HIDDEN out", hidden_states.shape, hidden_states.mean()) hidden_states = self.final_layer_norm(hidden_states) return hidden_states @@ -514,16 +488,17 @@ class T5Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix:str=""): super().__init__() config: T5Config = vllm_config.model_config.hf_config - # TODO lora cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - - self.padding_idx = config.pad_token_id # TODO decoding token + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.padding_idx = config.pad_token_id self.shared = VocabParallelEmbedding(config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size) self.encoder = T5Stack(config, False, config.num_layers, self.shared, cache_config=cache_config,quant_config=quant_config,prefix=f"{prefix}.encoder") - # assert config.num_layers == config.num_decoder_layers self.decoder = T5Stack(config, True, config.num_decoder_layers, self.shared, cache_config=cache_config,quant_config=quant_config, prefix=f"{prefix}.decoder") def get_input_embeddings(self, input_ids: torch.Tensor)->torch.Tensor: @@ -543,7 +518,6 @@ def forward( # are provided as input: on a regular generate call, the encoder # runs once, on the prompt. Subsequent decoder calls re-use output # `encoder_hidden_states`. - print("Running on encoder input ids", encoder_input_ids.shape, "on this many sequences", len(attn_metadata.encoder_seq_lens)) encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, kv_caches=kv_caches, attn_metadata=attn_metadata) @@ -551,16 +525,12 @@ def forward( attn_metadata.attn_bias = None attn_metadata.encoder_attn_bias = None attn_metadata.cross_attn_bias = None - print("ENC OUT HIDDEN", encoder_hidden_states.shape, encoder_hidden_states.mean()) - print("Running on decoder input ids (0 as input token)", input_ids) - # decoder outputs consists of - # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, kv_caches=kv_caches, attn_metadata=attn_metadata) - print("DEC OUT HIDDEN", decoder_outputs.shape, decoder_outputs.mean()) return decoder_outputs @@ -576,9 +546,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix:str=""): self.model_dim = config.d_model self.config = config self.unpadded_vocab_size = config.vocab_size - # TODO - # if lora_config := vllm_config.lora_config: - # self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if lora_config := vllm_config.lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.model = T5Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) # Although not in config, this is the default for hf models. @@ -602,7 +571,6 @@ def compute_logits( # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 hidden_states = hidden_states * (self.model_dim**-0.5) - print("hidden states input", hidden_states.shape) logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits @@ -655,7 +623,7 @@ def load_weights( for name, loaded_weight in weights: # No relative position attn bias on cross attention. - if name == "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight": + if name in self._keys_to_ignore_on_load_unexpected: continue # Handle some renaming From b481f5d7fd728e82bb32b2644da0c696ae2b8d58 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 7 Jan 2025 16:38:41 +0000 Subject: [PATCH 11/17] address missing bos token case Signed-off-by: NickLucche --- vllm/inputs/preprocess.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 5e35ca02bf811..5f56566e7a3a5 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -89,7 +89,7 @@ def get_decoder_start_token_id(self) -> Optional[int]: return dec_start_token_id - def _get_default_enc_dec_decoder_prompt(self) -> List[int]: + def _maybe_get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' Specifically for encoder/decoder models: generate a default decoder prompt for when @@ -122,11 +122,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' bos_token_id = self.get_bos_token_id() - if bos_token_id is None: - # TODO do I have to make another config to set pad id as bos? T5 has no bos..pad is used in transformers too - bos_token_id = 0 - assert bos_token_id is not None - return [bos_token_id] + return [] if bos_token_id is None else [bos_token_id] def _prepare_decoder_input_ids_for_generation( self, @@ -158,7 +154,7 @@ def _prepare_decoder_input_ids_for_generation( if decoder_input_ids is None: # no decoder prompt input -> # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + decoder_input_ids = self._maybe_get_default_enc_dec_decoder_prompt() if (len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id): From 67bdbbc7855122431b6d17e6054cc1b503bce061 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Jan 2025 10:30:50 +0000 Subject: [PATCH 12/17] format and clean up Signed-off-by: NickLucche --- vllm/attention/backends/xformers.py | 36 +-- vllm/inputs/preprocess.py | 3 +- vllm/model_executor/models/t5.py | 459 ++++++++++++++++++---------- 3 files changed, 313 insertions(+), 185 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6b4839bd90771..79b69c9e351a3 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -157,8 +157,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. + # when alibi slopes or custom attention bias are used. + # It is because of a limitation from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None @@ -544,6 +544,7 @@ def forward( assert query.shape[0] == num_prefill_query_tokens assert decode_query.shape[0] == num_decode_query_tokens + attn_bias = _get_attn_bias(attn_metadata, attn_type) if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -551,8 +552,10 @@ def forward( # normal attention. # block tables are empty if the prompt does not have a cached # prefix. - # TODO this should be forwarded when splitting prefill/decode_meta - _set_attn_bias(prefill_meta, _get_attn_bias(attn_metadata, attn_type), attn_type) + # As prefill metadata are cached on first call, we need to make + # sure attn_bias is up to date. + if attn_bias: + _set_attn_bias(prefill_meta, attn_bias, attn_type) out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta, attn_type=attn_type) assert out.shape == output[:num_prefill_query_tokens].shape @@ -599,10 +602,11 @@ def forward( block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias: + assert len( + attn_bias + ) == 1, "PagedAttention expects a single bias to be provided for all input sequences." attn_bias = attn_bias[0] - # print("Bias shape", attn_bias.shape) output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, @@ -666,7 +670,6 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - # FIXME this is None should be rel pos encoding attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: @@ -710,7 +713,8 @@ def _run_memory_efficient_xformers_forward( attn_metadata.seq_lens) else: raise ValueError("Unknown AttentionType: %s", attn_type) - + + assert isinstance(attn_bias, BlockDiagonalMask) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -723,23 +727,15 @@ def _run_memory_efficient_xformers_forward( attn_metadata.seq_lens) _set_attn_bias(attn_metadata, attn_bias, attn_type) - - # if isinstance(attn_bias[0], torch.Tensor): - # print("IS THIS WORKING PREFILL shape", [b.shape for b in attn_bias]) - # print("IS THIS WORKING PREFILL stride", [b.stride() for b in attn_bias]) - # print("QUERY shape", query.shape, key.shape) - # No alibi slopes. + # No alibi slopes and no multi-sequence custom attention bias. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. - # TODO refactor custom attn bias must not go here - if self.alibi_slopes is None and len(attn_bias)==1: + if self.alibi_slopes is None and len(attn_bias) == 1: # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) - # if isinstance(attn_bias[0], torch.Tensor): - # print("RUNNING SINGLE ATTN BIAS VERSION WITH", attn_bias[0].shape) out = xops.memory_efficient_attention_forward( query, key, @@ -748,14 +744,14 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) return out.view_as(original_query) - - # Attention with alibi slopes. + # Attention with alibi slopes or multiple custom attention bias. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) seq_lens = attn_metadata.encoder_seq_lens if attn_type == AttentionType.ENCODER else attn_metadata.seq_lens + assert seq_lens start = 0 for i, seq_len in enumerate(seq_lens): end = start + seq_len diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 5f56566e7a3a5..7c84fabf2653d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -154,7 +154,8 @@ def _prepare_decoder_input_ids_for_generation( if decoder_input_ids is None: # no decoder prompt input -> # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._maybe_get_default_enc_dec_decoder_prompt() + decoder_input_ids = self._maybe_get_default_enc_dec_decoder_prompt( + ) if (len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id): diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index f831a85ca7713..b0af947ada7eb 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -17,7 +17,7 @@ """PyTorch T5 model.""" import math -from typing import List, Optional, Tuple, Union, Set, Iterable +from typing import List, Optional, Tuple, Set, Iterable import re import torch @@ -26,7 +26,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import (ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.layers.activation import get_act_fn from vllm.attention.layer import Attention, AttentionType, AttentionMetadata from vllm.config import CacheConfig @@ -47,6 +48,7 @@ class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. @@ -55,14 +57,16 @@ def __init__(self, hidden_size, eps=1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states)->torch.Tensor: + def forward(self, hidden_states) -> torch.Tensor: # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -72,13 +76,19 @@ def forward(self, hidden_states)->torch.Tensor: class T5DenseActDense(nn.Module): - def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + + def __init__(self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False) - self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config) self.act = get_act_fn(config.dense_act_fn) - def forward(self, hidden_states)->torch.Tensor: + def forward(self, hidden_states) -> torch.Tensor: hidden_states, _ = self.wi(hidden_states) hidden_states = self.act(hidden_states) # if ( @@ -92,14 +102,26 @@ def forward(self, hidden_states)->torch.Tensor: class T5DenseGatedActDense(nn.Module): - def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + + def __init__(self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.wi_0 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) - self.wi_1 = ColumnParallelLinear(config.d_model, config.d_ff, bias=False, quant_config=quant_config) - self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, quant_config=quant_config) + self.wi_0 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False, + quant_config=quant_config) + self.wi_1 = ColumnParallelLinear(config.d_model, + config.d_ff, + bias=False, + quant_config=quant_config) + self.wo = RowParallelLinear(config.d_ff, + config.d_model, + bias=False, + quant_config=quant_config) self.act = get_act_fn(config.dense_act_fn) - def forward(self, hidden_states)->torch.Tensor: + def forward(self, hidden_states) -> torch.Tensor: hidden_gelu = self.act(self.wi_0(hidden_states)[0]) hidden_linear, _ = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear @@ -120,16 +142,22 @@ def forward(self, hidden_states)->torch.Tensor: class T5LayerFF(nn.Module): - def __init__(self, config: T5Config, quant_config: Optional[QuantizationConfig] = None): + + def __init__(self, + config: T5Config, + quant_config: Optional[QuantizationConfig] = None): super().__init__() if config.is_gated_act: - self.DenseReluDense = T5DenseGatedActDense(config, quant_config=quant_config) + self.DenseReluDense = T5DenseGatedActDense( + config, quant_config=quant_config) else: - self.DenseReluDense = T5DenseActDense(config, quant_config=quant_config) + self.DenseReluDense = T5DenseActDense(config, + quant_config=quant_config) - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) - def forward(self, hidden_states)->torch.Tensor: + def forward(self, hidden_states) -> torch.Tensor: forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = hidden_states + forwarded_states @@ -137,17 +165,16 @@ def forward(self, hidden_states)->torch.Tensor: class T5Attention(nn.Module): - def __init__( - self, - config: T5Config, - attn_type: AttentionType, - has_relative_attention_bias=False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "" - ): + + def __init__(self, + config: T5Config, + attn_type: AttentionType, + has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() - self.attn_type = attn_type + self.attn_type = attn_type # Cross-attention has no relative pos encoding anyway self.is_decoder = attn_type == AttentionType.DECODER self.has_relative_attention_bias = has_relative_attention_bias @@ -165,14 +192,24 @@ def __init__( # No GQA in t5. self.n_kv_heads = self.n_heads - self.qkv_proj = QKVParallelLinear(self.d_model, self.d_model // self.n_heads, self.n_heads, self.n_kv_heads, bias=False, quant_config=quant_config) + self.qkv_proj = QKVParallelLinear(self.d_model, + self.d_model // self.n_heads, + self.n_heads, + self.n_kv_heads, + bias=False, + quant_config=quant_config) - # NOTE (NickLucche) T5 employs a scaled weight initialization scheme + # NOTE (NickLucche) T5 employs a scaled weight initialization scheme # instead of scaling attention scores directly. - self.attn = Attention(self.n_heads, config.d_kv, 1.0, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn") - - # Only the first SelfAttention block in encoder decoder has this - # embedding layer, the others re-use its output. + self.attn = Attention(self.n_heads, + config.d_kv, + 1.0, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + # Only the first SelfAttention block in encoder decoder has this + # embedding layer, the others reuse its output. if self.has_relative_attention_bias: self.relative_attention_bias = VocabParallelEmbedding(self.relative_attention_num_buckets,\ self.n_heads, org_num_embeddings=self.relative_attention_num_buckets, quant_config=quant_config) @@ -184,7 +221,10 @@ def __init__( ) @staticmethod - def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + def _relative_position_bucket(relative_position, + bidirectional=True, + num_buckets=32, + max_distance=128): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 @@ -208,10 +248,12 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets = 0 if bidirectional: num_buckets //= 2 - relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_buckets += (relative_position > 0).to( + torch.long) * num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + relative_position = -torch.min(relative_position, + torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -220,24 +262,31 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) + torch.log(relative_position.float() / max_exact) / + math.log(max_distance / max_exact) * + (num_buckets - max_exact)).to(torch.long) relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) - ) + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1)) - relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + relative_buckets += torch.where(is_small, relative_position, + relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None)->torch.Tensor: + def compute_bias(self, + query_length, + key_length, + device=None) -> torch.Tensor: """Compute binned relative position bias""" # TODO possible tp issue? if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + context_position = torch.arange(query_length, + dtype=torch.long, + device=device)[:, None] + memory_position = torch.arange(key_length, + dtype=torch.long, + device=device)[None, :] # max_seq_len, nh relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -246,31 +295,36 @@ def compute_bias(self, query_length, key_length, device=None)->torch.Tensor: num_buckets=self.relative_attention_num_buckets, max_distance=self.relative_attention_max_distance, ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - x = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + x = values.permute([2, 0, 1]).unsqueeze( + 0) # shape (1, num_heads, query_length, key_length) return x def forward( self, - hidden_states: torch.Tensor, # (num_tokens, d_model) + hidden_states: torch.Tensor, # (num_tokens, d_model) kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, - )->torch.Tensor: - # TODO auto-selection of xformers backend when t5 is detected + ) -> torch.Tensor: + # TODO auto-selection of xformers backend when t5 is detected assert isinstance(attn_metadata, XFormersMetadata) is_profile_run = kv_cache.numel() == 0 if not is_profile_run: # TODO xformers only block_size = kv_cache.shape[2] // self.inner_dim - num_seqs = len(attn_metadata.seq_lens) if attn_metadata.seq_lens else len(attn_metadata.encoder_seq_lens) + num_seqs = len( + attn_metadata.seq_lens) if attn_metadata.seq_lens else len( + attn_metadata.encoder_seq_lens) qkv, _ = self.qkv_proj(hidden_states) - # Projection of 'own' hidden state (self-attention). No GQA here. + # Projection of 'own' hidden state (self-attention). No GQA here. q, k, v = qkv.split(self.inner_dim, dim=-1) # NOTE (NickLucche) Attn bias is computed once per encoder or decoder - # forward, on the first call to T5Attention.forward. Subsequent - # *self-attention* layers will re-use it. + # forward, on the first call to T5Attention.forward. Subsequent + # *self-attention* layers will reuse it. attn_bias = _get_attn_bias(attn_metadata, self.attn_type) if self.attn_type == AttentionType.ENCODER_DECODER: # Projection of encoder's hidden states, cross-attention. @@ -284,57 +338,73 @@ def forward( # Prefill phase (first decoder forward), caching kv qkv_enc, _ = self.qkv_proj(encoder_hidden_states) _, k, v = qkv_enc.split(self.inner_dim, dim=-1) - # No custom attention bias must be set when running cross attn. + # No custom attention bias must be set when running cross attn. assert attn_bias is None # FIXME should be enabled on profiling run to assess memory of bias. - # TODO NOT compatible with CP here (as all encoder-decoder models), + # TODO NOT compatible with CP here (as all encoder-decoder models), # as it assumes homogeneous batch (prefills or decodes). elif self.has_relative_attention_bias and not is_profile_run: - assert attn_bias is None # to be recomputed - # Self-attention. Compute T5 relative positional encoding. - # The bias term is computed on longest sequence in batch. Biases + assert attn_bias is None # to be recomputed + # Self-attention. Compute T5 relative positional encoding. + # The bias term is computed on longest sequence in batch. Biases # for shorter sequences are slices of the longest. # TODO xformers-specific code. align_to = 8 - # what I want: (num_seqs, NH, L, L_pad) for prefill, (num_seqs, NH, 1, L_pad) for decodes + # bias expected shape: (num_seqs, NH, L, L_pad) for prefill, + # (num_seqs, NH, 1, L_pad) for decodes. if self.attn_type == AttentionType.ENCODER: - # Encoder prefill stage, uses xFormers, hence sequence + # Encoder prefill stage, uses xFormers, hence sequence # padding/alignment to 8 is required. seq_len = attn_metadata.max_encoder_seq_len - padded_seq_len = (seq_len + align_to - 1) // align_to * align_to + padded_seq_len = (seq_len + align_to - + 1) // align_to * align_to # TODO (NickLucche) avoid extra copy on repeat, provide multiple slices of same memory - position_bias = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) + position_bias = self.compute_bias(seq_len, + padded_seq_len).repeat( + num_seqs, 1, 1, 1) # xFormers expects a list of biases, one matrix per sequence. # As each sequence gets its own bias, no masking is required. - attn_bias = [p[None, :, :sq, :sq] for p, sq in zip(position_bias, attn_metadata.encoder_seq_lens)] + attn_bias = [ + p[None, :, :sq, :sq] for p, sq in zip( + position_bias, attn_metadata.encoder_seq_lens) + ] elif attn_metadata.prefill_metadata: # Decoder prefill stage, uses xFormers, hence sequence - # padding/alignment to 8 is required. First decoder step, - # seq_len is usually 1, but one can prepend different start - # tokens prior to generation. + # padding/alignment to 8 is required. First decoder step, + # seq_len is usually 1, but one can prepend different start + # tokens prior to generation. seq_len = attn_metadata.max_prefill_seq_len # ->align - padded_seq_len = (seq_len + align_to - 1) // align_to * align_to - position_bias = self.compute_bias(seq_len, padded_seq_len).repeat(num_seqs, 1, 1, 1) + padded_seq_len = (seq_len + align_to - + 1) // align_to * align_to + position_bias = self.compute_bias(seq_len, + padded_seq_len).repeat( + num_seqs, 1, 1, 1) # Causal mask for prefill. - attn_bias = [LowerTriangularMaskWithTensorBias(pb[None, :, :sq, :sq]) for pb, sq in zip(position_bias, attn_metadata.seq_lens)] + attn_bias = [ + LowerTriangularMaskWithTensorBias(pb[None, :, :sq, :sq]) + for pb, sq in zip(position_bias, attn_metadata.seq_lens) + ] else: # Decoder decoding stage, uses PagedAttention, hence sequence - # padding/alignment to `block_size` is required. Expected - # number of queries is always 1 (MQA not supported). + # padding/alignment to `block_size` is required. Expected + # number of queries is always 1 (MQA not supported). seq_len = attn_metadata.max_decode_seq_len - block_aligned_seq_len = (seq_len + block_size - 1) // block_size * block_size - + block_aligned_seq_len = (seq_len + block_size - + 1) // block_size * block_size + # TODO bf16 bias support in PagedAttention. - position_bias = self.compute_bias(seq_len, block_aligned_seq_len).float() + position_bias = self.compute_bias( + seq_len, block_aligned_seq_len).float() # Bias for the last query, the one at current decoding step. - position_bias = position_bias[:, :, -1:, :].repeat(num_seqs, 1, 1, 1) - # No explicit masking required, this is done inside the - # paged attention kernel based on the sequence length. + position_bias = position_bias[:, :, -1:, :].repeat( + num_seqs, 1, 1, 1) + # No explicit masking required, this is done inside the + # paged attention kernel based on the sequence length. attn_bias = [position_bias] - - # NOTE Assign bias term on metadata based on attn type: + + # NOTE Assign bias term on metadata based on attn type: # ENCODER->`encoder_attn_bias`, DECODER->`attn_bias`. _set_attn_bias(attn_metadata, attn_bias, self.attn_type) elif not self.has_relative_attention_bias and not is_profile_run: @@ -342,32 +412,43 @@ def forward( assert attn_bias is not None attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=self.attn_type) + k, + v, + kv_cache, + attn_metadata, + attn_type=self.attn_type) output, _ = self.out_proj(attn_output) return output + class T5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False, cache_config: Optional[CacheConfig] = None, + + def __init__( + self, + config, + has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", ): + prefix: str = "", + ): super().__init__() - self.SelfAttention = T5Attention(config, AttentionType.DECODER if "decoder" in prefix else AttentionType.ENCODER, - has_relative_attention_bias=has_relative_attention_bias, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.SelfAttention") - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.SelfAttention = T5Attention( + config, + AttentionType.DECODER + if "decoder" in prefix else AttentionType.ENCODER, + has_relative_attention_bias=has_relative_attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.SelfAttention") + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) def forward( self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - )->torch.Tensor: + ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( hidden_states=normed_hidden_states, @@ -380,12 +461,21 @@ def forward( class T5LayerCrossAttention(nn.Module): - def __init__(self, config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str=""): + + def __init__(self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() - self.EncDecAttention = T5Attention(config, AttentionType.ENCODER_DECODER, has_relative_attention_bias=False, cache_config=cache_config, quant_config=quant_config,prefix=f"{prefix}.EncDecAttention") - self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.EncDecAttention = T5Attention(config, + AttentionType.ENCODER_DECODER, + has_relative_attention_bias=False, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.EncDecAttention") + self.layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) def forward( self, @@ -393,7 +483,7 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, encoder_hidden_states: Optional[torch.Tensor] = None, - )->torch.Tensor: + ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( hidden_states=normed_hidden_states, @@ -406,17 +496,29 @@ def forward( class T5Block(nn.Module): - def __init__(self, config: T5Config, is_decoder: bool, has_relative_attention_bias=False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, + def __init__(self, + config: T5Config, + is_decoder: bool, + has_relative_attention_bias=False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() self.is_decoder = is_decoder - self.self_attn = T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, cache_config=cache_config, quant_config=quant_config,prefix=f"{prefix}.self_attn") - + self.self_attn = T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + if self.is_decoder: - self.cross_attn = T5LayerCrossAttention(config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.cross_attn") + self.cross_attn = T5LayerCrossAttention( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.cross_attn") self.ffn = T5LayerFF(config, quant_config=quant_config) @@ -425,8 +527,8 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - encoder_hidden_states: Optional[torch.Tensor]=None, - )->torch.Tensor: + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: hidden_states = self.self_attn( hidden_states=hidden_states, @@ -447,28 +549,36 @@ def forward( class T5Stack(nn.Module): - def __init__(self, config: T5Config, is_decoder: bool, n_layers: int, embed_tokens=None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str=""): + + def __init__(self, + config: T5Config, + is_decoder: bool, + n_layers: int, + embed_tokens=None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.embed_tokens = embed_tokens # Only the first block has relative positional encoding. - self.blocks = nn.ModuleList( - [T5Block(config, is_decoder=is_decoder, has_relative_attention_bias=i==0, - cache_config=cache_config,quant_config=quant_config, - prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)] - ) - self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - + self.blocks = nn.ModuleList([ + T5Block(config, + is_decoder=is_decoder, + has_relative_attention_bias=i == 0, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}") for i in range(n_layers) + ]) + self.final_layer_norm = T5LayerNorm(config.d_model, + eps=config.layer_norm_epsilon) def forward( - self, - input_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - encoder_hidden_states: Optional[torch.Tensor]=None - )-> torch.Tensor: + self, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None + ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) for idx, block in enumerate(self.blocks): @@ -483,40 +593,53 @@ def forward( class T5Model(nn.Module): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] - def __init__(self, *, vllm_config: VllmConfig, prefix:str=""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: T5Config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - + lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.padding_idx = config.pad_token_id - self.shared = VocabParallelEmbedding(config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size) - - self.encoder = T5Stack(config, False, config.num_layers, self.shared, cache_config=cache_config,quant_config=quant_config,prefix=f"{prefix}.encoder") - self.decoder = T5Stack(config, True, config.num_decoder_layers, self.shared, cache_config=cache_config,quant_config=quant_config, prefix=f"{prefix}.decoder") - - def get_input_embeddings(self, input_ids: torch.Tensor)->torch.Tensor: + self.shared = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size) + + self.encoder = T5Stack(config, + False, + config.num_layers, + self.shared, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = T5Stack(config, + True, + config.num_decoder_layers, + self.shared, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.shared(input_ids) - def forward( - self, - input_ids: torch.Tensor, - encoder_input_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata - ) ->torch.Tensor: + def forward(self, input_ids: torch.Tensor, encoder_input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: encoder_hidden_states = None if encoder_input_ids.numel() > 0: # Run encoder attention if a non-zero number of encoder tokens # are provided as input: on a regular generate call, the encoder - # runs once, on the prompt. Subsequent decoder calls re-use output + # runs once, on the prompt. Subsequent decoder calls reuse output # `encoder_hidden_states`. encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, kv_caches=kv_caches, @@ -538,9 +661,12 @@ class T5ForConditionalGeneration(nn.Module): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", + "lm_head.weight" + ] - def __init__(self, *, vllm_config: VllmConfig, prefix:str=""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: T5Config = vllm_config.model_config.hf_config self.model_dim = config.d_model @@ -549,14 +675,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix:str=""): if lora_config := vllm_config.lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.model = T5Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.model = T5Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) # Although not in config, this is the default for hf models. if self.config.tie_word_embeddings: self.lm_head = self.model.shared # in transformers this is smt more explicit, as in (after load) # self.lm_head.weight = self.model.shared.weight else: - self.lm_head = ParallelLMHead(self.unpadded_vocab_size, config.d_model, org_num_embeddings=config.vocab_size) + self.lm_head = ParallelLMHead(self.unpadded_vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -582,8 +711,8 @@ def sample( ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - - def get_input_embeddings(self, input_ids: torch.Tensor)->torch.Tensor: + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.shared(input_ids) def forward( @@ -597,21 +726,22 @@ def forward( encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, **kwargs, -) -> torch.Tensor: - return self.model(input_ids, encoder_input_ids, kv_caches, attn_metadata) + ) -> torch.Tensor: + return self.model(input_ids, encoder_input_ids, kv_caches, + attn_metadata) - def load_weights( - self, - weights: Iterable[Tuple[str, torch.Tensor]] - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): model_params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: Set[str] = set() renamed_reg = [ (re.compile(r'block\.(\d+)\.layer\.0'), r'blocks.\1.self_attn'), - (re.compile(r'decoder.block\.(\d+)\.layer\.1'), r'decoder.blocks.\1.cross_attn'), - (re.compile(r'decoder.block\.(\d+)\.layer\.2'), r'decoder.blocks.\1.ffn'), + (re.compile(r'decoder.block\.(\d+)\.layer\.1'), + r'decoder.blocks.\1.cross_attn'), + (re.compile(r'decoder.block\.(\d+)\.layer\.2'), + r'decoder.blocks.\1.ffn'), # encoder has no cross-attn, but rather self-attention+ffn. - (re.compile(r'encoder.block\.(\d+)\.layer\.1'), r'encoder.blocks.\1.ffn'), + (re.compile(r'encoder.block\.(\d+)\.layer\.1'), + r'encoder.blocks.\1.ffn'), (re.compile(r'\.o\.'), r'.out_proj.'), ] stacked_params_mapping = [ @@ -622,7 +752,7 @@ def load_weights( ] for name, loaded_weight in weights: - # No relative position attn bias on cross attention. + # No relative position attn bias on cross attention. if name in self._keys_to_ignore_on_load_unexpected: continue @@ -646,7 +776,8 @@ def load_weights( else: # Not a q/k/v layer. param = model_params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) - return loaded_params \ No newline at end of file + return loaded_params From bd264c7516b85218feb7f467b6d156c4b2beeaae Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Jan 2025 10:46:29 +0000 Subject: [PATCH 13/17] t5 tests Signed-off-by: NickLucche --- .../encoder_decoder/language/conftest.py | 50 ++-- .../encoder_decoder/language/test_bart.py | 9 +- .../encoder_decoder/language/test_t5.py | 252 +++--------------- .../models/encoder_decoder/language/utils.py | 3 +- 4 files changed, 67 insertions(+), 247 deletions(-) diff --git a/tests/models/encoder_decoder/language/conftest.py b/tests/models/encoder_decoder/language/conftest.py index a751f73dfa9a6..5f2b14741df03 100644 --- a/tests/models/encoder_decoder/language/conftest.py +++ b/tests/models/encoder_decoder/language/conftest.py @@ -1,29 +1,28 @@ from transformers import AutoModelForSeq2SeqLM -from ....conftest import (DecoderPromptType, HfRunner, VllmRunner, ExplicitEncoderDecoderPrompt) +from ....conftest import (DecoderPromptType, HfRunner, VllmRunner, + ExplicitEncoderDecoderPrompt) from typing import List, Optional, Tuple, Type, Dict, Any from ...utils import check_logprobs_close from .utils import vllm_to_hf_output - -# TODO docs def compare_hf_vllm_logprobs( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, - vllm_runner_kwargs: Optional[Dict[str, Any]] = dict(), -) -> None: + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, + vllm_runner_kwargs: Optional[Dict[str, Any]] = dict(), + hf_tokens_to_skip: int = 0) -> None: ''' - Test the vLLM BART model for a variety of encoder/decoder input prompts, - by validating it against HuggingFace (HF) BART. + Test the provided model for a variety of encoder/decoder input prompts, + by validating it against corresponding HuggingFace (HF). Arguments: @@ -83,9 +82,8 @@ def compare_hf_vllm_logprobs( discard the first decoded token from the HF output before comparing it to vLLM. - To that end, when testing the scenario where the decoder prompt is None - (and only in that one scenario), this test skips the first HF decoded - token during the process of validating the vLLM decoded output. + To that end, `hf_tokens_to_skip` must be set to the number of HF decoded + tokens to skip during the process of validating the vLLM decoded output. ''' # NOTE: take care of the order. run vLLM first, and then run HF. @@ -107,7 +105,8 @@ def compare_hf_vllm_logprobs( dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, **vllm_runner_kwargs) as vllm_model: + enforce_eager=True, + **vllm_runner_kwargs) as vllm_model: vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( prompts, max_tokens, num_logprobs) @@ -132,9 +131,6 @@ def compare_hf_vllm_logprobs( **hf_kwargs, )) - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) - check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=[ @@ -143,5 +139,5 @@ def compare_hf_vllm_logprobs( ], name_0="hf", name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) \ No newline at end of file + num_outputs_0_skip_tokens=hf_tokens_to_skip, + ) diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index b43edb6fcb5f8..6e8aebf00f4f7 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -2,14 +2,11 @@ Run `pytest tests/models/encoder_decoder/language/test_bart.py`. """ -from typing import List, Optional, Tuple, Type - import pytest -from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, - HfRunner, VllmRunner) +from ....conftest import DecoderPromptType from ....utils import multi_gpu_test -from .utils import compare_hf_vllm_logprobs +from .conftest import compare_hf_vllm_logprobs @pytest.mark.parametrize( @@ -37,7 +34,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=1, - ) + hf_tokens_to_skip=int(decoder_prompt_type == DecoderPromptType.NONE)) @multi_gpu_test(num_gpus=2) diff --git a/tests/models/encoder_decoder/language/test_t5.py b/tests/models/encoder_decoder/language/test_t5.py index 1ff43defbc726..873e4c314d5b5 100644 --- a/tests/models/encoder_decoder/language/test_t5.py +++ b/tests/models/encoder_decoder/language/test_t5.py @@ -3,72 +3,47 @@ Run `pytest tests/models/encoder_decoder/language/test_t5.py`. """ -from typing import List, Optional, Tuple, Type - import pytest -from transformers import AutoModelForSeq2SeqLM - -from tests.kernels.utils import make_test_metadata -from vllm.attention.layer import Attention from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.config import set_current_vllm_config -from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, - HfRunner, VllmRunner) +from ....conftest import DecoderPromptType from ....utils import multi_gpu_test from .conftest import compare_hf_vllm_logprobs import torch -from vllm.model_executor.models.t5 import T5Attention, T5Config, AttentionType -from vllm.platforms import current_platform +from vllm.attention.selector import _Backend + @pytest.mark.parametrize( "model", [ - # pytest.param("google/t5-small", - # marks=[pytest.mark.core_model, pytest.mark.cpu_model]), pytest.param("google-t5/t5-small"), + pytest.param("google/flan-t5-base"), ], ) -@pytest.mark.parametrize( - "vllm_kwargs", - [{ - "max_model_len": 512 - }] - ) -@pytest.mark.parametrize("dtype", ["float"])#, "bfloat16"]) +@pytest.mark.parametrize("vllm_kwargs", [{"max_model_len": 512}]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +# TODO custom prompt here generate high entropy output, causing +# differences in sampled tokens. +@pytest.mark.parametrize("decoder_prompt_type", + [DecoderPromptType.NONE, DecoderPromptType.EMPTY_STR]) def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type, vllm_kwargs) -> None: - # TODO force backend - compare_hf_vllm_logprobs( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - vllm_runner_kwargs=vllm_kwargs - ) - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out + dtype, max_tokens, num_logprobs, decoder_prompt_type, + vllm_kwargs) -> None: + # Model only supported on xformers backend as of now. + with global_force_attn_backend_context_manager(_Backend.XFORMERS): + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + vllm_runner_kwargs=vllm_kwargs) @pytest.fixture @@ -87,156 +62,6 @@ def dist_init(): yield cleanup_dist_env_and_memory() -# TODO more cases -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -def test_t5_bias_attention(dtype, dist_init) -> None: - import random - seed = 0 - MAX_SEQ_LEN = 32 - block_size = 16 - NUM_BLOCKS = 4321 - current_platform.seed_everything(seed) - config = T5Config() - - # setup kv caches - head_size = config.d_kv - num_heads = (config.num_heads, config.num_heads) - num_seqs = 1 - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - # query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - # query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - - seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - seq_lens[-1] = MAX_SEQ_LEN - max_seq_len = max(seq_lens) - # seq_lens = torch.tensor(seq_lens, dtype=torch.int) - - # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables_lst: List[List[int]] = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables_lst.append(block_table) - - block_tables = torch.tensor(block_tables_lst, dtype=torch.int) - - # Create the KV caches. - kv_cache_dtype = 'auto' - from vllm.utils import create_kv_caches_with_random - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - 'cuda') - key_cache, value_cache = key_caches[0], value_caches[0] - - # Using default kv_scale - k_scale = v_scale = 1.0 - - - from vllm.attention.selector import _Backend - x = torch.randn(num_seqs, max_seq_len, config.d_model, device='cuda', dtype=torch.float) - with global_force_attn_backend_context_manager(_Backend.XFORMERS): - - from vllm.attention.backends.xformers import XFormersMetadata - from vllm.attention.backends.xformers import XFormersBackend - from vllm import LLM - - from vllm.forward_context import set_forward_context - from vllm.config import VllmConfig - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - encoder_seq_start_loc = torch.zeros(len(seq_lens) + 1, - dtype=torch.int32, - device='cuda') - meta = XFormersBackend.make_metadata( - seq_lens=None,#seq_lens, - max_decode_seq_len=0, num_prefills=None, - num_prefill_tokens=None, num_decode_tokens=0, - seq_lens_tensor=None,#torch.tensor(seq_lens), - slot_mapping=None,#torch.zeros(1), - multi_modal_placeholder_index_maps=None, - max_prefill_seq_len=None,#MAX_SEQ_LEN, - use_cuda_graph=False, - context_lens_tensor=None, - # no block tables on encoder forward - block_tables=torch.tensor([]).cuda(), - # block_tables=block_tables, - num_encoder_tokens=sum(seq_lens), encoder_seq_lens=seq_lens,encoder_seq_lens_tensor=torch.tensor(seq_lens).cuda(), - max_encoder_seq_len=max(seq_lens), encoder_seq_start_loc=encoder_seq_start_loc) - # # NOTE use compute_bias here - # attn_bias = t5_attn.compute_bias(MAX_SEQ_LEN, MAX_SEQ_LEN) - - # same weights should be loaded - # TODO load model without engine overhead - llm = LLM(model="google-t5/t5-small", load_format='safetensors', enforce_eager=True, dtype='float') - model = llm.llm_engine.model_executor.driver_worker.model_runner.model - t5_attn = model.model.encoder.blocks[0].self_attn.SelfAttention - print("\nTYPE", type(t5_attn)) - # TODO decoder - # FIXME this is kinda close, maybe issue is not with xformers custom bias attn - # t5_attn = T5Attention(config, AttentionType.ENCODER, has_relative_attention_bias=True).cuda() - # t5_attn.has_relative_attention_bias = False - assert t5_attn.has_relative_attention_bias - from transformers import T5Tokenizer, T5ForConditionalGeneration - from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention - hfmodel = T5ForConditionalGeneration.from_pretrained('google-t5/t5-small', return_dict=True) - print("My T5", t5_attn) - # this must be set to call attn.impl.forward - # vllm_config.compilation_config.static_forward_context[".attn"] = t5_attn.attn - vllm_config.compilation_config.static_forward_context["model.encoder.blocks.0.self_attn.SelfAttention.attn"] = t5_attn.attn - hf_attn = hfmodel.encoder.block[0].layer[0].SelfAttention.cuda() - # hf_attn.has_relative_attention_bias = False - assert hf_attn.has_relative_attention_bias - # hf_attn = HFT5Attention(config, has_relative_attention_bias=True).cuda() - - - with set_forward_context(meta, vllm_config): - # kv_cache for xformers [2, num_blocks, block_size * num_kv_heads * head_size] - kvc = torch.stack([key_cache.reshape(NUM_BLOCKS, -1), value_cache.reshape(NUM_BLOCKS, -1)], 0) - output = t5_attn(x, kvc, meta) - ref_output, *_ = hf_attn(x) - - atol, rtol = 1e-3, 1e-5 - # torch.testing.assert_close(output, ref_output.squeeze(), atol=atol, rtol=rtol) - - # **cross attn** - t5_attn = model.model.decoder.blocks[0].cross_attn.EncDecAttention - print("\nTYPE", type(t5_attn)) - assert not t5_attn.has_relative_attention_bias - vllm_config.compilation_config.static_forward_context["model.decoder.blocks.0.cross_attn.EncDecAttention.attn"] = t5_attn.attn - hf_attn = hfmodel.decoder.block[0].layer[1].EncDecAttention.cuda() - assert not hf_attn.has_relative_attention_bias - - meta = XFormersBackend.make_metadata( - seq_lens=seq_lens, - max_decode_seq_len=MAX_SEQ_LEN, num_prefills=0, - num_prefill_tokens=0, num_decode_tokens=1, - max_prefill_seq_len=None, - seq_lens_tensor=torch.tensor(seq_lens), - slot_mapping=None,#torch.zeros(1), - multi_modal_placeholder_index_maps=None, - use_cuda_graph=False, - context_lens_tensor=None, - block_tables=torch.tensor([]).cuda(), - # block_tables=block_tables - ) - - - with set_forward_context(meta, vllm_config): - output = t5_attn(x, kvc, meta) - ref_output, *_ = hf_attn(x) - - torch.testing.assert_close(output, ref_output.squeeze(), atol=atol, rtol=rtol) - @multi_gpu_test(num_gpus=2) @@ -245,21 +70,22 @@ def test_t5_bias_attention(dtype, dist_init) -> None: @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.NONE]) def test_models_distributed(hf_runner, vllm_runner, example_encoder_decoder_prompts, distributed_executor_backend, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - compare_hf_vllm_logprobs( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - ) + with global_force_attn_backend_context_manager(_Backend.XFORMERS): + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/models/encoder_decoder/language/utils.py b/tests/models/encoder_decoder/language/utils.py index c4828683dafdb..fd8c81c38b471 100644 --- a/tests/models/encoder_decoder/language/utils.py +++ b/tests/models/encoder_decoder/language/utils.py @@ -2,6 +2,7 @@ from ....conftest import (DecoderPromptType) from vllm.sequence import SampleLogprobs + def vllm_to_hf_output( vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], decoder_prompt_type: DecoderPromptType, @@ -13,4 +14,4 @@ def vllm_to_hf_output( if decoder_prompt_type == DecoderPromptType.NONE: hf_output_str = "" + hf_output_str - return output_ids, hf_output_str, out_logprobs \ No newline at end of file + return output_ids, hf_output_str, out_logprobs From 2d5b4fbc30439453b08c5e54ec8ce4664467d040 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Jan 2025 13:50:51 +0000 Subject: [PATCH 14/17] format Signed-off-by: NickLucche --- .../offline_inference_encoder_decoder.py | 58 +++- .../encoder_decoder/language/conftest.py | 12 +- .../language/language/__init__.py | 0 .../language/language/conftest.py | 143 ++++++++ .../language/language/test_bart.py | 63 ++++ .../language/language/test_t5.py | 305 ++++++++++++++++++ .../language/language/utils.py | 17 + .../encoder_decoder/language/test_bart.py | 3 +- .../encoder_decoder/language/test_t5.py | 26 +- .../models/encoder_decoder/language/utils.py | 4 +- vllm/attention/backends/xformers.py | 9 +- vllm/model_executor/models/t5.py | 105 +++--- 12 files changed, 650 insertions(+), 95 deletions(-) create mode 100644 tests/models/encoder_decoder/language/language/__init__.py create mode 100644 tests/models/encoder_decoder/language/language/conftest.py create mode 100644 tests/models/encoder_decoder/language/language/test_bart.py create mode 100644 tests/models/encoder_decoder/language/language/test_t5.py create mode 100644 tests/models/encoder_decoder/language/language/utils.py diff --git a/examples/offline_inference/offline_inference_encoder_decoder.py b/examples/offline_inference/offline_inference_encoder_decoder.py index f386ebb6c2176..0f266d7918853 100644 --- a/examples/offline_inference/offline_inference_encoder_decoder.py +++ b/examples/offline_inference/offline_inference_encoder_decoder.py @@ -11,10 +11,8 @@ # Create a BART encoder/decoder model instance llm = LLM( - # model="facebook/bart-large-cnn", - model="google-t5/t5-small", + model="facebook/bart-large-cnn", dtype=dtype, - enforce_eager=True ) # Get BART tokenizer @@ -26,9 +24,41 @@ # encoder/decoder model. # # - Helpers for building prompts -to_translate = "My name is Azeem and I live in India" -text_prompt_raw = "translate English to German: "+to_translate +text_prompt_raw = "Hello, my name is" +text_prompt = TextPrompt(prompt="The president of the United States is") +tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( + prompt="The capital of France is")) +# - Pass a single prompt to encoder/decoder model +# (implicitly encoder input prompt); +# decoder input prompt is assumed to be None +single_text_prompt_raw = text_prompt_raw # Pass a string directly +single_text_prompt = text_prompt # Pass a TextPrompt +single_tokens_prompt = tokens_prompt # Pass a TokensPrompt + +# - Pass explicit encoder and decoder input prompts within one data structure. +# Encoder and decoder prompts can both independently be text or tokens, with +# no requirement that they be the same prompt type. Some example prompt-type +# combinations are shown below, note that these are not exhaustive. + +enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt string directly, & + # pass decoder prompt tokens + encoder_prompt=single_text_prompt_raw, + decoder_prompt=single_tokens_prompt, +) +enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( + # Pass TextPrompt to encoder, and + # pass decoder prompt string directly + encoder_prompt=single_text_prompt, + decoder_prompt=single_text_prompt_raw, +) +enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt tokens directly, and + # pass TextPrompt to decoder + encoder_prompt=single_tokens_prompt, + decoder_prompt=single_text_prompt, +) # - Finally, here's a useful helper function for zipping encoder and # decoder prompts together into a list of ExplicitEncoderDecoderPrompt @@ -39,21 +69,19 @@ # - Let's put all of the above example prompts together into one list # which we will pass to the encoder/decoder LLM. -# prompts = [ -# single_text_prompt_raw, single_text_prompt, single_tokens_prompt, -# enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 -# ] + zipped_prompt_list +prompts = [ + single_text_prompt_raw, single_text_prompt, single_tokens_prompt, + enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 +] + zipped_prompt_list -prompts = [text_prompt_raw]#, "Se ni mondo"] print(prompts) # Create a sampling params object. sampling_params = SamplingParams( - temperature=0.2, - max_tokens=100, - # top_p=1.0, - # min_tokens=0, - # max_tokens=20, + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, ) # Generate output tokens from the prompts. The output is a list of diff --git a/tests/models/encoder_decoder/language/conftest.py b/tests/models/encoder_decoder/language/conftest.py index 5f2b14741df03..b3cb094401a40 100644 --- a/tests/models/encoder_decoder/language/conftest.py +++ b/tests/models/encoder_decoder/language/conftest.py @@ -1,7 +1,9 @@ +from typing import Any, Dict, List, Optional, Type + from transformers import AutoModelForSeq2SeqLM -from ....conftest import (DecoderPromptType, HfRunner, VllmRunner, - ExplicitEncoderDecoderPrompt) -from typing import List, Optional, Tuple, Type, Dict, Any + +from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, + HfRunner, VllmRunner) from ...utils import check_logprobs_close from .utils import vllm_to_hf_output @@ -18,7 +20,7 @@ def compare_hf_vllm_logprobs( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, - vllm_runner_kwargs: Optional[Dict[str, Any]] = dict(), + vllm_runner_kwargs: Optional[Dict[str, Any]] = None, hf_tokens_to_skip: int = 0) -> None: ''' Test the provided model for a variety of encoder/decoder input prompts, @@ -101,6 +103,8 @@ def compare_hf_vllm_logprobs( # decoder-only unit tests expect), so when testing an encoder/decoder # model we must explicitly specify enforce_eager=True in the VllmRunner # constructor. + if not vllm_runner_kwargs: + vllm_runner_kwargs = dict() with vllm_runner(model, dtype=dtype, tensor_parallel_size=tensor_parallel_size, diff --git a/tests/models/encoder_decoder/language/language/__init__.py b/tests/models/encoder_decoder/language/language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/language/language/conftest.py b/tests/models/encoder_decoder/language/language/conftest.py new file mode 100644 index 0000000000000..5f2b14741df03 --- /dev/null +++ b/tests/models/encoder_decoder/language/language/conftest.py @@ -0,0 +1,143 @@ +from transformers import AutoModelForSeq2SeqLM +from ....conftest import (DecoderPromptType, HfRunner, VllmRunner, + ExplicitEncoderDecoderPrompt) +from typing import List, Optional, Tuple, Type, Dict, Any +from ...utils import check_logprobs_close +from .utils import vllm_to_hf_output + + +def compare_hf_vllm_logprobs( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + decoder_prompt_type: DecoderPromptType, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, + vllm_runner_kwargs: Optional[Dict[str, Any]] = dict(), + hf_tokens_to_skip: int = 0) -> None: + ''' + Test the provided model for a variety of encoder/decoder input prompts, + by validating it against corresponding HuggingFace (HF). + + Arguments: + + * hf_runner: HuggingFace (HF) test model runner + * vllm_runner: vLLM test model runner + * example_encoder_decoder_prompts: test fixture which provides a + dictionary of dummy prompts + * model: the HF ID of the specific BART variant under test + * dtype: the tensor datatype to employ + * max_tokens + * num_logprobs + * decoder_prompt_type: key into the example_encoder_decoder_prompts + dictionary; selects specific encoder/decoder + prompt scenarios to test + + A note on using HF BART as a baseline for validating vLLM BART, + specifically when the decoder prompt is None. + + The HF GenerationMixin's default behavior is to force the first + decoded token to be if the prompt does not already contain + (this is accomplished using a logit + processor setting.) + + So when we use HF BART as our baseline for comparison, note that + when the user provides a request with a None decoder prompt + (i.e. a singleton encoder prompt, or else an explicit encoder/ + decoder prompt with the decoder sub-prompt set to None), HF and + vLLM handle this in different ways: + + * HF will (1) tokenize the None prompt as an empty token-list, + (2) append to the beginning, yielding + [], (3) pass this token list to the model, and + then (4) after computing logits during prefill, override the model + logits & force to be the first generated token. + + * vLLM will (1) tokenize the None prompt as [], (2) append decoder- + start-token to the beginning, yielding [], + (3) pass these tokens to the model & proceed with generation. + + The net effect is that compared to vLLM, the list of HF *decoded* tokens + will contain one more initial than the vLLM generated tokens, + because vLLM's token is injected into the prompt rather than into + the generated output. This is in spite of the fact that overall, the + complete sequences (prompt + decoded tokens) produced by vLLM will match + HF. + + So when we use HF decoded token output to validate vLLM's decoded token + output, the testing process must account for the difference in decoded + token sequences between vLLM and HF specifically in the + decoder-prompt-is-None case. + + One option is to disable the logit processor feature that forces the + token to be decoded (forced_bos_token_id = None), eliminating + the problem entirely. However this is not "normal" BART usage. + + The other option is - only in the decoder-prompt-is-None case - to + discard the first decoded token from the HF output before comparing it + to vLLM. + + To that end, `hf_tokens_to_skip` must be set to the number of HF decoded + tokens to skip during the process of validating the vLLM decoded output. + ''' + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default). + + # Note: currently encoder/decoder models are only compatible with + # enforce_eager=True. Normally this is not a problem because + # for encoder/decoder models vLLM will + # default to enforce_eager=True if enforce_eager + # is left unspecified. However, the + # VllmRunner test fixture (which wraps around the LLM class) defaults to + # enforce_eager=False (a behavior which a number of already-exisitng + # decoder-only unit tests expect), so when testing an encoder/decoder + # model we must explicitly specify enforce_eager=True in the VllmRunner + # constructor. + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + **vllm_runner_kwargs) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs) + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_tokens_to_skip, + ) diff --git a/tests/models/encoder_decoder/language/language/test_bart.py b/tests/models/encoder_decoder/language/language/test_bart.py new file mode 100644 index 0000000000000..990bd96bed570 --- /dev/null +++ b/tests/models/encoder_decoder/language/language/test_bart.py @@ -0,0 +1,63 @@ +"""Compare the outputs of HF and vLLM for BART models using greedy sampling. + +Run `pytest tests/models/encoder_decoder/language/test_bart.py`. +""" +import pytest + +from ....conftest import DecoderPromptType +from ....utils import multi_gpu_test +from .conftest import compare_hf_vllm_logprobs + + +@pytest.mark.parametrize( + "model", + [ + pytest.param("facebook/bart-base", + marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param("facebook/bart-large-cnn"), + ], +) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: + + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + hf_tokens_to_skip=int(decoder_prompt_type == DecoderPromptType.NONE)) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +def test_models_distributed(hf_runner, vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, model, dtype, + max_tokens, num_logprobs, + decoder_prompt_type) -> None: + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + hf_tokens_to_skip=int(decoder_prompt_type == DecoderPromptType.NONE)) diff --git a/tests/models/encoder_decoder/language/language/test_t5.py b/tests/models/encoder_decoder/language/language/test_t5.py new file mode 100644 index 0000000000000..12fe2b246e1b8 --- /dev/null +++ b/tests/models/encoder_decoder/language/language/test_t5.py @@ -0,0 +1,305 @@ +"""Compare the outputs of HF and vLLM for T5 models using greedy sampling. +Based on tests/models/encoder_decoder/language/test_bart.py. + +Run `pytest tests/models/encoder_decoder/language/test_t5.py`. +""" +from typing import Optional +import pytest +from vllm.attention.selector import global_force_attn_backend_context_manager +from vllm.config import set_current_vllm_config + +from ....conftest import DecoderPromptType +from ....utils import multi_gpu_test +from .conftest import compare_hf_vllm_logprobs +import torch +from vllm.model_executor.models.t5 import T5Config +from vllm.platforms import current_platform +from vllm.attention.selector import _Backend + + +@pytest.mark.parametrize( + "model", + [ + pytest.param("google-t5/t5-small"), + pytest.param("google/flan-t5-base"), + ], +) +@pytest.mark.parametrize("vllm_kwargs", [{"max_model_len": 512}]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +# TODO custom prompt here generate high entropy output, causing +# differences in sampled tokens. +@pytest.mark.parametrize("decoder_prompt_type", + [DecoderPromptType.NONE, DecoderPromptType.EMPTY_STR]) +def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, + dtype, max_tokens, num_logprobs, decoder_prompt_type, + vllm_kwargs) -> None: + # Model only supported on xformers backend as of now. + with global_force_attn_backend_context_manager(_Backend.XFORMERS): + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + vllm_runner_kwargs=vllm_kwargs) + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +@pytest.fixture +def dist_init(): + from vllm.distributed import init_distributed_environment, cleanup_dist_env_and_memory, initialize_model_parallel + import tempfile + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield + cleanup_dist_env_and_memory() + + +# TODO more cases +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +def test_t5_bias_attention(dtype, dist_init) -> None: + import random + + seed = 0 + MAX_SEQ_LEN = 34 + block_size = 16 + NUM_BLOCKS = 4321 + current_platform.seed_everything(seed) + config = T5Config() + + # setup kv caches + head_size = config.d_kv + num_heads = (config.num_heads, config.num_heads) + num_seqs = 1 + + num_query_heads, num_kv_heads = num_heads + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + + # Create the KV caches. + kv_cache_dtype = 'auto' + from vllm.utils import create_kv_caches_with_random + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, + dtype, seed, 'cuda') + key_cache, value_cache = key_caches[0], value_caches[0] + + x = torch.randn(num_seqs, + max_seq_len, + config.d_model, + device='cuda', + dtype=torch.float) + with global_force_attn_backend_context_manager(_Backend.XFORMERS): + + from vllm.attention.backends.xformers import XFormersBackend + from vllm import LLM + + from vllm.forward_context import set_forward_context + from vllm.config import VllmConfig + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + encoder_seq_start_loc = torch.zeros(len(seq_lens) + 1, + dtype=torch.int32, + device='cuda') + meta = XFormersBackend.make_metadata( + seq_lens=None, #seq_lens, + max_decode_seq_len=0, + num_prefills=None, + num_prefill_tokens=None, + num_decode_tokens=0, + seq_lens_tensor=None, #torch.tensor(seq_lens), + slot_mapping=None, #torch.zeros(1), + multi_modal_placeholder_index_maps=None, + max_prefill_seq_len=None, #MAX_SEQ_LEN, + use_cuda_graph=False, + context_lens_tensor=None, + # no block tables on encoder forward + block_tables=torch.tensor([]).cuda(), + # block_tables=block_tables, + num_encoder_tokens=sum(seq_lens), + encoder_seq_lens=seq_lens, + encoder_seq_lens_tensor=torch.tensor(seq_lens).cuda(), + max_encoder_seq_len=max(seq_lens), + encoder_seq_start_loc=encoder_seq_start_loc) + # same weights should be loaded + # TODO load model without engine overhead + llm = LLM(model="google-t5/t5-small", + load_format='safetensors', + enforce_eager=True, + dtype='float') + model = llm.llm_engine.model_executor.driver_worker.model_runner.model + t5_attn = model.model.encoder.blocks[0].self_attn.SelfAttention + print("\nTYPE", type(t5_attn)) + # FIXME this is kinda close, maybe issue is not with xformers custom bias attn + # t5_attn = T5Attention(config, AttentionType.ENCODER, has_relative_attention_bias=True).cuda() + assert t5_attn.has_relative_attention_bias + from transformers import T5Tokenizer, T5ForConditionalGeneration + from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention + hfmodel = T5ForConditionalGeneration.from_pretrained( + 'google-t5/t5-small', return_dict=True) + print("My T5", t5_attn) + # this must be set to call attn.impl.forward + # vllm_config.compilation_config.static_forward_context[".attn"] = t5_attn.attn + vllm_config.compilation_config.static_forward_context[ + "model.encoder.blocks.0.self_attn.SelfAttention.attn"] = t5_attn.attn + hf_attn = hfmodel.encoder.block[0].layer[0].SelfAttention.cuda() + assert hf_attn.has_relative_attention_bias + # hf_attn = HFT5Attention(config, has_relative_attention_bias=True).cuda() + + with set_forward_context(meta, vllm_config): + # input to vllm is 1d flattened, assuming all sequences of same len + xin = x.reshape(-1, config.d_model) + # kv_cache for xformers [2, num_blocks, block_size * num_kv_heads * head_size] + kvc = torch.stack([ + key_cache.reshape(NUM_BLOCKS, -1), + value_cache.reshape(NUM_BLOCKS, -1) + ], 0) + output = t5_attn(xin, kvc, meta) + ref_output, *_ = hf_attn(x) + + atol, rtol = 1e-3, 1e-5 + torch.testing.assert_close(output, + ref_output.squeeze(), + atol=atol, + rtol=rtol) + + # **decoder attn, first xformer forward** + t5_attn = model.model.decoder.blocks[0].self_attn.SelfAttention + assert t5_attn.has_relative_attention_bias + vllm_config.compilation_config.static_forward_context[ + "model.decoder.blocks.0.self_attn.SelfAttention.attn"] = t5_attn.attn + hf_attn = hfmodel.decoder.block[0].layer[0].SelfAttention.cuda() + assert hf_attn.has_relative_attention_bias + + num_decoding_input_ids = 2 # 1 + x = torch.randn(num_seqs, + num_decoding_input_ids, + config.d_model, + device='cuda', + dtype=torch.float) + prefill_seqlens = [num_decoding_input_ids] * len(seq_lens) + meta = XFormersBackend.make_metadata( + seq_lens=prefill_seqlens, + max_decode_seq_len=0, + num_prefills=len(seq_lens), + num_prefill_tokens=sum(prefill_seqlens), + num_decode_tokens=0, + seq_lens_tensor=torch.tensor(prefill_seqlens), + slot_mapping=torch.zeros(1, dtype=torch.long), + # slot_mapping=torch.tensor(slot_mapping_list, dtype=torch.long,device="cuda"), + multi_modal_placeholder_index_maps=None, + max_prefill_seq_len=max(prefill_seqlens), + use_cuda_graph=False, + context_lens_tensor=None, + block_tables=torch.tensor([]).cuda(), + # block_tables=block_tables, + # num_encoder_tokens=sum(seq_lens), encoder_seq_lens=seq_lens,encoder_seq_lens_tensor=torch.tensor(seq_lens).cuda(), + # max_encoder_seq_len=max(seq_lens), encoder_seq_start_loc=encoder_seq_start_loc + ) + + with set_forward_context(meta, vllm_config): + xin = x.reshape(-1, config.d_model) + kvc = torch.stack([ + key_cache.reshape(NUM_BLOCKS, -1), + value_cache.reshape(NUM_BLOCKS, -1) + ], 0) + output = t5_attn(xin, kvc, meta) + ref_output, *_ = hf_attn(x) + torch.testing.assert_close(output.squeeze(), + ref_output.squeeze(), + atol=atol, + rtol=rtol) + return + # **cross attn** + t5_attn = model.model.decoder.blocks[0].cross_attn.EncDecAttention + print("\nTYPE", type(t5_attn)) + assert not t5_attn.has_relative_attention_bias + vllm_config.compilation_config.static_forward_context[ + "model.decoder.blocks.0.cross_attn.EncDecAttention.attn"] = t5_attn.attn + hf_attn = hfmodel.decoder.block[0].layer[1].EncDecAttention.cuda() + assert not hf_attn.has_relative_attention_bias + + meta = XFormersBackend.make_metadata( + seq_lens=seq_lens, + max_decode_seq_len=MAX_SEQ_LEN, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=1, + max_prefill_seq_len=None, + seq_lens_tensor=torch.tensor(seq_lens), + slot_mapping= + None, #torch.tensor(slot_mapping_list, dtype=torch.long,device="cuda"), + multi_modal_placeholder_index_maps=None, + use_cuda_graph=False, + context_lens_tensor=None, + block_tables=torch.tensor([]).cuda(), + # block_tables=block_tables + ) + + with set_forward_context(meta, vllm_config): + output = t5_attn(x, kvc, meta) + ref_output, *_ = hf_attn(x) + + torch.testing.assert_close(output, + ref_output.squeeze(), + atol=atol, + rtol=rtol) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", ["google/t5-small"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +def test_models_distributed(hf_runner, vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, model, dtype, + max_tokens, num_logprobs, + decoder_prompt_type) -> None: + compare_hf_vllm_logprobs( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/models/encoder_decoder/language/language/utils.py b/tests/models/encoder_decoder/language/language/utils.py new file mode 100644 index 0000000000000..fd8c81c38b471 --- /dev/null +++ b/tests/models/encoder_decoder/language/language/utils.py @@ -0,0 +1,17 @@ +from typing import List, Optional, Tuple +from ....conftest import (DecoderPromptType) +from vllm.sequence import SampleLogprobs + + +def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 6e8aebf00f4f7..30f48c80a3dba 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -4,8 +4,9 @@ """ import pytest +from tests.utils import multi_gpu_test # type: ignore[attr-defined] + from ....conftest import DecoderPromptType -from ....utils import multi_gpu_test from .conftest import compare_hf_vllm_logprobs diff --git a/tests/models/encoder_decoder/language/test_t5.py b/tests/models/encoder_decoder/language/test_t5.py index 873e4c314d5b5..aebfbc728aa44 100644 --- a/tests/models/encoder_decoder/language/test_t5.py +++ b/tests/models/encoder_decoder/language/test_t5.py @@ -4,13 +4,13 @@ Run `pytest tests/models/encoder_decoder/language/test_t5.py`. """ import pytest -from vllm.attention.selector import global_force_attn_backend_context_manager + +from tests.utils import multi_gpu_test +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) from ....conftest import DecoderPromptType -from ....utils import multi_gpu_test from .conftest import compare_hf_vllm_logprobs -import torch -from vllm.attention.selector import _Backend @pytest.mark.parametrize( @@ -46,24 +46,6 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, vllm_runner_kwargs=vllm_kwargs) -@pytest.fixture -def dist_init(): - from vllm.distributed import init_distributed_environment, cleanup_dist_env_and_memory, initialize_model_parallel - import tempfile - temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend="nccl", - ) - initialize_model_parallel(1, 1) - yield - cleanup_dist_env_and_memory() - - - @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @pytest.mark.parametrize("model", ["google/t5-small"]) diff --git a/tests/models/encoder_decoder/language/utils.py b/tests/models/encoder_decoder/language/utils.py index fd8c81c38b471..7c675b7d7a60c 100644 --- a/tests/models/encoder_decoder/language/utils.py +++ b/tests/models/encoder_decoder/language/utils.py @@ -1,7 +1,9 @@ from typing import List, Optional, Tuple -from ....conftest import (DecoderPromptType) + from vllm.sequence import SampleLogprobs +from ....conftest import DecoderPromptType + def vllm_to_hf_output( vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 79b69c9e351a3..edb9f2f5ae0ee 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -605,7 +605,9 @@ def forward( if attn_bias: assert len( attn_bias - ) == 1, "PagedAttention expects a single bias to be provided for all input sequences." + ) == 1, "PagedAttention expects a single bias to be provided\ + for all input sequences." + attn_bias = attn_bias[0] output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, @@ -713,7 +715,7 @@ def _run_memory_efficient_xformers_forward( attn_metadata.seq_lens) else: raise ValueError("Unknown AttentionType: %s", attn_type) - + assert isinstance(attn_bias, BlockDiagonalMask) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( @@ -750,7 +752,8 @@ def _run_memory_efficient_xformers_forward( # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) - seq_lens = attn_metadata.encoder_seq_lens if attn_type == AttentionType.ENCODER else attn_metadata.seq_lens + seq_lens = attn_metadata.encoder_seq_lens \ + if attn_type == AttentionType.ENCODER else attn_metadata.seq_lens assert seq_lens start = 0 for i, seq_len in enumerate(seq_lens): diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index b0af947ada7eb..a8327a6e8128a 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -17,51 +17,55 @@ """PyTorch T5 model.""" import math -from typing import List, Optional, Tuple, Set, Iterable import re +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn +from transformers import T5Config +# TODO best way to handle xformers imports? +from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias + +# TODO func should be in backend interface +from vllm.attention.backends.xformers import (XFormersMetadata, _get_attn_bias, + _set_attn_bias) +from vllm.attention.layer import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.layers.activation import get_act_fn -from vllm.attention.layer import Attention, AttentionType, AttentionMetadata -from vllm.config import CacheConfig -from transformers import T5Config -from vllm.attention.backends.xformers import XFormersMetadata -from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import get_sampler, SamplerOutput -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import IntermediateTensors from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.sequence import IntermediateTensors + from .utils import maybe_prefix -# TODO best way to handle xformers imports? -from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias -# TODO func should be in backend interface -from vllm.attention.backends.xformers import _get_attn_bias, _set_attn_bias class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ - Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + Construct a layernorm module in the T5 style. + No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states) -> torch.Tensor: - # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean - # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated - # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for - # half-precision inputs is done in fp32 + # T5 uses a layer_norm which only scales and doesn't shift, which is + # also known as Root Mean Square Layer Normalization + # https://arxiv.org/abs/1910.07467 thus variance is calculated w/o mean + # and there is no bias. Additionally we want to make sure that the + # accumulation for half-precision inputs is done in fp32. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) @@ -115,6 +119,8 @@ def __init__(self, config.d_ff, bias=False, quant_config=quant_config) + # Should not run in fp16 unless mixed-precision is used, + # see https://github.com/huggingface/transformers/issues/20287. self.wo = RowParallelLinear(config.d_ff, config.d_model, bias=False, @@ -125,18 +131,6 @@ def forward(self, hidden_states) -> torch.Tensor: hidden_gelu = self.act(self.wi_0(hidden_states)[0]) hidden_linear, _ = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear - - # TODO - # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. - # See https://github.com/huggingface/transformers/issues/20287 - # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - # if ( - # isinstance(self.wo.weight, torch.Tensor) - # and hidden_states.dtype != self.wo.weight.dtype - # and self.wo.weight.dtype != torch.int8 - # ): - # hidden_states = hidden_states.to(self.wo.weight.dtype) - hidden_states, _ = self.wo(hidden_states) return hidden_states @@ -178,8 +172,10 @@ def __init__(self, # Cross-attention has no relative pos encoding anyway self.is_decoder = attn_type == AttentionType.DECODER self.has_relative_attention_bias = has_relative_attention_bias - self.relative_attention_num_buckets = config.relative_attention_num_buckets - self.relative_attention_max_distance = config.relative_attention_max_distance + self.relative_attention_num_buckets = \ + config.relative_attention_num_buckets + self.relative_attention_max_distance = \ + config.relative_attention_max_distance self.d_model = config.d_model self.key_value_proj_dim = config.d_kv @@ -211,8 +207,12 @@ def __init__(self, # Only the first SelfAttention block in encoder decoder has this # embedding layer, the others reuse its output. if self.has_relative_attention_bias: - self.relative_attention_bias = VocabParallelEmbedding(self.relative_attention_num_buckets,\ - self.n_heads, org_num_embeddings=self.relative_attention_num_buckets, quant_config=quant_config) + self.relative_attention_bias = \ + VocabParallelEmbedding(self.relative_attention_num_buckets, + self.n_heads, + org_num_embeddings=\ + self.relative_attention_num_buckets, + quant_config=quant_config) self.out_proj = RowParallelLinear( self.inner_dim, self.d_model, @@ -229,12 +229,16 @@ def _relative_position_bucket(relative_position, Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 - Translate relative position to a bucket number for relative attention. The relative position is defined as - memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to - position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for - small absolute relative_position and larger buckets for larger absolute relative_positions. All relative - positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. - This should allow for more graceful generalization to longer sequences than the model has been trained on + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, + i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative + positions are invalid. We use smaller buckets for small absolute + relative_position and larger buckets for larger absolute + relative_positions. All relative positions >=max_distance map to the + same bucket. All relative positions <=-max_distance map to the same + bucket. This should allow for more graceful generalization to longer + sequences than the model has been trained on Args: relative_position: an int32 Tensor @@ -243,8 +247,9 @@ def _relative_position_bucket(relative_position, max_distance: an integer Returns: - a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) - """ + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """# noqa: E501 relative_buckets = 0 if bidirectional: num_buckets //= 2 @@ -260,7 +265,8 @@ def _relative_position_bucket(relative_position, max_exact = num_buckets // 2 is_small = relative_position < max_exact - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + # The other half of the buckets are for logarithmically bigger bins + # in positions up to max_distance relative_position_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * @@ -288,7 +294,7 @@ def compute_bias(self, dtype=torch.long, device=device)[None, :] # max_seq_len, nh - relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position = memory_position - context_position relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), @@ -359,7 +365,8 @@ def forward( seq_len = attn_metadata.max_encoder_seq_len padded_seq_len = (seq_len + align_to - 1) // align_to * align_to - # TODO (NickLucche) avoid extra copy on repeat, provide multiple slices of same memory + # TODO (NickLucche) avoid extra copy on repeat, + # provide multiple slices of same memory position_bias = self.compute_bias(seq_len, padded_seq_len).repeat( num_seqs, 1, 1, 1) @@ -698,7 +705,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: if self.config.tie_word_embeddings: # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa: E501 hidden_states = hidden_states * (self.model_dim**-0.5) logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) From dc25e4de2e156ed13a1d2b0ea1859121bbe9e13e Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Jan 2025 14:16:06 +0000 Subject: [PATCH 15/17] sync with custom attn bias pr Signed-off-by: NickLucche --- csrc/attention/attention_kernels.cuh | 15 --------------- tests/kernels/test_attention.py | 11 +---------- 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index a83bbf8b7648f..08f9882f65f09 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -155,21 +155,6 @@ __device__ void paged_attention_kernel( const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - // TODO check if indexing still makes sense - // seq_len indexes on 'max_seq_lens' dim, - // it's like renaming dim you get attn_bias: seq_len x num_kv_heads x seq_len - // TODO each seq can have different len (seq_lens) but only one bias!! - // NOTE (NickLucche) `max_seq_len` bias values for current sequence and current head - const float* attn_bias_vec = - attn_bias == nullptr - ? nullptr - : attn_bias + seq_idx * num_heads * num_seq_blocks * BLOCK_SIZE + - head_idx * num_seq_blocks * BLOCK_SIZE; - // : attn_bias + seq_idx * num_kv_heads * num_seq_blocks * BLOCK_SIZE + - // const float* attn_bias_vec = attn_bias == nullptr - // ? nullptr - // : attn_bias + seq_idx * num_kv_heads * seq_len + - // kv_head_idx * seq_len; // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence // and current head. diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index de5aba65d50d0..b9cfe6437183f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -18,8 +18,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -MAX_SEQ_LEN = 16 +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -30,7 +29,6 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing -# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # FlashAttention forward only supports head dimension at most 128 @@ -142,10 +140,6 @@ def test_paged_attention( seed: int, device: str, ) -> None: - # num_heads = (2, 2) - # num_seqs = 2 - # head_size = 32 - if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -213,7 +207,6 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -232,7 +225,6 @@ def test_paged_attention( k_scale, v_scale, ) - # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, @@ -242,7 +234,6 @@ def test_paged_attention( and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): - assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape From 3eae4f616b89e0087a320ee829aac613f4af786c Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Jan 2025 15:23:56 +0000 Subject: [PATCH 16/17] remove spurious files Signed-off-by: NickLucche --- .../language/language/conftest.py | 143 -------- .../language/language/test_bart.py | 63 ---- .../language/language/test_t5.py | 305 ------------------ .../language/language/utils.py | 17 - .../encoder_decoder/language/test_bart.py | 2 +- 5 files changed, 1 insertion(+), 529 deletions(-) delete mode 100644 tests/models/encoder_decoder/language/language/conftest.py delete mode 100644 tests/models/encoder_decoder/language/language/test_bart.py delete mode 100644 tests/models/encoder_decoder/language/language/test_t5.py delete mode 100644 tests/models/encoder_decoder/language/language/utils.py diff --git a/tests/models/encoder_decoder/language/language/conftest.py b/tests/models/encoder_decoder/language/language/conftest.py deleted file mode 100644 index 5f2b14741df03..0000000000000 --- a/tests/models/encoder_decoder/language/language/conftest.py +++ /dev/null @@ -1,143 +0,0 @@ -from transformers import AutoModelForSeq2SeqLM -from ....conftest import (DecoderPromptType, HfRunner, VllmRunner, - ExplicitEncoderDecoderPrompt) -from typing import List, Optional, Tuple, Type, Dict, Any -from ...utils import check_logprobs_close -from .utils import vllm_to_hf_output - - -def compare_hf_vllm_logprobs( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt[str, str]], - decoder_prompt_type: DecoderPromptType, - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, - vllm_runner_kwargs: Optional[Dict[str, Any]] = dict(), - hf_tokens_to_skip: int = 0) -> None: - ''' - Test the provided model for a variety of encoder/decoder input prompts, - by validating it against corresponding HuggingFace (HF). - - Arguments: - - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - * model: the HF ID of the specific BART variant under test - * dtype: the tensor datatype to employ - * max_tokens - * num_logprobs - * decoder_prompt_type: key into the example_encoder_decoder_prompts - dictionary; selects specific encoder/decoder - prompt scenarios to test - - A note on using HF BART as a baseline for validating vLLM BART, - specifically when the decoder prompt is None. - - The HF GenerationMixin's default behavior is to force the first - decoded token to be if the prompt does not already contain - (this is accomplished using a logit - processor setting.) - - So when we use HF BART as our baseline for comparison, note that - when the user provides a request with a None decoder prompt - (i.e. a singleton encoder prompt, or else an explicit encoder/ - decoder prompt with the decoder sub-prompt set to None), HF and - vLLM handle this in different ways: - - * HF will (1) tokenize the None prompt as an empty token-list, - (2) append to the beginning, yielding - [], (3) pass this token list to the model, and - then (4) after computing logits during prefill, override the model - logits & force to be the first generated token. - - * vLLM will (1) tokenize the None prompt as [], (2) append decoder- - start-token to the beginning, yielding [], - (3) pass these tokens to the model & proceed with generation. - - The net effect is that compared to vLLM, the list of HF *decoded* tokens - will contain one more initial than the vLLM generated tokens, - because vLLM's token is injected into the prompt rather than into - the generated output. This is in spite of the fact that overall, the - complete sequences (prompt + decoded tokens) produced by vLLM will match - HF. - - So when we use HF decoded token output to validate vLLM's decoded token - output, the testing process must account for the difference in decoded - token sequences between vLLM and HF specifically in the - decoder-prompt-is-None case. - - One option is to disable the logit processor feature that forces the - token to be decoded (forced_bos_token_id = None), eliminating - the problem entirely. However this is not "normal" BART usage. - - The other option is - only in the decoder-prompt-is-None case - to - discard the first decoded token from the HF output before comparing it - to vLLM. - - To that end, `hf_tokens_to_skip` must be set to the number of HF decoded - tokens to skip during the process of validating the vLLM decoded output. - ''' - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default). - - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-exisitng - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - **vllm_runner_kwargs) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_tokens_to_skip, - ) diff --git a/tests/models/encoder_decoder/language/language/test_bart.py b/tests/models/encoder_decoder/language/language/test_bart.py deleted file mode 100644 index 990bd96bed570..0000000000000 --- a/tests/models/encoder_decoder/language/language/test_bart.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Compare the outputs of HF and vLLM for BART models using greedy sampling. - -Run `pytest tests/models/encoder_decoder/language/test_bart.py`. -""" -import pytest - -from ....conftest import DecoderPromptType -from ....utils import multi_gpu_test -from .conftest import compare_hf_vllm_logprobs - - -@pytest.mark.parametrize( - "model", - [ - pytest.param("facebook/bart-base", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), - pytest.param("facebook/bart-large-cnn"), - ], -) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) -def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: - - compare_hf_vllm_logprobs( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - hf_tokens_to_skip=int(decoder_prompt_type == DecoderPromptType.NONE)) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) -def test_models_distributed(hf_runner, vllm_runner, - example_encoder_decoder_prompts, - distributed_executor_backend, model, dtype, - max_tokens, num_logprobs, - decoder_prompt_type) -> None: - compare_hf_vllm_logprobs( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - hf_tokens_to_skip=int(decoder_prompt_type == DecoderPromptType.NONE)) diff --git a/tests/models/encoder_decoder/language/language/test_t5.py b/tests/models/encoder_decoder/language/language/test_t5.py deleted file mode 100644 index 12fe2b246e1b8..0000000000000 --- a/tests/models/encoder_decoder/language/language/test_t5.py +++ /dev/null @@ -1,305 +0,0 @@ -"""Compare the outputs of HF and vLLM for T5 models using greedy sampling. -Based on tests/models/encoder_decoder/language/test_bart.py. - -Run `pytest tests/models/encoder_decoder/language/test_t5.py`. -""" -from typing import Optional -import pytest -from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.config import set_current_vllm_config - -from ....conftest import DecoderPromptType -from ....utils import multi_gpu_test -from .conftest import compare_hf_vllm_logprobs -import torch -from vllm.model_executor.models.t5 import T5Config -from vllm.platforms import current_platform -from vllm.attention.selector import _Backend - - -@pytest.mark.parametrize( - "model", - [ - pytest.param("google-t5/t5-small"), - pytest.param("google/flan-t5-base"), - ], -) -@pytest.mark.parametrize("vllm_kwargs", [{"max_model_len": 512}]) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -# TODO custom prompt here generate high entropy output, causing -# differences in sampled tokens. -@pytest.mark.parametrize("decoder_prompt_type", - [DecoderPromptType.NONE, DecoderPromptType.EMPTY_STR]) -def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, - dtype, max_tokens, num_logprobs, decoder_prompt_type, - vllm_kwargs) -> None: - # Model only supported on xformers backend as of now. - with global_force_attn_backend_context_manager(_Backend.XFORMERS): - compare_hf_vllm_logprobs( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - vllm_runner_kwargs=vllm_kwargs) - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out - - -@pytest.fixture -def dist_init(): - from vllm.distributed import init_distributed_environment, cleanup_dist_env_and_memory, initialize_model_parallel - import tempfile - temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend="nccl", - ) - initialize_model_parallel(1, 1) - yield - cleanup_dist_env_and_memory() - - -# TODO more cases -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) -def test_t5_bias_attention(dtype, dist_init) -> None: - import random - - seed = 0 - MAX_SEQ_LEN = 34 - block_size = 16 - NUM_BLOCKS = 4321 - current_platform.seed_everything(seed) - config = T5Config() - - # setup kv caches - head_size = config.d_kv - num_heads = (config.num_heads, config.num_heads) - num_seqs = 1 - - num_query_heads, num_kv_heads = num_heads - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - - seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - seq_lens[-1] = MAX_SEQ_LEN - max_seq_len = max(seq_lens) - - # Create the KV caches. - kv_cache_dtype = 'auto' - from vllm.utils import create_kv_caches_with_random - key_caches, value_caches = create_kv_caches_with_random( - NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, - dtype, seed, 'cuda') - key_cache, value_cache = key_caches[0], value_caches[0] - - x = torch.randn(num_seqs, - max_seq_len, - config.d_model, - device='cuda', - dtype=torch.float) - with global_force_attn_backend_context_manager(_Backend.XFORMERS): - - from vllm.attention.backends.xformers import XFormersBackend - from vllm import LLM - - from vllm.forward_context import set_forward_context - from vllm.config import VllmConfig - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - encoder_seq_start_loc = torch.zeros(len(seq_lens) + 1, - dtype=torch.int32, - device='cuda') - meta = XFormersBackend.make_metadata( - seq_lens=None, #seq_lens, - max_decode_seq_len=0, - num_prefills=None, - num_prefill_tokens=None, - num_decode_tokens=0, - seq_lens_tensor=None, #torch.tensor(seq_lens), - slot_mapping=None, #torch.zeros(1), - multi_modal_placeholder_index_maps=None, - max_prefill_seq_len=None, #MAX_SEQ_LEN, - use_cuda_graph=False, - context_lens_tensor=None, - # no block tables on encoder forward - block_tables=torch.tensor([]).cuda(), - # block_tables=block_tables, - num_encoder_tokens=sum(seq_lens), - encoder_seq_lens=seq_lens, - encoder_seq_lens_tensor=torch.tensor(seq_lens).cuda(), - max_encoder_seq_len=max(seq_lens), - encoder_seq_start_loc=encoder_seq_start_loc) - # same weights should be loaded - # TODO load model without engine overhead - llm = LLM(model="google-t5/t5-small", - load_format='safetensors', - enforce_eager=True, - dtype='float') - model = llm.llm_engine.model_executor.driver_worker.model_runner.model - t5_attn = model.model.encoder.blocks[0].self_attn.SelfAttention - print("\nTYPE", type(t5_attn)) - # FIXME this is kinda close, maybe issue is not with xformers custom bias attn - # t5_attn = T5Attention(config, AttentionType.ENCODER, has_relative_attention_bias=True).cuda() - assert t5_attn.has_relative_attention_bias - from transformers import T5Tokenizer, T5ForConditionalGeneration - from transformers.models.t5.modeling_t5 import T5Attention as HFT5Attention - hfmodel = T5ForConditionalGeneration.from_pretrained( - 'google-t5/t5-small', return_dict=True) - print("My T5", t5_attn) - # this must be set to call attn.impl.forward - # vllm_config.compilation_config.static_forward_context[".attn"] = t5_attn.attn - vllm_config.compilation_config.static_forward_context[ - "model.encoder.blocks.0.self_attn.SelfAttention.attn"] = t5_attn.attn - hf_attn = hfmodel.encoder.block[0].layer[0].SelfAttention.cuda() - assert hf_attn.has_relative_attention_bias - # hf_attn = HFT5Attention(config, has_relative_attention_bias=True).cuda() - - with set_forward_context(meta, vllm_config): - # input to vllm is 1d flattened, assuming all sequences of same len - xin = x.reshape(-1, config.d_model) - # kv_cache for xformers [2, num_blocks, block_size * num_kv_heads * head_size] - kvc = torch.stack([ - key_cache.reshape(NUM_BLOCKS, -1), - value_cache.reshape(NUM_BLOCKS, -1) - ], 0) - output = t5_attn(xin, kvc, meta) - ref_output, *_ = hf_attn(x) - - atol, rtol = 1e-3, 1e-5 - torch.testing.assert_close(output, - ref_output.squeeze(), - atol=atol, - rtol=rtol) - - # **decoder attn, first xformer forward** - t5_attn = model.model.decoder.blocks[0].self_attn.SelfAttention - assert t5_attn.has_relative_attention_bias - vllm_config.compilation_config.static_forward_context[ - "model.decoder.blocks.0.self_attn.SelfAttention.attn"] = t5_attn.attn - hf_attn = hfmodel.decoder.block[0].layer[0].SelfAttention.cuda() - assert hf_attn.has_relative_attention_bias - - num_decoding_input_ids = 2 # 1 - x = torch.randn(num_seqs, - num_decoding_input_ids, - config.d_model, - device='cuda', - dtype=torch.float) - prefill_seqlens = [num_decoding_input_ids] * len(seq_lens) - meta = XFormersBackend.make_metadata( - seq_lens=prefill_seqlens, - max_decode_seq_len=0, - num_prefills=len(seq_lens), - num_prefill_tokens=sum(prefill_seqlens), - num_decode_tokens=0, - seq_lens_tensor=torch.tensor(prefill_seqlens), - slot_mapping=torch.zeros(1, dtype=torch.long), - # slot_mapping=torch.tensor(slot_mapping_list, dtype=torch.long,device="cuda"), - multi_modal_placeholder_index_maps=None, - max_prefill_seq_len=max(prefill_seqlens), - use_cuda_graph=False, - context_lens_tensor=None, - block_tables=torch.tensor([]).cuda(), - # block_tables=block_tables, - # num_encoder_tokens=sum(seq_lens), encoder_seq_lens=seq_lens,encoder_seq_lens_tensor=torch.tensor(seq_lens).cuda(), - # max_encoder_seq_len=max(seq_lens), encoder_seq_start_loc=encoder_seq_start_loc - ) - - with set_forward_context(meta, vllm_config): - xin = x.reshape(-1, config.d_model) - kvc = torch.stack([ - key_cache.reshape(NUM_BLOCKS, -1), - value_cache.reshape(NUM_BLOCKS, -1) - ], 0) - output = t5_attn(xin, kvc, meta) - ref_output, *_ = hf_attn(x) - torch.testing.assert_close(output.squeeze(), - ref_output.squeeze(), - atol=atol, - rtol=rtol) - return - # **cross attn** - t5_attn = model.model.decoder.blocks[0].cross_attn.EncDecAttention - print("\nTYPE", type(t5_attn)) - assert not t5_attn.has_relative_attention_bias - vllm_config.compilation_config.static_forward_context[ - "model.decoder.blocks.0.cross_attn.EncDecAttention.attn"] = t5_attn.attn - hf_attn = hfmodel.decoder.block[0].layer[1].EncDecAttention.cuda() - assert not hf_attn.has_relative_attention_bias - - meta = XFormersBackend.make_metadata( - seq_lens=seq_lens, - max_decode_seq_len=MAX_SEQ_LEN, - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=1, - max_prefill_seq_len=None, - seq_lens_tensor=torch.tensor(seq_lens), - slot_mapping= - None, #torch.tensor(slot_mapping_list, dtype=torch.long,device="cuda"), - multi_modal_placeholder_index_maps=None, - use_cuda_graph=False, - context_lens_tensor=None, - block_tables=torch.tensor([]).cuda(), - # block_tables=block_tables - ) - - with set_forward_context(meta, vllm_config): - output = t5_attn(x, kvc, meta) - ref_output, *_ = hf_attn(x) - - torch.testing.assert_close(output, - ref_output.squeeze(), - atol=atol, - rtol=rtol) - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", ["google/t5-small"]) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) -def test_models_distributed(hf_runner, vllm_runner, - example_encoder_decoder_prompts, - distributed_executor_backend, model, dtype, - max_tokens, num_logprobs, - decoder_prompt_type) -> None: - compare_hf_vllm_logprobs( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts[decoder_prompt_type], - decoder_prompt_type, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - ) diff --git a/tests/models/encoder_decoder/language/language/utils.py b/tests/models/encoder_decoder/language/language/utils.py deleted file mode 100644 index fd8c81c38b471..0000000000000 --- a/tests/models/encoder_decoder/language/language/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import List, Optional, Tuple -from ....conftest import (DecoderPromptType) -from vllm.sequence import SampleLogprobs - - -def vllm_to_hf_output( - vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], - decoder_prompt_type: DecoderPromptType, -): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "" - if decoder_prompt_type == DecoderPromptType.NONE: - hf_output_str = "" + hf_output_str - - return output_ids, hf_output_str, out_logprobs diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py index 30f48c80a3dba..4a26a63fce33f 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -4,7 +4,7 @@ """ import pytest -from tests.utils import multi_gpu_test # type: ignore[attr-defined] +from tests.utils import multi_gpu_test from ....conftest import DecoderPromptType from .conftest import compare_hf_vllm_logprobs From 455d0cb7288e50fd1037b7393b2d85f4053d2da4 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 9 Jan 2025 15:26:44 +0000 Subject: [PATCH 17/17] update to use new attention_type interface Signed-off-by: NickLucche --- vllm/model_executor/models/t5.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py index a8327a6e8128a..50c7c26659507 100644 --- a/vllm/model_executor/models/t5.py +++ b/vllm/model_executor/models/t5.py @@ -202,7 +202,8 @@ def __init__(self, 1.0, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=self.attn_type) # Only the first SelfAttention block in encoder decoder has this # embedding layer, the others reuse its output. @@ -418,12 +419,7 @@ def forward( # Encoder/Decoder Self-Attention Layer, attn bias already cached. assert attn_bias is not None - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=self.attn_type) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output